diff --git a/checkpoint/orbax/checkpoint/checkpoint_manager.py b/checkpoint/orbax/checkpoint/checkpoint_manager.py index 6ef2de477..ade52f910 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/checkpoint_manager.py @@ -925,6 +925,7 @@ def __init__( ) self._last_save_time = None + self._wait_for_prev_save_duration = 0.0 logging.info( '[process=%s][thread=%s] CheckpointManager created, primary_host=%s,' @@ -1443,14 +1444,8 @@ def save( step_stats.time_between_consecutive_saves_sec, ) self.wait_until_finished() - step_stats.wait_for_prev_duration_secs = ( - time.time() - step_stats.wait_for_prev_start_time - ) - - jax.monitoring.record_event_duration_secs( - '/jax/checkpoint/write/wait_for_prev_duration_secs', - step_stats.wait_for_prev_duration_secs, - ) + step_stats.wait_for_prev_duration_secs = self._wait_for_prev_save_duration + self._wait_for_prev_save_duration = 0.0 if ( step_stats.wait_for_prev_duration_secs > _WAIT_FOR_PREV_SAVE_WARNING_THRESHOLD_SECS @@ -2002,23 +1997,27 @@ def wait_until_finished(self): If some checkpointers are of type :py:class:`.AsyncCheckpointer`, however, this method will wait until each of these checkpointers is finished. """ + start_time = time.time() process_index = multihost.process_index() current_thread = threading.current_thread() + finalize_thread_name = None + step = None - if self._finalize_thread.map( - lambda t: t is None or (not t.is_alive() and t.exception is None) - ): - logging.info( - '[process=%s][thread=%s][wait_until_finished] No Save Finalize' - ' thread to wait for. Returning.', - process_index, - current_thread.name, - ) - return - - step = self._finalize_thread.get_not_none().step() - finalize_thread_name = self._finalize_thread.get_not_none().name try: + if self._finalize_thread.map( + lambda t: t is None or (not t.is_alive() and t.exception is None) + ): + logging.info( + '[process=%s][thread=%s][wait_until_finished] No Save Finalize' + ' thread to wait for. Returning.', + process_index, + current_thread.name, + ) + return + + step = self._finalize_thread.get_not_none().step() + finalize_thread_name = self._finalize_thread.get_not_none().name + logging.info( '[process=%s][thread=%s][step=%s][wait_until_finished] Waiting for' ' Save Finalize thread (%s) to complete.', @@ -2051,6 +2050,14 @@ def wait_until_finished(self): ) self._checkpoints.delete_if(lambda info: info.step == step) raise + finally: + duration = time.time() - start_time + if duration > 0: + jax.monitoring.record_event_duration_secs( + '/jax/checkpoint/write/wait_for_prev_duration_secs', + duration, + ) + self._wait_for_prev_save_duration += duration def is_saving_in_progress(self) -> bool: """Returns whether a checkpoint save is in progress.""" diff --git a/checkpoint/orbax/checkpoint/checkpoint_manager_test.py b/checkpoint/orbax/checkpoint/checkpoint_manager_test.py index a15407370..7f8978632 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_manager_test.py +++ b/checkpoint/orbax/checkpoint/checkpoint_manager_test.py @@ -2432,6 +2432,7 @@ def test_save_and_restore_standard_logger(self): step_statistics['checkpoint_manager_duration_secs'] ) + def test_configure_atomicity(self): """Test case.""" with CheckpointManager(