Skip to content

fix: FP8 fallback for AIU addons running on CPU#200

Merged
chichun-charlie-liu merged 10 commits intomainfrom
fp8_cpu
Mar 20, 2026
Merged

fix: FP8 fallback for AIU addons running on CPU#200
chichun-charlie-liu merged 10 commits intomainfrom
fp8_cpu

Conversation

@andrea-fasoli
Copy link
Collaborator

@andrea-fasoli andrea-fasoli commented Mar 19, 2026

Description of the change

Starting from PyTorch 2.10, torch._scaled_mm no longer supports FP8 matmul on CPU for any quantization scheme other than per-tensor. torch._scaled_mm through a call to addmm_float8_unwrapped_inference is currently called by the FP8 AIU addons when the model runs on CPU.

This PR implements a fallback in this scenario: we perform a mock FP8 x FP8 matmul on CPU using torch.nn.functional.linear between quantized/dequantized activations and dequantized weights. Notice we do not simply dequantize the FP8 weights, we also mock the activations as FP8.

Related issues or PRs

[internal issue]

How to verify the PR

Example of a test that should pass, ran on a pod with 4 AIUs, in PF mode, in PyTorch 2.10 env (set up env vars according to your case; AFTU = aiu-fms-testing-utils repo):

torchrun --nproc-per-node 4 ${AFTU_PATH}/scripts/drive_paged_programs.py --model_variant ${FP8_MODEL_PATH} --max_new_tokens 128 --timing per-token --dataset_type sharegpt --dataset_path ${DATASET_PATH} --test_type metrics --program_criteria_json_path ${PROGRAMS_FILE} --programs ${SELECTED_PROGRAM} --attention_type paged_fp8 --save_validation_info_outputs --validation_info_outputs_dir ${OUTPUT_DIR} --prefill_chunk_size 1024 --cross_entropy_threshold 2.6 --failure_rate_threshold 0.1 --prioritize_large_batch_sizes --enforce_homogeneous_prompt_programs --distributed

Was the PR tested

  • I have ensured all unit tests pass

Checklist for passing CI/CD:

  • All commits are signed showing "Signed-off-by: Name <email@domain.com>" with git commit -signoff or equivalent
  • PR title and commit messages adhere to Conventional Commits
  • Contribution is formatted with pre-commit
  • Contribution passes all unit tests with tox -e unit

Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
@andrea-fasoli
Copy link
Collaborator Author

@ani300 need your eyes on this

Copy link
Contributor

@ani300 ani300 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! the fix makes sense

@ani300
Copy link
Contributor

ani300 commented Mar 19, 2026

is it worth adding a test to check if the combination that was failing before works now and in the future?

Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
@andrea-fasoli
Copy link
Collaborator Author

@ani300 I added some tests to verify the FP8 CPU support. I also fixed a bug to FP8Linear where the non-quantized activation path could lead to mismatched dtype in the matmul.

As you are aware, there have been changes to FP8 in torchao > 0.11. The handling of scales for FP8 tensors seem to be different. This will break the new fallback path of FP8 on CPU, with dequantize() throwing a shape error. For now, I throw an error if we access the fallback path. FMS-MO enforces torchao == 0.11, although this will eventually change.

@andrea-fasoli
Copy link
Collaborator Author

@ani300 unrelated to this PR, I noticed a suspicious assignment in fp8_shard_linear:

        for module_name, module_info in module_sharding_info.items():
            linear_mod: torch.nn.Module = module_info.linear_module
            weight_strategy = getattr(linear_mod, "linear_config")["input_activations"][
                "strategy"
            ]

I suspect we should be using "weights" instead of "input_activations" to load the weight_strategy. Do you recall if there was any specific reason for this choice?

@ani300
Copy link
Contributor

ani300 commented Mar 20, 2026

@andrea-fasoli it's been a while and I don't remember why I picked this particular field, but it probably has to do with how the FP8 checkpoint comes out of llm-compressor

@thanh-lam
Copy link

Thanks @andrea-fasoli , for this PR! We can test the fixes out and verify this [issue].(https://github.ibm.com/ai-foundation/aiu-app-sw-tracker/issues/1732)

@chichun-charlie-liu
Copy link
Collaborator

lgtm, just a suggestion, maybe we can add a reminder "comment" in pyproject.toml on the line of torchao regarding the version constraint for now. Next time the github bot asks us to upgrade torchao, we will remember and decline the upgrade.

Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Copy link
Collaborator

@chichun-charlie-liu chichun-charlie-liu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@chichun-charlie-liu
Copy link
Collaborator

=========================== short test summary info ============================
ERROR tests/models/test_qmodelprep.py::test_bert_dynamo[qat_int8] - OSError: ...
ERROR tests/models/test_qmodelprep.py::test_bert_dynamo[ptq_int8] - OSError: ...
ERROR tests/models/test_qmodelprep.py::test_bert_dynamo_wi_qbmm[qat_int8] - O...
ERROR tests/models/test_qmodelprep.py::test_bert_dynamo_wi_qbmm[ptq_int8] - O...
============ 5261 passed, 58 skipped, 4 errors in 675.97s (0:11:15) ============

tests failed due to huggingface does not allow github to access bert (unclear if it's too many access within a certain amount of time or overall, may need to disable this test in the future.) Merge for now.

@chichun-charlie-liu chichun-charlie-liu merged commit 90fc888 into main Mar 20, 2026
11 of 14 checks passed
@chichun-charlie-liu chichun-charlie-liu deleted the fp8_cpu branch March 20, 2026 19:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants