-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Tensor scale nvfp4 #3022
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Tensor scale nvfp4 #3022
Conversation
…into tensor-scale-nvfp4
| qmode); | ||
|
|
||
| qqmm.run(encoder, out, a, b, a_scale, b_scale, alpha); | ||
| if (scalars.uses_device_pointers()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name of that method is a little confusing. Maybe it would make more sense to call it has_values() or something?
awni
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I left one last minor comment. Otherwise this looks great, we should merge it!
Add per tensor scale for
nvfp4quantization forcudaandcpu.qqmm,quantize,dequantizeinputs optional 1Dfloat32array (global_scale) ifmode == "nvfp4".Also some files related to
qqmmwere refactored.Important details:
qqmm] currently ifglobal_scaleis provided for the first input, it must be provided for the second input as well. This is because we pass global scales as inputs inQQMatmul::eval_gpu()and we can't distinguish betweenglobal_scale_xandglobal_scale_w.alphaandbetaboth should be device or host ptrs. Therefore, ifalphais a device ptr,Tensor scale will help with small inputs:
TODO:
global_scalein metal as well but it requires changing all quantized operationsx,was well as forcotan), thereforeQQLineardoes not have global scale support. I will add it after some exploration as a separate PR.fp_qmv_implwas not updated yet to support global scale