From 1994ea33f86000a245c2bcf2506df797c1b43893 Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Fri, 3 Apr 2026 13:47:47 -0700 Subject: [PATCH] Cache RoPE freqs on device to avoid repeated CPU-GPU copy in QwenImage --- .../transformers/transformer_qwenimage.py | 33 +++++++++++++------ 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index d88aef4dcf2a..ca7b78066887 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -222,6 +222,7 @@ def __init__(self, theta: int, axes_dim: list[int], scale_rope=False): # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART self.scale_rope = scale_rope + self._device_freq_cache: dict[torch.device, tuple[torch.Tensor, torch.Tensor]] = {} def rope_params(self, index, dim, theta=10000): """ @@ -233,6 +234,12 @@ def rope_params(self, index, dim, theta=10000): freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs + def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: + """Return pos_freqs and neg_freqs on the given device, caching the transfer.""" + if device not in self._device_freq_cache: + self._device_freq_cache[device] = (self.pos_freqs.to(device), self.neg_freqs.to(device)) + return self._device_freq_cache[device] + def forward( self, video_fhw: tuple[int, int, int, list[tuple[int, int, int]]], @@ -300,8 +307,9 @@ def forward( max_vid_index = max(height, width, max_vid_index) max_txt_seq_len_int = int(max_txt_seq_len) - # Create device-specific copy for text freqs without modifying self.pos_freqs - txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] + # Use cached device-transferred freqs to avoid CPU→GPU sync every forward call + pos_freqs_device, _ = self._get_device_freqs(device) + txt_freqs = pos_freqs_device[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @@ -311,8 +319,7 @@ def _compute_video_freqs( self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None ) -> torch.Tensor: seq_lens = frame * height * width - pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs - neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs + pos_freqs, neg_freqs = self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs) freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) @@ -356,6 +363,7 @@ def __init__(self, theta: int, axes_dim: list[int], scale_rope=False): ) self.scale_rope = scale_rope + self._device_freq_cache: dict[torch.device, tuple[torch.Tensor, torch.Tensor]] = {} def rope_params(self, index, dim, theta=10000): """ @@ -367,6 +375,12 @@ def rope_params(self, index, dim, theta=10000): freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs + def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: + """Return pos_freqs and neg_freqs on the given device, caching the transfer.""" + if device not in self._device_freq_cache: + self._device_freq_cache[device] = (self.pos_freqs.to(device), self.neg_freqs.to(device)) + return self._device_freq_cache[device] + def forward( self, video_fhw: tuple[int, int, int, list[tuple[int, int, int]]], @@ -421,8 +435,9 @@ def forward( max_vid_index = max(max_vid_index, layer_num) max_txt_seq_len_int = int(max_txt_seq_len) - # Create device-specific copy for text freqs without modifying self.pos_freqs - txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] + # Use cached device-transferred freqs to avoid CPU→GPU sync every forward call + pos_freqs_device, _ = self._get_device_freqs(device) + txt_freqs = pos_freqs_device[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @@ -430,8 +445,7 @@ def forward( @lru_cache_unless_export(maxsize=None) def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None): seq_lens = frame * height * width - pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs - neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs + pos_freqs, neg_freqs = self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs) freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) @@ -452,8 +466,7 @@ def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device @lru_cache_unless_export(maxsize=None) def _compute_condition_freqs(self, frame, height, width, device: torch.device = None): seq_lens = frame * height * width - pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs - neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs + pos_freqs, neg_freqs = self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs) freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)