Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
90619c5
Reapply "Attention bug fixes, tokamax splash defaulting logic (#282)"…
eltsai Dec 15, 2025
d848983
Reapply "Cross self attention switch (#251)" (#288)
eltsai Dec 15, 2025
c29fdc4
Disable unsafe rng
eltsai Dec 15, 2025
f68c7b0
Integrate tokamax ring attention as optional attention kernel for WAN…
eltsai Dec 17, 2025
8a18686
Merge branch 'main' into elisatsai_disable_unsafe_rng
eltsai Dec 29, 2025
a7fa4f0
Fixed formatting issue
eltsai Dec 30, 2025
41d9353
Updated scheduler test values
eltsai Dec 30, 2025
d128e32
Updated values based on v5p-8 tests
eltsai Dec 30, 2025
70ce989
Fixing ring attention
eltsai Jan 5, 2026
ed47e5f
moving kernel init outside the sharding map
eltsai Feb 10, 2026
65e7f93
Revert "moving kernel init outside the sharding map"
eltsai Feb 15, 2026
a0c377f
jitting and sharding vae, refactored for loops in jitted VAE, 132 sec…
eltsai Feb 23, 2026
e7cd3c4
Renaming VAE sharding axis to vae_spatial
eltsai Feb 26, 2026
c236d56
Renaming VAE sharding axis to vae_spatial
eltsai Feb 26, 2026
9bcd458
ring-attention
coolkp Mar 2, 2026
0e60bbb
Merge remote-tracking branch 'origin/kunjanp-ring-attention' into eli…
eltsai Mar 4, 2026
10f2f33
Merge remote-tracking branch 'origin/main' into elisatsai_ring_attention
eltsai Mar 4, 2026
ffd7933
fixing attention from merging main
eltsai Mar 5, 2026
62e3b06
Fix attention_flax API regression from manual edits regarding context…
eltsai Mar 5, 2026
0a7d593
Merge branch 'elisatsai_ring_attention' of https://github.com/AI-Hype…
eltsai Mar 5, 2026
115fffa
Added sharding on ROPE
eltsai Mar 10, 2026
e04e78d
cfg cache
Mar 9, 2026
5b91824
Merged CFG cache, 220 sec using tokamax_flash
eltsai Mar 11, 2026
2d4eae1
Changed profiling logic
eltsai Mar 12, 2026
438fefd
Format fix
eltsai Mar 16, 2026
dff5c30
Merge remote-tracking branch 'origin/main' into elisatsai_ring_attention
eltsai Mar 16, 2026
7293017
updated vae config logic to be the consistent, update xprof logic
eltsai Mar 19, 2026
b193301
feat: sync pyink, add splash_attention __init__, and exclude kernel t…
eltsai Mar 30, 2026
5823603
Merge origin/main into elisatsai_ring_attention
eltsai Mar 30, 2026
7375d6e
fix: reformat attention_ltx2.py jnp.clip lines to pass pyink formatter
eltsai Mar 30, 2026
768416a
Fix pylink error
eltsai Mar 30, 2026
0fa8678
fixing kernel precision
eltsai Apr 6, 2026
f7b4145
Merge origin/main: resolve VAE autoencoder and imports conflicts
eltsai Apr 6, 2026
0cc19c7
push all the changes
eltsai Apr 6, 2026
6fd09fe
downgraded pylink version
eltsai Apr 6, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ jobs:
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536'
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
# add_pull_ready:
# add_pull_ready:q
# if: github.ref != 'refs/heads/main'
# permissions:
# checks: read
# pull-requests: write
# needs: build
# uses: ./.github/workflows/AddLabel.yml
# uses: ./.github/workflows/AddLabel.yml
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,8 @@ To generate images, run the following command:
* For Wan2.2 T2V, use `base_wan_27b.yml`.
* For Wan2.2 I2V, use `base_wan_i2v_27b.yml`.

<<<<<<< HEAD
=======
### Caching Mechanisms

Wan 2.x pipelines support several caching strategies to accelerate inference by skipping redundant transformer forward passes. These are **mutually exclusive** — enable only one at a time.
Expand All @@ -597,6 +599,7 @@ To generate images, run the following command:
...
```

>>>>>>> origin/main
## Flux

First make sure you have permissions to access the Flux repos in Huggingface.
Expand Down
23 changes: 23 additions & 0 deletions docker_build_dependency_image.sh
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,28 @@ if [[ ${DEVICE} == "gpu" ]]; then
export BASEIMAGE=ghcr.io/nvidia/jax:base
fi
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE -f ./maxdiffusion_gpu_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
<<<<<<< HEAD
else
if [[ ${MODE} == "stable_stack" || ${MODE} == "jax_ai_image" ]]; then
if [[ ! -v BASEIMAGE ]]; then
echo "Erroring out because BASEIMAGE is unset, please set it!"
exit 1
fi
docker build --no-cache \
--build-arg JAX_AI_IMAGE_BASEIMAGE=${BASEIMAGE} \
--build-arg COMMIT_HASH=${COMMIT_HASH} \
--network=host \
-t ${LOCAL_IMAGE_NAME} \
-f maxdiffusion_jax_ai_image_tpu.Dockerfile .
else
docker build --no-cache \
--network=host \
--build-arg MODE=${MODE} \
--build-arg JAX_VERSION=${JAX_VERSION} \
-t ${LOCAL_IMAGE_NAME} \
-f maxdiffusion_dependencies.Dockerfile .
fi
=======
else
# Default to maxdiffusion_dependencies.Dockerfile for non-GPU builds
export BASEIMAGE=${BASEIMAGE:-python:3.12-slim-bullseye}
Expand All @@ -76,4 +98,5 @@ else
--build-arg BASEIMAGE=${BASEIMAGE} \
-t ${LOCAL_IMAGE_NAME} \
-f maxdiffusion_dependencies.Dockerfile .
>>>>>>> origin/main
fi
24 changes: 24 additions & 0 deletions maxdiffusion_dependencies.Dockerfile
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
<<<<<<< HEAD
# Use Python 3.12-slim-bullseye as the base image
FROM python:3.12-slim-bullseye
=======
# Use Python 3.12-slim-bullseye as the base image unless overridden
ARG BASEIMAGE=python:3.12-slim-bullseye
FROM $BASEIMAGE
>>>>>>> origin/main

# Environment variable for no-cache-dir and pip root user warning
ENV PIP_NO_CACHE_DIR=1
Expand All @@ -13,8 +18,13 @@ ENV CLOUD_SDK_VERSION=latest
# Set DEBIAN_FRONTEND to noninteractive to avoid frontend errors
ENV DEBIAN_FRONTEND=noninteractive

<<<<<<< HEAD
# Upgrade pip to the latest version
RUN python -m pip install --upgrade pip --no-warn-script-location
=======
# Upgrade pip to the latest version and install uv
RUN python -m pip install --upgrade pip uv --no-warn-script-location
>>>>>>> origin/main

# Install system dependencies
RUN apt-get update && apt-get install -y apt-utils git curl gnupg procps iproute2 ethtool && rm -rf /var/lib/apt/lists/*
Expand All @@ -26,12 +36,26 @@ RUN curl -fsSL https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dea
# Install the Google Cloud SDK
RUN apt-get update && apt-get install -y google-cloud-sdk && rm -rf /var/lib/apt/lists/*

<<<<<<< HEAD
# Install cloud-accelerator-diagnostics
RUN pip install cloud-accelerator-diagnostics

# Install cloud-tpu-diagnostics
RUN pip install cloud-tpu-diagnostics

# Install gcsfs
RUN pip install gcsfs

# Install google-cloud-storage
RUN pip install google-cloud-storage
=======
# Install diagnostic and storage dependencies using uv
RUN python -m uv pip install --system \
cloud-accelerator-diagnostics \
cloud-tpu-diagnostics \
gcsfs \
google-cloud-storage
>>>>>>> origin/main

# Args
ARG MODE
Expand Down
31 changes: 31 additions & 0 deletions maxdiffusion_jax_ai_image_tpu.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
ARG JAX_AI_IMAGE_BASEIMAGE

# JAX AI Base Image
FROM $JAX_AI_IMAGE_BASEIMAGE

ARG JAX_AI_IMAGE_BASEIMAGE

ARG COMMIT_HASH

ENV COMMIT_HASH=$COMMIT_HASH

RUN mkdir -p /deps

# Set the working directory in the container
WORKDIR /deps

# Copy all files from local workspace into docker container
COPY . .

# Install Maxdiffusion Jax AI Image requirements
RUN pip install -r /deps/requirements_with_jax_ai_image.txt

# TODO: Remove the flax pin and fsspec overrides once flax stable version releases
RUN if echo "$JAX_AI_IMAGE_BASEIMAGE" | grep -q "nightly"; then \
echo "Nightly build detected: Installing specific Flax commit and fsspec." && \
pip install --upgrade --force-reinstall git+https://github.com/google/flax.git@ef78d6584623511746be4824965cdef42b464583 && \
pip install "fsspec==2025.10.0"; \
fi

# Run the script available in JAX-AI-Image base image to generate the manifest file
RUN bash /jax-ai-image/generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH
41 changes: 41 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
--extra-index-url https://download.pytorch.org/whl/cpu
jax>=0.7.2
jaxlib>=0.4.30
grain
google-cloud-storage>=2.17.0
absl-py
chex
datasets
flax>=0.12.0
optax>=0.2.3
torch>=2.6.0
torchvision>=0.20.1
ftfy
tensorboard>=2.17.0
tensorboardx>=2.6.2.2
tensorboard-plugin-profile>=2.15.2
tokamax
Jinja2
scikit-image
parameterized
Pillow
pylint
pyink
pytest==8.2.2
tensorflow>=2.17.0
tensorflow-datasets>=4.9.6
ruff>=0.1.5,<=0.2
git+https://github.com/Lightricks/LTX-Video
git+https://github.com/zmelumian972/xla@torchax/jittable_module_callable#subdirectory=torchax
opencv-python-headless==4.10.0.84
orbax-checkpoint
tokenizers==0.21.0
huggingface_hub>=0.30.2
transformers==4.51.0
einops==0.8.0
sentencepiece
aqtp
imageio==2.37.0
imageio-ffmpeg==0.6.0
hf_transfer>=0.1.9
qwix@git+https://github.com/google/qwix.git
41 changes: 41 additions & 0 deletions requirements_with_jax_ai_image.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Requirements for Building the MaxDifussion Docker Image
# These requirements are additional to the dependencies present in the JAX AI base image.
--extra-index-url https://download.pytorch.org/whl/cpu
jax>=0.7.2
jaxlib>=0.4.30
grain
google-cloud-storage>=2.17.0
absl-py
chex
datasets
flax>=0.12.0
optax>=0.2.3
torch>=2.6.0
torchvision>=0.20.1
ftfy
tensorboard>=2.17.0
tensorboardx>=2.6.2.2
tensorboard-plugin-profile>=2.15.2
Jinja2
scikit-image
parameterized
Pillow
pylint
pyink
pytest==8.2.2
tensorflow>=2.17.0
tensorflow-datasets>=4.9.6
ruff>=0.1.5,<=0.2
opencv-python-headless==4.10.0.84
orbax-checkpoint
tokenizers==0.21.0
huggingface_hub>=0.30.2
transformers==4.51.0
tokamax
einops==0.8.0
sentencepiece
aqtp
imageio==2.37.0
imageio-ffmpeg==0.6.0
hf_transfer>=0.1.9
qwix@git+https://github.com/google/qwix.git
20 changes: 20 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[isort]
default_section = FIRSTPARTY
ensure_newline_before_comments = True
force_grid_wrap = 0
include_trailing_comma = True
known_first_party = accelerate
known_third_party =
numpy
torch
torch_xla

line_length = 119
lines_after_imports = 2
multi_line_output = 3
use_parentheses = True

[flake8]
ignore = E203, E722, E501, E741, W503, W605
max-line-length = 119
per-file-ignores = __init__.py:F401
Loading
Loading