Skip to content

Comments

add tma transpose auto scheduler#5982

Draft
liqiangxl wants to merge 18 commits intomainfrom
llu/transpose_tma_auto2
Draft

add tma transpose auto scheduler#5982
liqiangxl wants to merge 18 commits intomainfrom
llu/transpose_tma_auto2

Conversation

@liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Feb 19, 2026

The heuristics is a basic version and current performance is in this doc.

Base automatically changed from llu/transpose_tma_auto to main February 19, 2026 17:09
@github-actions
Copy link

github-actions bot commented Feb 19, 2026

Review updated until commit 0c32eb6

Description

  • Implement TMA transpose auto scheduler with full scheduling logic

  • Add TMA-specific heuristics including tile sizing and chunking parameters

  • Enable TMA transpose through new option flag with fallback to non-TMA

  • Add comprehensive test coverage for TMA transpose functionality

Changes walkthrough

Relevant files
Enhancement
transpose_tma.cpp
Complete TMA transpose scheduler implementation                   

csrc/scheduler/transpose_tma.cpp

  • Implement complete TMA transpose heuristics with tile sizing and
    chunking
  • Add full TMA scheduling logic including tiling, shared memory
    swizzling, and register access
  • Support both TMA load and optional TMA store operations
  • Include debug output and comprehensive scheduling transforms
  • +233/-4 
    transpose.cpp
    Add TMA transpose option check with fallback                         

    csrc/scheduler/transpose.cpp

  • Add conditional check for TMA transpose option before attempting TMA
    path
  • Maintain fallback to non-TMA scheduler when TMA is not applicable
  • +5/-3     
    transpose_heuristic.h
    Extend TransposeParams with TMA-specific parameters           

    csrc/scheduler/transpose_heuristic.h

  • Add TMA store flag and chunking parameters to TransposeParams
  • Update equality check, hashing, and debug output for new TMA
    parameters
  • +19/-0   
    Configuration changes
    options.h
    Add TMA transpose enable option                                                   

    csrc/options.h

    • Add new EnableOption::TmaTranspose enum value
    +1/-0     
    options.cpp
    Register TMA transpose option                                                       

    csrc/options.cpp

    • Register "tma_transpose" string option in enable options map
    +1/-0     
    Tests
    test_transpose.cpp
    Add comprehensive TMA transpose test coverage                       

    tests/cpp/test_transpose.cpp

  • Add TmaTransposeTestP test fixture with TMA transpose enable guard
  • Implement parameterized tests for various data types and dimension
    combinations
  • Test corner cases with different inner dimension sizes
  • +55/-0   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Potential Division by Zero

    In line 56, there's a division by max_input_dtype_size without checking if it's zero. While dtype size should never be zero, adding a defensive check would make the code more robust.

    tparams->elements_per_chunk = kBytesPerChunk / max_input_dtype_size;
    Unbounded chunks_per_thread

    The chunks_per_thread calculation on line 55 could result in values outside the expected range [1, 8]. The comment mentions this range but there's no clamping logic to enforce it.

    const int64_t target_bdimx = (n_input == 1) ? 256 : 128;
    tparams->chunks_per_thread = tparams->tile_size1 * 8 / target_bdimx;
    tparams->elements_per_chunk = kBytesPerChunk / max_input_dtype_size;
    Missing Error Handling

    The code assumes certain tensor properties (like having LoadStoreOp for cached inputs) but doesn't handle cases where these assumptions might fail, potentially leading to null pointer dereferences.

    if (auto load_op = dynamic_cast<LoadStoreOp*>(cached_input->definition())) {
      load_op->setOpType(LoadStoreOpType::CpAsyncBulkTensorTile);
      cached_input->setMemoryType(MemoryType::Shared);
      tma_load_tvs.push_back(cached_input);
    }

    @liqiangxl liqiangxl marked this pull request as ready for review February 19, 2026 17:15
    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl liqiangxl requested a review from rdspring1 February 19, 2026 17:16
    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 19, 2026

    Greptile Summary

    This PR implements a TMA (Tensor Memory Accelerator) transpose auto-scheduler for GPU kernels. The implementation adds a new scheduler that uses TMA load/store operations with swizzled shared memory to optimize transpose operations.

    Key Changes:

    • Added TmaTranspose enable option to control TMA transpose scheduling
    • Implemented getTransposeHeuristics() to calculate tile sizes and thread configuration based on input data type and count
    • Implemented scheduleTranspose() with 4-step scheduling: TMA tiling, TMA store scheduling, shared memory swizzle, and register scheduling
    • Added new parameters use_tma_store, chunks_per_thread, and elements_per_chunk to TransposeParams
    • Added comprehensive parameterized tests covering Float and BFloat16 data types with various dimensions

    Issues Found:

    • Missing validation for edge cases where tile_size1 could be 0 (if max_input_dtype_size > 128), which would cause chunks_per_thread to be 0 and lead to invalid split operations

    Confidence Score: 4/5

    • This PR is generally safe to merge but has one logical issue that could cause runtime failures in edge cases
    • The implementation is well-structured with clear scheduling steps and comprehensive tests. However, there's a missing validation check that could lead to division by zero or invalid split operations if max_input_dtype_size > 128. While this is unlikely with standard data types (max is typically 16 bytes for ComplexDouble), the lack of validation is a potential bug. The fix is straightforward - add a validation check before calculating chunks_per_thread.
    • Pay close attention to csrc/scheduler/transpose_tma.cpp - ensure the tile size calculation handles all possible input data type sizes

    Important Files Changed

    Filename Overview
    csrc/scheduler/transpose.cpp Added conditional check for TmaTranspose option before invoking TMA path
    csrc/scheduler/transpose_heuristic.h Added TMA store support and chunk-based parameters to TransposeParams class
    csrc/scheduler/transpose_tma.cpp Implemented TMA transpose auto-scheduler with heuristics and scheduling logic; potential validation gaps for edge cases

    Flowchart

    %%{init: {'theme': 'neutral'}}%%
    flowchart TD
        A[TransposeScheduler::computeHeuristics] --> B{TmaTranspose enabled?}
        B -->|Yes| C[transpose::tma::getTransposeHeuristics]
        B -->|No| D[transpose::non_tma::getTransposeHeuristics]
        C --> E{Returns valid params?}
        E -->|Yes| F[Use TMA scheduler]
        E -->|No| D
        D --> G[Return non-TMA params]
        
        F --> H[scheduleTranspose with TMA]
        H --> I[Step 1: TMA Tiling<br/>Split and reorder for BIDx, tile_2, tile_1]
        I --> J[Step 2: Schedule TMA Store<br/>Set Bulk parallel on tile dims]
        J --> K[Step 3: Schedule Input Shared Memory<br/>Apply XOR swizzle for TMA load]
        K --> L[Step 4: Schedule Register TVs<br/>Split by chunks_per_thread and elements_per_chunk<br/>Set TIDx and Unroll parallelization]
        L --> M[inlineMost]
        
        C --> N[Calculate Heuristics]
        N --> O[Compute max_input_dtype_size and n_input]
        O --> P[tile_size2 = 128 / max_input_dtype_size]
        P --> Q{n_input == 1?}
        Q -->|Yes| R[tile_size1 = tile_size2 * 2<br/>target_bdimx = 256]
        Q -->|No| S[tile_size1 = tile_size2<br/>target_bdimx = 128]
        R --> T[chunks_per_thread = tile_size1 * 8 / target_bdimx]
        S --> T
        T --> U[elements_per_chunk = 16 / max_input_dtype_size]
    
    Loading

    Last reviewed commit: 0c32eb6

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    6 files reviewed, 4 comments

    Edit Code Review Agent Settings | Greptile

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    6 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 20, 2026

    Additional Comments (2)

    csrc/scheduler/transpose_heuristic.h
    New fields use_tma_store, chunks_per_thread, and elements_per_chunk missing from equality check

      bool sameAs(const HeuristicParams* other_base) const override {
        auto other = dynamic_cast<const TransposeParams*>(other_base);
        if (other == nullptr) {
          return false;
        }
        bool attr_equal = other->cparams == cparams &&
            other->use_tma_load == use_tma_load &&
            other->use_tma_store == use_tma_store &&
            other->chunks_per_thread == chunks_per_thread &&
            other->elements_per_chunk == elements_per_chunk &&
            other->split_before_tiling == split_before_tiling &&
            other->dims_merged_with_1 == dims_merged_with_1 &&
            other->dims_merged_with_2 == dims_merged_with_2 &&
            other->vectorize_factor1 == vectorize_factor1 &&
            other->vectorize_factor2 == vectorize_factor2 &&
            other->tile_size1 == tile_size1 && other->tile_size2 == tile_size2;
        return attr_equal;
      }
    

    csrc/scheduler/transpose_heuristic.h
    New fields use_tma_store, chunks_per_thread, and elements_per_chunk missing from hash calculation

      size_t hash() const override {
        return c10::get_hash(
            use_tma_load,
            use_tma_store,
            chunks_per_thread,
            elements_per_chunk,
            split_before_tiling,
            dims_merged_with_1,
            dims_merged_with_2,
            vectorize_factor1,
            vectorize_factor2,
            tile_size1,
            tile_size2);
      }
    

    liqiangxl and others added 7 commits February 20, 2026 07:01
    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    6 files reviewed, 3 comments

    Edit Code Review Agent Settings | Greptile

    // the inner most dim, range is [1, 8]
    // bdimx = tile_size1 * 8 / chunks_per_thread
    const int64_t target_bdimx = (n_input == 1) ? 256 : 128;
    tparams->chunks_per_thread = tparams->tile_size1 * 8 / target_bdimx;
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    potential division by zero if target_bdimx is zero (though unlikely given the constants)

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 20, 2026

    Additional Comments (2)

    csrc/scheduler/transpose_heuristic.h
    missing new fields in hash function: use_tma_store, chunks_per_thread, elements_per_chunk

      size_t hash() const override {
        return c10::get_hash(
            use_tma_load,
            use_tma_store,
            chunks_per_thread,
            elements_per_chunk,
            split_before_tiling,
            dims_merged_with_1,
            dims_merged_with_2,
            vectorize_factor1,
            vectorize_factor2,
            tile_size1,
            tile_size2);
      }
    

    csrc/scheduler/transpose_heuristic.h
    missing new fields in equality check: use_tma_store, chunks_per_thread, elements_per_chunk

      bool sameAs(const HeuristicParams* other_base) const override {
        auto other = dynamic_cast<const TransposeParams*>(other_base);
        if (other == nullptr) {
          return false;
        }
        bool attr_equal = other->cparams == cparams &&
            other->use_tma_load == use_tma_load &&
            other->use_tma_store == use_tma_store &&
            other->chunks_per_thread == chunks_per_thread &&
            other->elements_per_chunk == elements_per_chunk &&
            other->split_before_tiling == split_before_tiling &&
            other->dims_merged_with_1 == dims_merged_with_1 &&
            other->dims_merged_with_2 == dims_merged_with_2 &&
            other->vectorize_factor1 == vectorize_factor1 &&
            other->vectorize_factor2 == vectorize_factor2 &&
            other->tile_size1 == tile_size1 && other->tile_size2 == tile_size2;
        return attr_equal;
      }
    

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    6 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    // the inner most dim, range is [1, 8]
    // bdimx = tile_size1 * 8 / chunks_per_thread
    const int64_t target_bdimx = (n_input == 1) ? 256 : 128;
    tparams->chunks_per_thread = tparams->tile_size1 * 8 / target_bdimx;
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    if tile_size1 is 0 (which could happen if max_input_dtype_size > 128), chunks_per_thread will be 0, causing issues with split at line 217

    Suggested change
    tparams->chunks_per_thread = tparams->tile_size1 * 8 / target_bdimx;
    NVF_CHECK(tparams->tile_size1 > 0, "tile_size1 must be positive");
    tparams->chunks_per_thread = tparams->tile_size1 * 8 / target_bdimx;

    Comment on lines +134 to +136
    if (max_output_dims == 0 && max_input_dims == 0) {
    return;
    }
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    consider adding validation or logging when both max_output_dims and max_input_dims are 0 before early return

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

    @liqiangxl liqiangxl removed the request for review from rdspring1 February 20, 2026 15:38
    @liqiangxl liqiangxl marked this pull request as draft February 20, 2026 15:38
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    1 participant