[compiler] Fix MXFP4 scale-read coalescing and 8-wave pingpong lowering#1069
Merged
harsh-nod merged 3 commits intoiree-org:mainfrom Mar 7, 2026
Merged
[compiler] Fix MXFP4 scale-read coalescing and 8-wave pingpong lowering#1069harsh-nod merged 3 commits intoiree-org:mainfrom
harsh-nod merged 3 commits intoiree-org:mainfrom
Conversation
79a5371 to
e347251
Compare
When M, N, K are dynamic, the read coalescer can fail to merge per-thread B-scale byte reads into contiguous vector loads because floordiv/Mod expressions are evaluated at probe points that do not respect divisibility constraints. The wider merged scale packets also exposed a pingpong scheduling bug where scale bytes from different K-halves were carried together and later regrouped incorrectly. Changes: - Apply divisibility forward subs before numeric probing so floordiv and Mod expressions evaluate consistently across probe sets. - Rework pairwise merge validation around probe-based flat-offset and per-dimension delta checks, and factor reusable symbolic probe helpers into symbol_utils. - Preserve uniform transformed bounds masks and scalar dense-window masks so bounded mapped reads can still be re-merged across merge levels. - Flatten bounds checks into precomputed sympy booleans before emitting merged reads. - Partition the pingpong scale read and bitcast path by K so loop- carried scale packets stay coherent across the two K-halves. - Update the related wave codegen, opsel, and ASM handling to match the wider merged scale packets. - Add and refresh lit and unit coverage for dynamic probing and MXFP4 scale-read merging behavior. Co-authored-by: xintin <gaurav.verma@amd.com> Co-authored-by: Hardcode84 <ivan.butygin@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Harsh Menon <harsh.menon@amd.com>
Contributor
|
So, what is the difference from the original branch? |
Collaborator
Author
Apart from the refactoring, I had to modify the following
|
Collaborator
Author
|
Also made it faster in the second commit. |
The iterative pairwise doubling loop (ept=1->2->4->8->16) was the dominant cost in merge_contiguous_reads, requiring 5 full passes with repeated ProbeEvaluator construction, symbolic simplification, and intermediate-read reprocessing. Add a single-pass wide-merge algorithm that runs before the iterative fallback and handles the bulk of the work: - Groups reads into maximal power-of-2 dense windows along the fastest physical dimension using numeric probing across 6 diverse probe sets. - Handles both identity/unmapped reads and mapped reads with uniform transformed bounds via a group-level N-way bounds check that replaces the per-pair _try_build_uniform_transformed_bounds_mask calls. - Caches transformed-index and mask-preservation info per read to avoid repeated transform_index_on_mapping calls across merge attempts. - Precomputes all probe values eagerly during ProbeEvaluator construction so verification is pure integer arithmetic. - Extracts reusable symbolic probe helpers (eval_expr, get_start_expr, expr_is_zero_under_probes) into symbol_utils. The iterative pairwise/multiway fallback is retained for residual reads the fast path cannot handle (shared-memory reads from gather expansion, sparse byte windows, mixed mapping groups). Result: merge_contiguous_reads drops from ~68s to ~28s (2.4x) on dynamic-dim MXFP4 preshuffle kernels. The iterative loop now completes in one vacuous pass (~1.6s) since the fast path consumes all the expensive mapped-bounded reads. Signed-off-by: Harsh Menon <harsh.menon@amd.com>
Collaborator
Author
|
i am reworking this a bit, will post here when ready |
Reorganize the merge legality checks in partition_strided_operators.py around three named proof obligations instead of ad-hoc repair rules: 1. Address equivalence: A_r(k) = A_W(o_r + k) for all original lanes. 2. Mask equivalence: V_r(k) = V_W(o_r + k) for all original lanes. 3. Lowering support: the backend codegen path realizes V_W exactly. Changes: - Add _prove_mask_equivalent() that unifies _can_preserve_bounds_mask, _try_build_uniform_transformed_bounds_mask, and _try_build_group_uniform_bounds_mask into one cheapest-first strategy ladder called by both pairwise and fast-path merging. - Add _check_lowering_ok() that makes hardware and backend realizability checks a named function. - Simplify _do_merge() and _resolve_group_mask() to delegate to the unified proof functions. - Remove the now-dead _try_build_uniform_transformed_bounds_mask and _try_build_group_uniform_bounds_mask functions. - Restrict fast-path eligibility to mapped reads only, since identity/unmapped reads in multi-dimensional tensors with stride > 1 have flat offsets that diverge from fastest-dim offsets. - Guard fast path against symbolic ept values. - Make all docstrings self-contained with inline proof obligation descriptions. Signed-off-by: Harsh Menon <harsh.menon@amd.com>
Collaborator
Author
|
okay should be good to review now. |
Hardcode84
approved these changes
Mar 7, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
…rnels
When M, N, K are dynamic, the read coalescer fails to merge per-thread B-scale byte reads into contiguous vector<16xi8> loads due to inconsistent floordiv/Mod evaluation at probe points that don't respect divisibility constraints.
Changes:
Apply divisibility forward subs (e.g. K -> 256*K') before numeric probing so floordiv/Mod evaluate consistently across all probe sets.
Replace symbolic diff approach with numeric probing for pairwise merge to avoid symbolic explosion on complex preshuffle index expressions. Uses parity-diverse probe generators (5 probes with three different parity patterns) to catch Mod wrapping bugs that uniform-parity probes miss.
Allow re-merging reads with precomputed masks across merge levels.
Pre-compute bounds checks as flat sympy booleans before merging.
Support partial iter_arg groups (>=2 offsets) in opsel_scaled_mfma coalescing, and look through arith.select from flatten-bounds masking.
Fix extract_strided_slice in ASM backend to convert element offsets to physical register indices based on element type width.
Extract ProbeEvaluator class shared between merge functions, use ceildiv, remove dead _resolve_symbolic_diff.