Cache RoPE freqs on device to avoid repeated CPU-GPU copy in QwenImage#13406
Open
akshan-main wants to merge 1 commit intohuggingface:mainfrom
Open
Cache RoPE freqs on device to avoid repeated CPU-GPU copy in QwenImage#13406akshan-main wants to merge 1 commit intohuggingface:mainfrom
akshan-main wants to merge 1 commit intohuggingface:mainfrom
Conversation
Author
|
The profiling was done with 2 steps, but this sync happens every transformer forward call, so at 20 inference steps, this eliminates ~1.5s of CPU-GPU sync overhead per run. Under torch.compile the impact is larger since GPU queues are deeper and each sync stalls longer (80ms vs 76ms in eager). |
Author
|
oh and this fix applies to all QwenImage variants (Edit, EditPlus, Layered) since they share the same transformer |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Part of #13401
QwenEmbedRope.forward()copiespos_freqsandneg_freqsfrom CPU to GPU via.to(device)on every transformer forward call. These tensors are fixed at init and never change, so the repeated transfer triggers an unnecessarycudaStreamSynchronize(~76ms each).Added
_get_device_freqs()that caches the GPU copy on first call. Applied to bothQwenEmbedRopeandQwenEmbedLayer3DRope.(
register_buffercan't be used here because it drops the imaginary part of complex tensors)Profiling (A100 80GB, eager, 2 steps, 1024x1024)
Before (76ms cudaStreamSynchronize inside transformer_forward):
After (no sync gap):
Profiled with the tooling from #13356. Reproduction notebook.
Part of #13401
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@sayakpaul @dg845