Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 28 additions & 21 deletions checkpoint/orbax/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,7 @@ def __init__(
)

self._last_save_time = None
self._wait_for_prev_save_duration = 0.0

logging.info(
'[process=%s][thread=%s] CheckpointManager created, primary_host=%s,'
Expand Down Expand Up @@ -1443,14 +1444,8 @@ def save(
step_stats.time_between_consecutive_saves_sec,
)
self.wait_until_finished()
step_stats.wait_for_prev_duration_secs = (
time.time() - step_stats.wait_for_prev_start_time
)

jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/wait_for_prev_duration_secs',
step_stats.wait_for_prev_duration_secs,
)
step_stats.wait_for_prev_duration_secs = self._wait_for_prev_save_duration
self._wait_for_prev_save_duration = 0.0
if (
step_stats.wait_for_prev_duration_secs
> _WAIT_FOR_PREV_SAVE_WARNING_THRESHOLD_SECS
Expand Down Expand Up @@ -2002,23 +1997,27 @@ def wait_until_finished(self):
If some checkpointers are of type :py:class:`.AsyncCheckpointer`, however,
this method will wait until each of these checkpointers is finished.
"""
start_time = time.time()
process_index = multihost.process_index()
current_thread = threading.current_thread()
finalize_thread_name = None
step = None

if self._finalize_thread.map(
lambda t: t is None or (not t.is_alive() and t.exception is None)
):
logging.info(
'[process=%s][thread=%s][wait_until_finished] No Save Finalize'
' thread to wait for. Returning.',
process_index,
current_thread.name,
)
return

step = self._finalize_thread.get_not_none().step()
finalize_thread_name = self._finalize_thread.get_not_none().name
try:
if self._finalize_thread.map(
lambda t: t is None or (not t.is_alive() and t.exception is None)
):
logging.info(
'[process=%s][thread=%s][wait_until_finished] No Save Finalize'
' thread to wait for. Returning.',
process_index,
current_thread.name,
)
return

step = self._finalize_thread.get_not_none().step()
finalize_thread_name = self._finalize_thread.get_not_none().name

logging.info(
'[process=%s][thread=%s][step=%s][wait_until_finished] Waiting for'
' Save Finalize thread (%s) to complete.',
Expand Down Expand Up @@ -2051,6 +2050,14 @@ def wait_until_finished(self):
)
self._checkpoints.delete_if(lambda info: info.step == step)
raise
finally:
duration = time.time() - start_time
if duration > 0:
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/wait_for_prev_duration_secs',
duration,
)
self._wait_for_prev_save_duration += duration

def is_saving_in_progress(self) -> bool:
"""Returns whether a checkpoint save is in progress."""
Expand Down
1 change: 1 addition & 0 deletions checkpoint/orbax/checkpoint/checkpoint_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2432,6 +2432,7 @@ def test_save_and_restore_standard_logger(self):
step_statistics['checkpoint_manager_duration_secs']
)


def test_configure_atomicity(self):
"""Test case."""
with CheckpointManager(
Expand Down
Loading