Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions src/executorlib/task_scheduler/file/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, Callable, Optional

from executorlib.standalone.command import get_cache_execute_command
from executorlib.standalone.hdf import get_cache_files, get_output
from executorlib.standalone.hdf import get_cache_files, get_output, get_queue_id
from executorlib.standalone.serialize import serialize_funct
from executorlib.task_scheduler.file.spawner_subprocess import terminate_subprocess

Expand Down Expand Up @@ -130,13 +130,12 @@ def execute_tasks_h5(
memory_dict=memory_dict,
file_name_dict=file_name_dict,
)
task_resource_dict = task_dict["resource_dict"].copy()
task_resource_dict.update(
{k: v for k, v in resource_dict.items() if k not in task_resource_dict}
task_resource_dict, cache_key, cache_directory, error_log_file = (
_get_task_input(
task_resource_dict=task_dict["resource_dict"].copy(),
resource_dict=resource_dict,
)
)
cache_key = task_resource_dict.pop("cache_key", None)
cache_directory = os.path.abspath(task_resource_dict.pop("cache_directory"))
error_log_file = task_resource_dict.pop("error_log_file", None)
task_key, data_dict = serialize_funct(
fn=task_dict["fn"],
fn_args=task_args,
Expand All @@ -152,7 +151,9 @@ def execute_tasks_h5(
file_name = os.path.join(cache_directory, task_key + "_i.h5")
if not disable_dependencies:
task_dependent_lst = [
process_dict[k] for k in future_wait_key_lst
process_dict[k]
for k in future_wait_key_lst
if k in process_dict
]
else:
if len(future_wait_key_lst) > 0:
Expand Down Expand Up @@ -181,9 +182,11 @@ def execute_tasks_h5(
backend=backend,
cache_directory=cache_directory,
)
file_name_dict[task_key] = os.path.join(
cache_directory, task_key + "_o.h5"
)
file_name = os.path.join(cache_directory, task_key + "_o.h5")
file_name_dict[task_key] = file_name
queue_id = get_queue_id(file_name=file_name)
if queue_id is not None:
process_dict[task_key] = queue_id
memory_dict[task_key] = task_dict["future"]
cache_dir_dict[task_key] = cache_directory
future_queue.task_done()
Expand Down Expand Up @@ -354,3 +357,15 @@ def _cancel_processes(
config_directory=pysqa_config_directory,
backend=backend,
)


def _get_task_input(
task_resource_dict: dict, resource_dict: dict
) -> tuple[dict, Optional[str], str, Optional[str]]:
task_resource_dict.update(
{k: v for k, v in resource_dict.items() if k not in task_resource_dict}
)
cache_key = task_resource_dict.pop("cache_key", None)
cache_directory = os.path.abspath(task_resource_dict.pop("cache_directory"))
error_log_file = task_resource_dict.pop("error_log_file", None)
return task_resource_dict, cache_key, cache_directory, error_log_file
Loading