Open
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
vlad-karp
reviewed
Apr 2, 2026
src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py
Outdated
Show resolved
Hide resolved
src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py
Outdated
Show resolved
Hide resolved
…warding through the iterator
vlad-karp
approved these changes
Apr 6, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR introduces a fault-tolerant approach to saving top-k teacher logits. Previously, the top-k teacher logits were written to a file (either local or gcs) after the teacher model has completed the set number of steps. However, if the job crashes or something happens where the script is abruptly ended, we need to re-run the saving of top-k teacher logits from scratch again. This PR introduces the fault-tolerance, where we write the data in chunks to a folder (in local or gcs). Now, you need to specify a value for the cmd arg
--steps_per_file, which will save a file with the logits of that chunk (number of steps). This way, if you need to save the top-k teacher logits for 100 steps, and--steps_per_file=10, this will create 10 chunk files. If the program crashes abruptly, the code will look at the output directory, check how many chunk files were written, and resume saving the top-k teacher logits from where it left off. This allows for fault-tolerant data collection and can be very beneficial for long running experiments.Tests
YAML file for testing: YAML
Ran the following command to save top-k teacher logits, and you can see the ouptut where it saves a file every 10 steps when
--steps_per_file=10:Command:
python3 src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py \ src/maxtext/configs/post_train/distillation.yml \ --local_tmp_dir=/tmp/save_logits_dir \ --steps_per_file=10Output Logs: Logs showing successful saving of top-k teacher logits (every 10 steps in chunks)
Abruptly stopped the saving file on purpose using cntrl + C at 140 steps, and ran the training saving top-k logits script again to see if the saving resumes from the previous point. We can see from the output below that there is a comment "Found existing data, resuming from step 140". This confirms that the fault-tolerance works:
Next, I modified the script tht verifies the correctness of the saved top-k teacher logits to take into account the chunked files. The output shows that verification is successful with this new change and that the data is being properly written:
python3 python3 src/maxtext/trainers/post_train/distillation/verify_saved_logits.py \ --output_dir=/tmp/save_logits_dir \ --expected_steps=140Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.