Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions src/diffusers/models/transformers/transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def __init__(self, theta: int, axes_dim: list[int], scale_rope=False):

# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
self.scale_rope = scale_rope
self._device_freq_cache: dict[torch.device, tuple[torch.Tensor, torch.Tensor]] = {}

def rope_params(self, index, dim, theta=10000):
"""
Expand All @@ -233,6 +234,12 @@ def rope_params(self, index, dim, theta=10000):
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs

def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
"""Return pos_freqs and neg_freqs on the given device, caching the transfer."""
if device not in self._device_freq_cache:
self._device_freq_cache[device] = (self.pos_freqs.to(device), self.neg_freqs.to(device))
return self._device_freq_cache[device]

def forward(
self,
video_fhw: tuple[int, int, int, list[tuple[int, int, int]]],
Expand Down Expand Up @@ -300,8 +307,9 @@ def forward(
max_vid_index = max(height, width, max_vid_index)

max_txt_seq_len_int = int(max_txt_seq_len)
# Create device-specific copy for text freqs without modifying self.pos_freqs
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
# Use cached device-transferred freqs to avoid CPU→GPU sync every forward call
pos_freqs_device, _ = self._get_device_freqs(device)
txt_freqs = pos_freqs_device[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)

return vid_freqs, txt_freqs
Expand All @@ -311,8 +319,7 @@ def _compute_video_freqs(
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
) -> torch.Tensor:
seq_lens = frame * height * width
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
pos_freqs, neg_freqs = self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs)

freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
Expand Down Expand Up @@ -356,6 +363,7 @@ def __init__(self, theta: int, axes_dim: list[int], scale_rope=False):
)

self.scale_rope = scale_rope
self._device_freq_cache: dict[torch.device, tuple[torch.Tensor, torch.Tensor]] = {}

def rope_params(self, index, dim, theta=10000):
"""
Expand All @@ -367,6 +375,12 @@ def rope_params(self, index, dim, theta=10000):
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs

def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
"""Return pos_freqs and neg_freqs on the given device, caching the transfer."""
if device not in self._device_freq_cache:
self._device_freq_cache[device] = (self.pos_freqs.to(device), self.neg_freqs.to(device))
return self._device_freq_cache[device]

def forward(
self,
video_fhw: tuple[int, int, int, list[tuple[int, int, int]]],
Expand Down Expand Up @@ -421,17 +435,17 @@ def forward(

max_vid_index = max(max_vid_index, layer_num)
max_txt_seq_len_int = int(max_txt_seq_len)
# Create device-specific copy for text freqs without modifying self.pos_freqs
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
# Use cached device-transferred freqs to avoid CPU→GPU sync every forward call
pos_freqs_device, _ = self._get_device_freqs(device)
txt_freqs = pos_freqs_device[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)

return vid_freqs, txt_freqs

@lru_cache_unless_export(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
seq_lens = frame * height * width
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
pos_freqs, neg_freqs = self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs)

freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
Expand All @@ -452,8 +466,7 @@ def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device
@lru_cache_unless_export(maxsize=None)
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
seq_lens = frame * height * width
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
pos_freqs, neg_freqs = self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs)

freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
Expand Down