Skip to content

Conversation

@nastya236
Copy link
Collaborator

@nastya236 nastya236 commented Jan 20, 2026

Add per tensor scale for nvfp4 quantization for cuda and cpu.

qqmm, quantize, dequantize inputs optional 1D float32 array (global_scale) if mode == "nvfp4".

Also some files related to qqmm were refactored.

Important details:

  • [qqmm] currently if global_scale is provided for the first input, it must be provided for the second input as well. This is because we pass global scales as inputs in QQMatmul::eval_gpu() and we can't distinguish between global_scale_x and global_scale_w.
  • alpha and beta both should be device or host ptrs. Therefore, if alpha is a device ptr,
    Tensor scale will help with small inputs:
import mlx.core as mx

x = mx.random.uniform(shape=(2, 16)) / 1e5
xq_ns, scales_ns = mx.quantize(x, mode="nvfp4")
global_scale=mx.absmax(x).astype(mx.float32)
xq_s, scales_s = mx.quantize(x, mode="nvfp4", global_scale = global_scale)

print(mx.allclose(scales_ns, mx.zeros_like(scales_ns)))
print(mx.allclose(scales_s, mx.zeros_like(scales_s)))

TODO:

  • we probably want to support global_scale in metal as well but it requires changing all quantized operations
  • it is not yet clear what might be the best strategy for scale computation during training (for x, w as well as for cotan), therefore QQLinear does not have global scale support. I will add it after some exploration as a separate PR.
  • fp_qmv_impl was not updated yet to support global scale

@nastya236 nastya236 closed this Jan 20, 2026
@nastya236 nastya236 reopened this Jan 20, 2026
@nastya236 nastya236 changed the title Tensor scale nvfp4 [WIP] Tensor scale nvfp4 Jan 20, 2026
@nastya236 nastya236 changed the title [WIP] Tensor scale nvfp4 Tensor scale nvfp4 Jan 23, 2026
@nastya236 nastya236 marked this pull request as draft February 7, 2026 01:47
@nastya236 nastya236 marked this pull request as ready for review February 7, 2026 01:47
qmode);

qqmm.run(encoder, out, a, b, a_scale, b_scale, alpha);
if (scalars.uses_device_pointers()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name of that method is a little confusing. Maybe it would make more sense to call it has_values() or something?

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I left one last minor comment. Otherwise this looks great, we should merge it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants