Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 50 additions & 36 deletions datashare-python/datashare_python/cli/worker.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import asyncio
import logging
from pathlib import Path
from typing import Annotated

import typer
import yaml

from datashare_python.config import WorkerConfig
from datashare_python.constants import DEFAULT_NAMESPACE, DEFAULT_TEMPORAL_ADDRESS
from datashare_python.discovery import discover_activities, discover_workflows
from datashare_python.discovery import discover, discover_activities, discover_workflows
from datashare_python.types_ import TemporalClient
from datashare_python.worker import datashare_worker
from datashare_python.worker import bootstrap_worker, create_worker_id

from .utils import AsyncTyper

Expand All @@ -20,11 +24,13 @@

_START_WORKER_WORKFLOWS_HELP = "workflow names run by the worker (supports regexes)"
_START_WORKER_ACTIVITIES_HELP = "activity names run by the worker (supports regexes)"
_WORKER_QUEUE_HELP = "worker task queue"
_WORKER_MAX_ACTIVITIES_HELP = (
"maximum number of concurrent activities/tasks"
" concurrently run by the worker. Defaults to 1 to encourage horizontal scaling."
_START_WORKER_DEPS_HELP = "worker lifetime dependencies name in the registry"
_START_WORKER_WORKER_ID_PREFIX_HELP = "worker ID prefix"
_START_WORKER_CONFIG_PATH_HELP = (
"path to a worker config YAML file,"
" if not provided will load worker configuration from env variables"
)
_WORKER_QUEUE_HELP = "worker task queue"
_TEMPORAL_NAMESPACE_HELP = "worker temporal namespace"

_TEMPORAL_URL_HELP = "address for temporal server"
Expand Down Expand Up @@ -73,42 +79,50 @@ async def start(
workflows: Annotated[list[str], typer.Option(help=_START_WORKER_WORKFLOWS_HELP)],
activities: Annotated[list[str], typer.Option(help=_START_WORKER_ACTIVITIES_HELP)],
queue: Annotated[str, typer.Option("--queue", "-q", help=_WORKER_QUEUE_HELP)],
dependencies: Annotated[
str | None, typer.Option(help=_START_WORKER_DEPS_HELP)
] = None,
config_path: Annotated[
Path | None,
typer.Option(
"--config-path", "--config", "-c", help=_START_WORKER_CONFIG_PATH_HELP
),
] = None,
worker_id_prefix: Annotated[
str | None, typer.Option(help=_START_WORKER_WORKER_ID_PREFIX_HELP)
] = None,
temporal_address: Annotated[
str, typer.Option("--temporal-address", "-a", help=_TEMPORAL_URL_HELP)
] = DEFAULT_TEMPORAL_ADDRESS,
namespace: Annotated[
str, typer.Option("--temporal-namespace", "-ns", help=_TEMPORAL_NAMESPACE_HELP)
] = DEFAULT_NAMESPACE,
max_concurrent_activities: Annotated[
int, typer.Option("--max-activities", help=_WORKER_MAX_ACTIVITIES_HELP)
] = 1,
) -> None:
wf_names, wfs = zip(*discover_workflows(workflows), strict=False)
registered = ""
if wf_names:
n_wfs = len(wf_names)
registered += (
f"- {n_wfs} workflow{'s' if n_wfs > 1 else ''}: {','.join(wf_names)}"
)
act_names, acts = zip(*discover_activities(activities), strict=False)
if act_names:
if registered:
registered += "\n"
i = len(act_names)
registered += f"- {i} activit{'ies' if i > 1 else 'y'}: {','.join(act_names)}"
if not acts and not wfs:
raise ValueError("Couldn't find any registered activity or workflow.")
logger.info("Starting datashare worker running:\n%s", registered)
if config_path is not None:
with config_path.open() as f:
bootstrap_config = WorkerConfig.model_validate(
yaml.load(f, Loader=yaml.Loader)
)
else:
bootstrap_config = WorkerConfig()
registered_wfs, registered_acts, registered_deps = discover(
workflows, act_names=activities, deps_name=dependencies
)
worker_id = create_worker_id(worker_id_prefix or "worker")
client = await TemporalClient.connect(temporal_address, namespace=namespace)
worker = datashare_worker(
client,
workflows=wfs,
activities=acts,
event_loop = asyncio.get_event_loop()
async with bootstrap_worker(
worker_id,
activities=registered_acts,
workflows=registered_wfs,
dependencies=registered_deps,
bootstrap_config=bootstrap_config,
client=client,
event_loop=event_loop,
task_queue=queue,
max_concurrent_activities=max_concurrent_activities,
)
try:
await worker.run()
except Exception as e: # noqa: BLE001
await worker.shutdown()
raise e
) as worker:
try:
await worker.run()
except Exception as e: # noqa: BLE001
await worker.shutdown()
raise e
2 changes: 2 additions & 0 deletions datashare-python/datashare_python/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ class WorkerConfig(ICIJSettings, LogWithWorkerIDMixin, BaseModel):
elasticsearch: ESClientConfig = ESClientConfig()
temporal: TemporalClientConfig = TemporalClientConfig()

max_concurrent_io_activities: int = 5

def to_es_client(self) -> ESClient:
return self.elasticsearch.to_es_client(self.datashare.api_key)

Expand Down
8 changes: 8 additions & 0 deletions datashare-python/datashare_python/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from asyncio import AbstractEventLoop
from collections.abc import AsyncGenerator, Generator, Iterator, Sequence
from pathlib import Path

import aiohttp
import pytest
Expand Down Expand Up @@ -102,6 +103,13 @@ def test_worker_config() -> WorkerConfig:
)


@pytest.fixture
def test_worker_config_path(test_worker_config: WorkerConfig, tmpdir: Path) -> Path:
config_path = Path(tmpdir) / "config.json"
config_path.write_text(test_worker_config.model_dump_json())
return config_path


@pytest.fixture(scope="session")
async def worker_lifetime_deps(
event_loop: AbstractEventLoop,
Expand Down
34 changes: 27 additions & 7 deletions datashare-python/datashare_python/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import inspect
import logging
from asyncio import AbstractEventLoop, iscoroutine
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable
from contextlib import AsyncExitStack, asynccontextmanager
from contextvars import ContextVar
from copy import deepcopy
from typing import Any

from icij_common.es import ESClient

Expand All @@ -20,7 +23,7 @@
TEMPORAL_CLIENT: ContextVar[TemporalClient] = ContextVar("temporal_client")


def set_event_loop(event_loop: AbstractEventLoop, **_) -> None:
def set_event_loop(event_loop: AbstractEventLoop) -> None:
EVENT_LOOP.set(event_loop)


Expand All @@ -31,13 +34,13 @@ def lifespan_event_loop() -> AbstractEventLoop:
raise DependencyInjectionError("event loop") from e


def set_loggers(worker_config: WorkerConfig, worker_id: str, **_) -> None:
def set_loggers(worker_config: WorkerConfig, worker_id: str) -> None:
worker_config.setup_loggers(worker_id=worker_id)
logger.info("worker loggers ready to log 💬")
logger.info("app config: %s", worker_config.model_dump_json(indent=2))


async def set_es_client(worker_config: WorkerConfig, **_) -> ESClient:
async def set_es_client(worker_config: WorkerConfig) -> ESClient:
client = worker_config.to_es_client()
ES_CLIENT.set(client)
return client
Expand All @@ -52,7 +55,7 @@ def lifespan_es_client() -> ESClient:


# Task client setup
async def set_task_client(worker_config: WorkerConfig, **_) -> DatashareTaskClient:
async def set_task_client(worker_config: WorkerConfig) -> DatashareTaskClient:
task_client = worker_config.to_task_client()
TASK_CLIENT.set(task_client)
return task_client
Expand All @@ -67,7 +70,7 @@ def lifespan_task_client() -> DatashareTaskClient:


# Temporal client setup
async def set_temporal_client(worker_config: WorkerConfig, **_) -> None:
async def set_temporal_client(worker_config: WorkerConfig) -> None:
client = await worker_config.to_temporal_client()
TEMPORAL_CLIENT.set(client)

Expand All @@ -86,11 +89,28 @@ async def with_dependencies(
) -> AsyncGenerator[None, None]:
async with AsyncExitStack() as stack:
for dep in dependencies:
cm = dep(**kwargs)
cm = dep(**add_missing_args(dep, kwargs))
if hasattr(cm, "__aenter__"):
await stack.enter_async_context(cm)
elif hasattr(cm, "__enter__"):
stack.enter_context(cm)
elif iscoroutine(cm):
await cm
yield


def add_missing_args(fn: Callable, args: dict[str, Any], **kwargs) -> dict[str, Any]:
# We make the choice not to raise in case of missing argument here, the error will
# be correctly raise when the function is called
from_kwargs = dict()
sig = inspect.signature(fn)
for param_name in sig.parameters:
if param_name in args:
continue
kwargs_value = kwargs.get(param_name)
if kwargs_value is not None:
from_kwargs[param_name] = kwargs_value
if from_kwargs:
args = deepcopy(args)
args.update(from_kwargs)
return args
94 changes: 92 additions & 2 deletions datashare-python/datashare_python/discovery.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,69 @@
import logging
import re
from collections.abc import Callable, Iterable
from importlib.metadata import entry_points

from .types_ import ContextManagerFactory
from .utils import ActivityWithProgress

logger = logging.getLogger(__name__)

Activity = ActivityWithProgress | Callable | type

_DEPENDENCIES = "dependencies"
_WORKFLOW_GROUPS = "datashare.workflows"
_ACTIVITIES_GROUPS = "datashare.activities"
_DEPENDENCIES_GROUPS = "datashare.dependencies"

_RegisteredWorkflow = tuple[str, type]
_RegisteredActivity = tuple[str, Activity]
_Dependencies = list[ContextManagerFactory]
_Discovery = tuple[
Iterable[_RegisteredWorkflow] | None,
Iterable[_RegisteredActivity] | None,
_Dependencies | None,
]


def discover_workflows(names: list[str]) -> Iterable[tuple[str, type]]:
def discover(
wf_names: list[str] | None, *, act_names: list[str] | None, deps_name: str | None
) -> _Discovery:
discovered = ""
wfs = None
if wf_names is not None:
wf_names, wfs = zip(*discover_workflows(wf_names), strict=True)
if wf_names:
n_wfs = len(wf_names)
discovered += (
f"- {n_wfs} workflow{'s' if n_wfs > 1 else ''}: {', '.join(wf_names)}"
)
acts = None
if act_names is not None:
act_names, acts = zip(*discover_activities(act_names), strict=True)
if act_names:
if discovered:
discovered += "\n"
n_acts = len(act_names)
discovered += (
f"- {n_acts} activit{'ies' if n_acts > 1 else 'y'}:"
f" {', '.join(act_names)}"
)
if not acts and not wfs:
raise ValueError("Couldn't find any registered activity or workflow.")
deps = discover_dependencies(deps_name)
if deps:
n_deps = len(deps)
discovered += "\n"
deps_names = (d.__name__ for d in deps)
discovered += (
f"- {n_deps} dependenc{'ies' if n_deps > 1 else 'y'}:"
f" {', '.join(deps_names)}"
)
logger.info("discovered:\n%s", discovered)
return wfs, acts, deps


def discover_workflows(names: list[str]) -> Iterable[_RegisteredWorkflow]:
pattern = None if not names else re.compile(rf"^{'|'.join(names)}$")
impls = entry_points(group=_WORKFLOW_GROUPS)
for wf_impls in impls:
Expand All @@ -24,7 +77,7 @@ def discover_workflows(names: list[str]) -> Iterable[tuple[str, type]]:
yield wf_name, wf_impl


def discover_activities(names: list[str]) -> Iterable[tuple[str, Activity]]:
def discover_activities(names: list[str]) -> Iterable[_RegisteredActivity]:
pattern = None if not names else re.compile(rf"^{'|'.join(names)}$")
impls = entry_points(group=_ACTIVITIES_GROUPS)
for act_impls in impls:
Expand All @@ -38,6 +91,43 @@ def discover_activities(names: list[str]) -> Iterable[tuple[str, Activity]]:
yield act_name, act_impl


def discover_dependencies(name: str | None) -> _Dependencies | None:
impls = entry_points(name=_DEPENDENCIES, group=_DEPENDENCIES_GROUPS)
if not impls:
if name is None:
return None
available_impls = entry_points(group=_DEPENDENCIES_GROUPS)
msg = (
f'failed to find dependency: "{name}", '
f"available dependencies: {available_impls}"
)
raise LookupError(msg)
if len(impls) > 1:
msg = f'found multiple dependencies for name "{name}": {impls}'
raise ValueError(msg)
deps_registry = impls[_DEPENDENCIES].load()
if name:
try:
return deps_registry[name]
except KeyError as e:
available = list(deps_registry)
msg = (
f'failed to find dependency for name "{name}", available dependencies: '
f"{available}"
)
raise LookupError(msg) from e
if not deps_registry:
raise ValueError("empty dependency registry !")
if len(deps_registry) > 1:
available = ", ".join('"' + d + '"' for d in deps_registry)
msg = (
f"dependency registry contains multiples entries {available},"
f" please select one by providing a name"
)
raise ValueError(msg)
return next(iter(deps_registry.values()))


def _parse_wf_name(wf_type: type) -> str:
if not isinstance(wf_type, type):
msg = (
Expand Down
Loading
Loading