Skip to content

[compiler] Fix MXFP4 scale-read coalescing and 8-wave pingpong lowering#1069

Merged
harsh-nod merged 3 commits intoiree-org:mainfrom
harsh-nod:coalesce
Mar 7, 2026
Merged

[compiler] Fix MXFP4 scale-read coalescing and 8-wave pingpong lowering#1069
harsh-nod merged 3 commits intoiree-org:mainfrom
harsh-nod:coalesce

Conversation

@harsh-nod
Copy link
Collaborator

…rnels

When M, N, K are dynamic, the read coalescer fails to merge per-thread B-scale byte reads into contiguous vector<16xi8> loads due to inconsistent floordiv/Mod evaluation at probe points that don't respect divisibility constraints.

Changes:

  • Apply divisibility forward subs (e.g. K -> 256*K') before numeric probing so floordiv/Mod evaluate consistently across all probe sets.

  • Replace symbolic diff approach with numeric probing for pairwise merge to avoid symbolic explosion on complex preshuffle index expressions. Uses parity-diverse probe generators (5 probes with three different parity patterns) to catch Mod wrapping bugs that uniform-parity probes miss.

  • Allow re-merging reads with precomputed masks across merge levels.

  • Pre-compute bounds checks as flat sympy booleans before merging.

  • Support partial iter_arg groups (>=2 offsets) in opsel_scaled_mfma coalescing, and look through arith.select from flatten-bounds masking.

  • Fix extract_strided_slice in ASM backend to convert element offsets to physical register indices based on element type width.

  • Extract ProbeEvaluator class shared between merge functions, use ceildiv, remove dead _resolve_symbolic_diff.

@harsh-nod harsh-nod force-pushed the coalesce branch 4 times, most recently from 79a5371 to e347251 Compare March 7, 2026 03:42
When M, N, K are dynamic, the read coalescer can fail to merge
per-thread B-scale byte reads into contiguous vector loads because
floordiv/Mod expressions are evaluated at probe points that do not
respect divisibility constraints. The wider merged scale packets also
exposed a pingpong scheduling bug where scale bytes from different
K-halves were carried together and later regrouped incorrectly.

Changes:

- Apply divisibility forward subs before numeric probing so floordiv
  and Mod expressions evaluate consistently across probe sets.

- Rework pairwise merge validation around probe-based flat-offset and
  per-dimension delta checks, and factor reusable symbolic probe
  helpers into symbol_utils.

- Preserve uniform transformed bounds masks and scalar dense-window
  masks so bounded mapped reads can still be re-merged across merge
  levels.

- Flatten bounds checks into precomputed sympy booleans before
  emitting merged reads.

- Partition the pingpong scale read and bitcast path by K so loop-
  carried scale packets stay coherent across the two K-halves.

- Update the related wave codegen, opsel, and ASM handling to match
  the wider merged scale packets.

- Add and refresh lit and unit coverage for dynamic probing and MXFP4
  scale-read merging behavior.

Co-authored-by: xintin <gaurav.verma@amd.com>
Co-authored-by: Hardcode84 <ivan.butygin@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Harsh Menon <harsh.menon@amd.com>
@harsh-nod harsh-nod changed the title [compiler] Coalesce B-scale reads for dynamic-dim MXFP4 preshuffle ke… [compiler] Fix MXFP4 scale-read coalescing and 8-wave pingpong lowering Mar 7, 2026
@harsh-nod harsh-nod requested a review from Hardcode84 March 7, 2026 03:49
@Hardcode84
Copy link
Contributor

So, what is the difference from the original branch?

@harsh-nod
Copy link
Collaborator Author

So, what is the difference from the original branch?

Apart from the refactoring, I had to modify the following

  • The schedule in the ping pong kernel to match the other schedules (split bitcast along K). Without this, the opsel pass was merging wrong reads and producing wrong results
  • In partition_strided_operators.py, there were some major changes:
  • Stronger bounds mask correctness (_should_use_transformed_index() , _can_preserve_bounds_mask(), _try_build_uniform_transformed_bounds_mask(), _dense_window_has_uniform_mask())
  • More robust probing (MERGE_PROBES increased to 6 generators)
  • Safer multi-way coalescing rules (filtering to identity mapped or unmapped reads, adds verification via verify_per_dim_delta)

@harsh-nod
Copy link
Collaborator Author

harsh-nod commented Mar 7, 2026

Also made it faster in the second commit.

The iterative pairwise doubling loop (ept=1->2->4->8->16) was the
dominant cost in merge_contiguous_reads, requiring 5 full passes with
repeated ProbeEvaluator construction, symbolic simplification, and
intermediate-read reprocessing.

Add a single-pass wide-merge algorithm that runs before the iterative
fallback and handles the bulk of the work:

- Groups reads into maximal power-of-2 dense windows along the fastest
  physical dimension using numeric probing across 6 diverse probe sets.

- Handles both identity/unmapped reads and mapped reads with uniform
  transformed bounds via a group-level N-way bounds check that replaces
  the per-pair _try_build_uniform_transformed_bounds_mask calls.

- Caches transformed-index and mask-preservation info per read to
  avoid repeated transform_index_on_mapping calls across merge attempts.

- Precomputes all probe values eagerly during ProbeEvaluator
  construction so verification is pure integer arithmetic.

- Extracts reusable symbolic probe helpers (eval_expr, get_start_expr,
  expr_is_zero_under_probes) into symbol_utils.

The iterative pairwise/multiway fallback is retained for residual reads
the fast path cannot handle (shared-memory reads from gather expansion,
sparse byte windows, mixed mapping groups).

Result: merge_contiguous_reads drops from ~68s to ~28s (2.4x) on
dynamic-dim MXFP4 preshuffle kernels. The iterative loop now completes
in one vacuous pass (~1.6s) since the fast path consumes all the
expensive mapped-bounded reads.

Signed-off-by: Harsh Menon <harsh.menon@amd.com>
@harsh-nod
Copy link
Collaborator Author

i am reworking this a bit, will post here when ready

Reorganize the merge legality checks in partition_strided_operators.py
around three named proof obligations instead of ad-hoc repair rules:

1. Address equivalence: A_r(k) = A_W(o_r + k) for all original lanes.
2. Mask equivalence: V_r(k) = V_W(o_r + k) for all original lanes.
3. Lowering support: the backend codegen path realizes V_W exactly.

Changes:

- Add _prove_mask_equivalent() that unifies _can_preserve_bounds_mask,
  _try_build_uniform_transformed_bounds_mask, and
  _try_build_group_uniform_bounds_mask into one cheapest-first strategy
  ladder called by both pairwise and fast-path merging.

- Add _check_lowering_ok() that makes hardware and backend
  realizability checks a named function.

- Simplify _do_merge() and _resolve_group_mask() to delegate to
  the unified proof functions.

- Remove the now-dead _try_build_uniform_transformed_bounds_mask and
  _try_build_group_uniform_bounds_mask functions.

- Restrict fast-path eligibility to mapped reads only, since
  identity/unmapped reads in multi-dimensional tensors with stride > 1
  have flat offsets that diverge from fastest-dim offsets.

- Guard fast path against symbolic ept values.

- Make all docstrings self-contained with inline proof obligation
  descriptions.

Signed-off-by: Harsh Menon <harsh.menon@amd.com>
@harsh-nod
Copy link
Collaborator Author

okay should be good to review now.

@harsh-nod harsh-nod merged commit dcac23e into iree-org:main Mar 7, 2026
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants