diff --git a/src/executorlib/task_scheduler/file/shared.py b/src/executorlib/task_scheduler/file/shared.py index 7c3183b9..491f34c9 100644 --- a/src/executorlib/task_scheduler/file/shared.py +++ b/src/executorlib/task_scheduler/file/shared.py @@ -5,7 +5,7 @@ from typing import Any, Callable, Optional from executorlib.standalone.command import get_cache_execute_command -from executorlib.standalone.hdf import get_cache_files, get_output +from executorlib.standalone.hdf import get_cache_files, get_output, get_queue_id from executorlib.standalone.serialize import serialize_funct from executorlib.task_scheduler.file.spawner_subprocess import terminate_subprocess @@ -130,13 +130,12 @@ def execute_tasks_h5( memory_dict=memory_dict, file_name_dict=file_name_dict, ) - task_resource_dict = task_dict["resource_dict"].copy() - task_resource_dict.update( - {k: v for k, v in resource_dict.items() if k not in task_resource_dict} + task_resource_dict, cache_key, cache_directory, error_log_file = ( + _get_task_input( + task_resource_dict=task_dict["resource_dict"].copy(), + resource_dict=resource_dict, + ) ) - cache_key = task_resource_dict.pop("cache_key", None) - cache_directory = os.path.abspath(task_resource_dict.pop("cache_directory")) - error_log_file = task_resource_dict.pop("error_log_file", None) task_key, data_dict = serialize_funct( fn=task_dict["fn"], fn_args=task_args, @@ -152,7 +151,9 @@ def execute_tasks_h5( file_name = os.path.join(cache_directory, task_key + "_i.h5") if not disable_dependencies: task_dependent_lst = [ - process_dict[k] for k in future_wait_key_lst + process_dict[k] + for k in future_wait_key_lst + if k in process_dict ] else: if len(future_wait_key_lst) > 0: @@ -181,9 +182,11 @@ def execute_tasks_h5( backend=backend, cache_directory=cache_directory, ) - file_name_dict[task_key] = os.path.join( - cache_directory, task_key + "_o.h5" - ) + file_name = os.path.join(cache_directory, task_key + "_o.h5") + file_name_dict[task_key] = file_name + queue_id = get_queue_id(file_name=file_name) + if queue_id is not None: + process_dict[task_key] = queue_id memory_dict[task_key] = task_dict["future"] cache_dir_dict[task_key] = cache_directory future_queue.task_done() @@ -354,3 +357,15 @@ def _cancel_processes( config_directory=pysqa_config_directory, backend=backend, ) + + +def _get_task_input( + task_resource_dict: dict, resource_dict: dict +) -> tuple[dict, Optional[str], str, Optional[str]]: + task_resource_dict.update( + {k: v for k, v in resource_dict.items() if k not in task_resource_dict} + ) + cache_key = task_resource_dict.pop("cache_key", None) + cache_directory = os.path.abspath(task_resource_dict.pop("cache_directory")) + error_log_file = task_resource_dict.pop("error_log_file", None) + return task_resource_dict, cache_key, cache_directory, error_log_file