diff --git a/AGENTS.md b/AGENTS.md index 37402fde..3ddbe3f4 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -77,6 +77,16 @@ current event. - `on_error_execution()` works via naming convention but **only** when a transition for `error.execution` is declared — it is NOT a generic callback. +### Thread safety + +- The sync engine is **thread-safe**: multiple threads can send events to the same SM instance + concurrently. The processing loop uses a `threading.Lock` so at most one thread executes + transitions at a time. Event queues use `PriorityQueue` (stdlib, thread-safe). +- **Do not replace `PriorityQueue`** with non-thread-safe alternatives (e.g., `collections.deque`, + plain `list`) — this would break concurrent access guarantees. +- Stress tests in `tests/test_threading.py::TestThreadSafety` exercise real contention with + barriers and multiple sender threads. Any change to queue or locking internals must pass these. + ### Invoke (``) - `invoke.py` — `InvokeManager` on the engine manages the lifecycle: `mark_for_invoke()`, @@ -127,6 +137,16 @@ timeout 120 uv run pytest -n 4 Testes normally run under 60s (~40s on average), so take a closer look if they take longer, it can be a regression. +### Debug logging + +`log_cli_level` defaults to `WARNING` in `pyproject.toml`. The engine caches a no-op +for `logger.debug` at init time — running tests with `DEBUG` would bypass this +optimization and inflate benchmark numbers. To enable debug logs for a specific run: + +```bash +uv run pytest -o log_cli_level=DEBUG tests/test_something.py +``` + When analyzing warnings or extensive output, run the tests **once** saving the output to a file (`> /tmp/pytest-output.txt 2>&1`), then analyze the file — instead of running the suite repeatedly with different greps. diff --git a/docs/processing_model.md b/docs/processing_model.md index 50b6a998..d8181316 100644 --- a/docs/processing_model.md +++ b/docs/processing_model.md @@ -315,3 +315,50 @@ The machine starts, enters `trying` (attempt 1), and the eventless self-transition keeps firing as long as `can_retry()` returns `True`. Once the limit is reached, the second eventless transition fires — all within a single macrostep triggered by initialization. + + +(thread-safety)= + +## Thread safety + +State machines are **thread-safe** for concurrent event sending. Multiple threads +can call `send()` or trigger events on the **same state machine instance** +simultaneously — the engine guarantees correct behavior through its internal +locking mechanism. + +### How it works + +The processing loop uses a non-blocking lock (`threading.Lock`). When a thread +sends an event: + +1. The event is placed on the **external queue** (backed by a thread-safe + `PriorityQueue` from the standard library). +2. If no other thread is currently running the processing loop, the sending + thread acquires the lock and processes all queued events. +3. If another thread is already processing, the event is simply enqueued and + will be processed by the thread that holds the lock — no event is lost. + +This means that **at most one thread executes transitions at any time**, preserving +the run-to-completion (RTC) guarantee while allowing safe concurrent access. + +### What is safe + +- **Multiple threads sending events** to the same state machine instance. +- **Reading state** (`current_state_value`, `configuration`) from any thread + while events are being processed. Note that transient `None` values may be + observed for `current_state_value` during configuration updates when using + [`atomic_configuration_update`](behaviour.md#atomic_configuration_update) `= False` + (the default on `StateChart`, SCXML-compliant). With `atomic_configuration_update = True` + (the default on `StateMachine`), the configuration is updated atomically at + the end of the microstep, so `None` is not observed. +- **Invoke handlers** running in background threads or thread executors + communicate with the parent machine via the thread-safe event queue. + +### What to avoid + +- **Do not share a state machine instance across threads with the async engine** + unless you ensure only one event loop drives the machine. The async engine is + designed for `asyncio` concurrency, not thread-based concurrency. +- **Callbacks execute in the processing thread**, not in the thread that sent + the event. Design callbacks accordingly (e.g., use locks if they access + shared external state). diff --git a/docs/releases/3.1.0.md b/docs/releases/3.1.0.md index 69b194e7..77c702b2 100644 --- a/docs/releases/3.1.0.md +++ b/docs/releases/3.1.0.md @@ -33,6 +33,22 @@ See {ref}`diagram:Sphinx directive` for full documentation. [#589](https://github.com/fgmacedo/python-statemachine/pull/589). +### Performance: 5x–7x faster event processing + +The engine's hot paths have been systematically profiled and optimized, resulting in +**4.7x–7.7x faster event throughput** and **1.9x–2.6x faster setup** across all +machine types. All optimizations are internal — no public API changes. +See [#592](https://github.com/fgmacedo/python-statemachine/pull/592) for details. + + +### Thread safety documentation + +The sync engine is thread-safe: multiple threads can send events to the same state +machine instance concurrently. This is now documented in the +{ref}`processing model ` and verified by stress tests. +[#592](https://github.com/fgmacedo/python-statemachine/pull/592). + + ### Bugfixes in 3.1.0 - Fixes silent misuse of `Event()` with multiple positional arguments. Passing more than one diff --git a/pyproject.toml b/pyproject.toml index 50912c67..b05ccfb3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,11 @@ markers = [ ] python_files = ["tests.py", "test_*.py", "*_tests.py"] xfail_strict = true -log_cli_level = "DEBUG" +# Log level WARNING by default; the engine caches a no-op for logger.debug at +# init time, so DEBUG here would bypass that optimization and slow benchmarks. +# To enable DEBUG logging for a specific test run: +# uv run pytest -o log_cli_level=DEBUG +log_cli_level = "WARNING" log_cli_format = "%(relativeCreated)6.0fms %(threadName)-18s %(name)-35s %(message)s" log_cli_date_format = "%H:%M:%S" asyncio_default_fixture_loop_scope = "module" @@ -131,7 +135,14 @@ disable_error_code = "annotation-unchecked" mypy_path = "$MYPY_CONFIG_FILE_DIR/tests/django_project" [[tool.mypy.overrides]] -module = ['django.*', 'pytest.*', 'pydot.*', 'sphinx_gallery.*', 'docutils.*', 'sphinx.*'] +module = [ + 'django.*', + 'pytest.*', + 'pydot.*', + 'sphinx_gallery.*', + 'docutils.*', + 'sphinx.*', +] ignore_missing_imports = true [tool.ruff] diff --git a/statemachine/configuration.py b/statemachine/configuration.py new file mode 100644 index 00000000..e5fd79ed --- /dev/null +++ b/statemachine/configuration.py @@ -0,0 +1,159 @@ +from typing import TYPE_CHECKING +from typing import Any +from typing import Dict +from typing import Mapping +from typing import MutableSet + +from .exceptions import InvalidStateValue +from .i18n import _ +from .orderedset import OrderedSet + +_SENTINEL = object() + +if TYPE_CHECKING: + from .state import State + + +class Configuration: + """Encapsulates the dual representation of the active state configuration. + + Internally, ``current_state_value`` is either a scalar (single active state) + or an ``OrderedSet`` (parallel regions). This class hides that detail behind + a uniform interface for reading, mutating, and caching the resolved + ``OrderedSet[State]``. + """ + + __slots__ = ( + "_instance_states", + "_model", + "_state_field", + "_states_map", + "_cached", + "_cached_value", + ) + + def __init__( + self, + instance_states: "Mapping[str, State]", + model: Any, + state_field: str, + states_map: "Dict[Any, State]", + ): + self._instance_states = instance_states + self._model = model + self._state_field = state_field + self._states_map = states_map + self._cached: "OrderedSet[State] | None" = None + self._cached_value: Any = _SENTINEL + + # -- Raw value (persisted on the model) ------------------------------------ + + @property + def value(self) -> Any: + """The raw state value stored on the model (scalar or ``OrderedSet``).""" + return getattr(self._model, self._state_field, None) + + @value.setter + def value(self, val: Any): + self._invalidate() + if val is not None and not isinstance(val, MutableSet) and val not in self._states_map: + raise InvalidStateValue(val) + setattr(self._model, self._state_field, val) + + @property + def values(self) -> OrderedSet[Any]: + """The set of raw state values currently active.""" + v = self.value + if isinstance(v, OrderedSet): + return v + return OrderedSet([v]) + + # -- Resolved states ------------------------------------------------------- + + @property + def states(self) -> "OrderedSet[State]": + """The set of currently active :class:`State` instances (cached).""" + csv = self.value + if self._cached is not None and self._cached_value is csv: + return self._cached + if csv is None: + return OrderedSet() + + instance_states = self._instance_states + if not isinstance(csv, MutableSet): + result = OrderedSet([instance_states[self._states_map[csv].id]]) + else: + result = OrderedSet([instance_states[self._states_map[v].id] for v in csv]) + + self._cached = result + self._cached_value = csv + return result + + @states.setter + def states(self, new_configuration: "OrderedSet[State]"): + if len(new_configuration) == 0: + self.value = None + elif len(new_configuration) == 1: + self.value = next(iter(new_configuration)).value + else: + self.value = OrderedSet(s.value for s in new_configuration) + + # -- Incremental mutation (used by the engine) ----------------------------- + + def add(self, state: "State"): + """Add *state* to the configuration, maintaining the dual representation.""" + csv = self.value + if csv is None: + self.value = state.value + elif isinstance(csv, MutableSet): + csv.add(state.value) + self._invalidate() + else: + self.value = OrderedSet([csv, state.value]) + + def discard(self, state: "State"): + """Remove *state* from the configuration, normalizing back to scalar.""" + csv = self.value + if isinstance(csv, MutableSet): + csv.discard(state.value) + self._invalidate() + if len(csv) == 1: + self.value = next(iter(csv)) + elif len(csv) == 0: + self.value = None + elif csv == state.value: + self.value = None + + # -- Deprecated v2 compat -------------------------------------------------- + + @property + def current_state(self) -> "State | OrderedSet[State]": + """Resolve the current state with validation. + + Unlike ``states`` (which returns an empty set for ``None``), this + raises ``InvalidStateValue`` when the value is ``None`` or not + found in ``states_map`` — matching the v2 ``current_state`` contract. + """ + csv = self.value + if csv is None: + raise InvalidStateValue( + csv, + _( + "There's no current state set. In async code, " + "did you activate the initial state? " + "(e.g., `await sm.activate_initial_state()`)" + ), + ) + try: + config = self.states + if len(config) == 1: + return next(iter(config)) + return config + except KeyError as err: + raise InvalidStateValue(csv) from err + + # -- Internal -------------------------------------------------------------- + + def _invalidate(self): + self._cached = None + self._cached_value = _SENTINEL diff --git a/statemachine/engines/async_.py b/statemachine/engines/async_.py index 67239fd8..9e055610 100644 --- a/statemachine/engines/async_.py +++ b/statemachine/engines/async_.py @@ -1,6 +1,5 @@ import asyncio import contextvars -import logging from itertools import chain from time import time from typing import TYPE_CHECKING @@ -17,12 +16,8 @@ from .base import BaseEngine if TYPE_CHECKING: - from ..event import Event from ..transition import Transition -logger = logging.getLogger(__name__) - - # ContextVar to distinguish reentrant calls (from within callbacks) from # concurrent external calls. asyncio propagates context to child tasks # (e.g., those created by asyncio.gather in the callback system), so a @@ -109,6 +104,23 @@ async def _conditions_match(self, transition: "Transition", trigger_data: Trigge transition.cond.key, *args, on_error=on_error, **kwargs ) + async def _first_transition_that_matches( # type: ignore[override] + self, + state: State, + trigger_data: TriggerData, + predicate: Callable, + ) -> "Transition | None": + for s in chain([state], state.ancestors()): + transition: "Transition" + for transition in s.transitions: + if ( + not transition.initial + and predicate(transition, trigger_data.event) + and await self._conditions_match(transition, trigger_data) + ): + return transition + return None + async def _select_transitions( # type: ignore[override] self, trigger_data: TriggerData, predicate: Callable ) -> "OrderedSet[Transition]": @@ -116,22 +128,8 @@ async def _select_transitions( # type: ignore[override] atomic_states = (state for state in self.sm.configuration if state.is_atomic) - async def first_transition_that_matches( - state: State, event: "Event | None" - ) -> "Transition | None": - for s in chain([state], state.ancestors()): - transition: "Transition" - for transition in s.transitions: - if ( - not transition.initial - and predicate(transition, event) - and await self._conditions_match(transition, trigger_data) - ): - return transition - return None - for state in atomic_states: - transition = await first_transition_that_matches(state, trigger_data.event) + transition = await self._first_transition_that_matches(state, trigger_data, predicate) if transition is not None: enabled_transitions.add(transition) @@ -179,7 +177,7 @@ async def _exit_states( # type: ignore[override] args, kwargs = await self._get_args_kwargs(info.transition, trigger_data) if info.state is not None: # pragma: no branch - logger.debug("%s Exiting state: %s", self._log_id, info.state) + self._debug("%s Exiting state: %s", self._log_id, info.state) await self.sm._callbacks.async_call( info.state.exit.key, *args, on_error=on_error, **kwargs ) @@ -234,7 +232,7 @@ async def _enter_states( # noqa: C901 target=target, ) - logger.debug("%s Entering state: %s", self._log_id, target) + self._debug("%s Entering state: %s", self._log_id, target) self._add_state_to_configuration(target) on_entry_result = await self.sm._callbacks.async_call( @@ -274,7 +272,7 @@ async def _enter_states( # noqa: C901 async def microstep(self, transitions: "List[Transition]", trigger_data: TriggerData): self._microstep_count += 1 - logger.debug( + self._debug( "%s macro:%d micro:%d transitions: %s", self._log_id, self._macrostep_count, @@ -366,7 +364,7 @@ async def processing_loop( # noqa: C901 return None _ctx_token = _in_processing_loop.set(True) - logger.debug("%s Processing loop started: %s", self._log_id, self.sm.current_state_value) + self._debug("%s Processing loop started: %s", self._log_id, self.sm.current_state_value) first_result = self._sentinel try: took_events = True @@ -378,7 +376,7 @@ async def processing_loop( # noqa: C901 # Phase 1: eventless transitions and internal events while not macrostep_done: self._microstep_count = 0 - logger.debug( + self._debug( "%s Macrostep %d: eventless/internal queue", self._log_id, self._macrostep_count, @@ -394,7 +392,7 @@ async def processing_loop( # noqa: C901 internal_event = self.internal_queue.pop() enabled_transitions = await self.select_transitions(internal_event) if enabled_transitions: - logger.debug( + self._debug( "%s Enabled transitions: %s", self._log_id, enabled_transitions ) took_events = True @@ -412,9 +410,7 @@ async def processing_loop( # noqa: C901 await self._run_microstep(enabled_transitions, internal_event) # Phase 3: external events - logger.debug( - "%s Macrostep %d: external queue", self._log_id, self._macrostep_count - ) + self._debug("%s Macrostep %d: external queue", self._log_id, self._macrostep_count) while not self.external_queue.is_empty(): self.clear_cache() took_events = True @@ -429,7 +425,7 @@ async def processing_loop( # noqa: C901 self._macrostep_count += 1 self._microstep_count = 0 - logger.debug( + self._debug( "%s macrostep %d: event=%s", self._log_id, self._macrostep_count, @@ -453,7 +449,7 @@ async def processing_loop( # noqa: C901 event_future = external_event.future try: enabled_transitions = await self.select_transitions(external_event) - logger.debug( + self._debug( "%s Enabled transitions: %s", self._log_id, enabled_transitions ) if enabled_transitions: @@ -494,7 +490,7 @@ async def processing_loop( # noqa: C901 _in_processing_loop.reset(_ctx_token) self._processing.release() - logger.debug("%s Processing loop ended", self._log_id) + self._debug("%s Processing loop ended", self._log_id) result = first_result if first_result is not self._sentinel else None # If the caller has a future, await it (already resolved by now). if caller_future is not None: diff --git a/statemachine/engines/base.py b/statemachine/engines/base.py index c197f83b..360398dd 100644 --- a/statemachine/engines/base.py +++ b/statemachine/engines/base.py @@ -11,11 +11,8 @@ from typing import Dict from typing import List from typing import cast -from weakref import ReferenceType -from weakref import ref from ..event import BoundEvent -from ..event import Event from ..event_data import EventData from ..event_data import TriggerData from ..exceptions import InvalidDefinition @@ -88,7 +85,7 @@ def remove(self, send_id: str): class BaseEngine: def __init__(self, sm: "StateChart"): - self._sm: ReferenceType["StateChart"] = ref(sm) + self.sm: "StateChart" = sm self.external_queue = EventQueue() self.internal_queue = EventQueue() self._sentinel = object() @@ -99,17 +96,12 @@ def __init__(self, sm: "StateChart"): self._macrostep_count: int = 0 self._microstep_count: int = 0 self._log_id = f"[{type(sm).__name__}]" + self._debug = logger.debug if logger.isEnabledFor(logging.DEBUG) else lambda *a, **k: None self._root_parallel_final_pending: "State | None" = None def empty(self): # pragma: no cover return self.external_queue.is_empty() - @property - def sm(self) -> "StateChart": - sm = self._sm() - assert sm, "StateMachine has been destroyed" - return sm - def clear_cache(self): """Clears the cache. Should be called at the start of each processing loop.""" self._cache.clear() @@ -125,7 +117,7 @@ def put(self, trigger_data: TriggerData, internal: bool = False, _delayed: bool self.external_queue.put(trigger_data) if not _delayed: - logger.debug( + self._debug( "%s New event '%s' put on the '%s' queue", self._log_id, trigger_data.event, @@ -180,7 +172,7 @@ def _send_error_execution(self, error: Exception, trigger_data: TriggerData): If already processing an error.execution event, ignore to avoid infinite loops. """ - logger.debug( + self._debug( "%s Error %s captured while executing event=%s", self._log_id, error, @@ -346,6 +338,23 @@ def select_transitions(self, trigger_data: TriggerData) -> OrderedSet[Transition """ return self._select_transitions(trigger_data, lambda t, e: t.match(e)) + def _first_transition_that_matches( + self, + state: State, + trigger_data: TriggerData, + predicate: Callable, + ) -> "Transition | None": + for s in chain([state], state.ancestors()): + transition: Transition + for transition in s.transitions: + if ( + not transition.initial + and predicate(transition, trigger_data.event) + and self._conditions_match(transition, trigger_data) + ): + return transition + return None + def _select_transitions( self, trigger_data: TriggerData, predicate: Callable ) -> OrderedSet[Transition]: @@ -355,23 +364,8 @@ def _select_transitions( # Get atomic states, TODO: sorted by document order atomic_states = (state for state in self.sm.configuration if state.is_atomic) - def first_transition_that_matches( - state: State, event: "Event | None" - ) -> "Transition | None": - for s in chain([state], state.ancestors()): - transition: Transition - for transition in s.transitions: - if ( - not transition.initial - and predicate(transition, event) - and self._conditions_match(transition, trigger_data) - ): - return transition - - return None - for state in atomic_states: - transition = first_transition_that_matches(state, trigger_data.event) + transition = self._first_transition_that_matches(state, trigger_data, predicate) if transition is not None: enabled_transitions.add(transition) @@ -382,7 +376,7 @@ def microstep(self, transitions: List[Transition], trigger_data: TriggerData): This includes exiting states, executing transition content, and entering states. """ self._microstep_count += 1 - logger.debug( + self._debug( "%s macro:%d micro:%d transitions: %s", self._log_id, self._macrostep_count, @@ -469,7 +463,7 @@ def _prepare_exit_states( states_to_exit, key=lambda x: x.state and x.state.document_order or 0, reverse=True ) result = OrderedSet([info.state for info in ordered_states if info.state]) - logger.debug("%s States to exit: %s", self._log_id, result) + self._debug("%s States to exit: %s", self._log_id, result) # Update history for info in ordered_states: @@ -480,7 +474,7 @@ def _prepare_exit_states( else: # shallow history history_value = [s for s in self.sm.configuration if s.parent == state] - logger.debug( + self._debug( "%s Saving '%s.%s' history state: '%s'", self._log_id, state, @@ -494,7 +488,7 @@ def _prepare_exit_states( def _remove_state_from_configuration(self, state: State): """Remove a state from the configuration if not using atomic updates.""" if not self.sm.atomic_configuration_update: - self.sm.configuration -= {state} + self.sm._config.discard(state) def _exit_states( self, enabled_transitions: List[Transition], trigger_data: TriggerData @@ -512,7 +506,7 @@ def _exit_states( # Execute `onexit` handlers — same per-block error isolation as onentry. if info.state is not None: # pragma: no branch - logger.debug("%s Exiting state: %s", self._log_id, info.state) + self._debug("%s Exiting state: %s", self._log_id, info.state) self.sm._callbacks.call(info.state.exit.key, *args, on_error=on_error, **kwargs) self._remove_state_from_configuration(info.state) @@ -566,17 +560,29 @@ def _prepare_entry_states( states_targets_to_enter = OrderedSet(info.state for info in ordered_states if info.state) - new_configuration = cast( - OrderedSet[State], (previous_configuration - states_to_exit) | states_targets_to_enter + # Build new configuration in a single pass instead of two set operations + # (- and |) that each allocate an intermediate OrderedSet. + new_configuration = OrderedSet( + s for s in previous_configuration if s not in states_to_exit ) - logger.debug("%s States to enter: %s", self._log_id, states_targets_to_enter) + new_configuration.update(states_targets_to_enter) + self._debug("%s States to enter: %s", self._log_id, states_targets_to_enter) return ordered_states, states_for_default_entry, default_history_content, new_configuration def _add_state_to_configuration(self, target: State): """Add a state to the configuration if not using atomic updates.""" if not self.sm.atomic_configuration_update: - self.sm.configuration |= {target} + self.sm._config.add(target) + + def stop(self): + """Stop this engine externally (e.g. when a parent cancels a child invocation).""" + self._debug("%s Stopping engine", self._log_id) + self.running = False + try: + self._invoke_manager.cancel_all() + except Exception: # pragma: no cover + self._debug("%s Error stopping engine", self._log_id, exc_info=True) def __del__(self): try: @@ -586,7 +592,7 @@ def __del__(self): def _handle_final_state(self, target: State, on_entry_result: list): """Handle final state entry: queue done events. No direct callback dispatch.""" - logger.debug("%s Reached final state: %s", self._log_id, target) + self._debug("%s Reached final state: %s", self._log_id, target) if target.parent is None: self._invoke_manager.cancel_all() self.running = False @@ -665,7 +671,7 @@ def _enter_states( # noqa: C901 target=target, ) - logger.debug("%s Entering state: %s", self._log_id, target) + self._debug("%s Entering state: %s", self._log_id, target) self._add_state_to_configuration(target) # Execute `onentry` handlers — each handler is a separate block per @@ -765,7 +771,7 @@ def add_descendant_states_to_enter( # noqa: C901 parent_id = state.parent and state.parent.id default_history_content[parent_id] = [info] if state.id in self.sm.history_values: - logger.debug( + self._debug( "%s History state '%s.%s' %s restoring: '%s'", self._log_id, state.parent, @@ -795,7 +801,7 @@ def add_descendant_states_to_enter( # noqa: C901 ) else: # Handle default history content - logger.debug( + self._debug( "%s History state '%s.%s' default content: %s", self._log_id, state.parent, @@ -804,7 +810,8 @@ def add_descendant_states_to_enter( # noqa: C901 ) for transition in state.transitions: - info_history = StateTransition(transition=transition, state=transition.target) + target = cast(State, transition.target) + info_history = StateTransition(transition=transition, state=target) default_history_content[parent_id].append(info_history) self.add_descendant_states_to_enter( info_history, @@ -813,7 +820,8 @@ def add_descendant_states_to_enter( # noqa: C901 default_history_content, ) # noqa: E501 for transition in state.transitions: - info_history = StateTransition(transition=transition, state=transition.target) + target = cast(State, transition.target) + info_history = StateTransition(transition=transition, state=target) self.add_ancestor_states_to_enter( info_history, diff --git a/statemachine/engines/sync.py b/statemachine/engines/sync.py index 6c856505..627b51ae 100644 --- a/statemachine/engines/sync.py +++ b/statemachine/engines/sync.py @@ -1,4 +1,3 @@ -import logging from time import sleep from time import time from typing import TYPE_CHECKING @@ -14,8 +13,6 @@ if TYPE_CHECKING: from ..transition import Transition -logger = logging.getLogger(__name__) - class SyncEngine(BaseEngine): def _run_microstep(self, enabled_transitions, trigger_data): @@ -77,7 +74,7 @@ def processing_loop(self, caller_future=None): # noqa: C901 # We will collect the first result as the processing result to keep backwards compatibility # so we need to use a sentinel object instead of `None` because the first result may # be also `None`, and on this case the `first_result` may be overridden by another result. - logger.debug("%s Processing loop started: %s", self._log_id, self.sm.current_state_value) + self._debug("%s Processing loop started: %s", self._log_id, self.sm.current_state_value) first_result = self._sentinel try: took_events = True @@ -92,7 +89,7 @@ def processing_loop(self, caller_future=None): # noqa: C901 # handles eventless transitions and internal events while not macrostep_done: self._microstep_count = 0 - logger.debug( + self._debug( "%s Macrostep %d: eventless/internal queue", self._log_id, self._macrostep_count, @@ -110,7 +107,7 @@ def processing_loop(self, caller_future=None): # noqa: C901 internal_event = self.internal_queue.pop() enabled_transitions = self.select_transitions(internal_event) if enabled_transitions: - logger.debug( + self._debug( "%s Enabled transitions: %s", self._log_id, enabled_transitions ) took_events = True @@ -130,9 +127,7 @@ def processing_loop(self, caller_future=None): # noqa: C901 self._run_microstep(enabled_transitions, internal_event) # Process external events - logger.debug( - "%s Macrostep %d: external queue", self._log_id, self._macrostep_count - ) + self._debug("%s Macrostep %d: external queue", self._log_id, self._macrostep_count) while not self.external_queue.is_empty(): self.clear_cache() took_events = True @@ -147,7 +142,7 @@ def processing_loop(self, caller_future=None): # noqa: C901 self._macrostep_count += 1 self._microstep_count = 0 - logger.debug( + self._debug( "%s macrostep %d: event=%s", self._log_id, self._macrostep_count, @@ -158,7 +153,7 @@ def processing_loop(self, caller_future=None): # noqa: C901 self._invoke_manager.handle_external_event(external_event) enabled_transitions = self.select_transitions(external_event) - logger.debug("%s Enabled transitions: %s", self._log_id, enabled_transitions) + self._debug("%s Enabled transitions: %s", self._log_id, enabled_transitions) if enabled_transitions: try: result = self.microstep(list(enabled_transitions), external_event) @@ -177,7 +172,7 @@ def processing_loop(self, caller_future=None): # noqa: C901 finally: self._processing.release() - logger.debug("%s Processing loop ended", self._log_id) + self._debug("%s Processing loop ended", self._log_id) return first_result if first_result is not self._sentinel else None def enabled_events(self, *args, **kwargs): diff --git a/statemachine/event_data.py b/statemachine/event_data.py index a54c0cc0..9eebfe41 100644 --- a/statemachine/event_data.py +++ b/statemachine/event_data.py @@ -63,8 +63,8 @@ class EventData: source: "State" = field(init=False) """The :ref:`State` which :ref:`statemachine` was in when the Event started.""" - target: "State" = field(init=False) - """The destination :ref:`State` of the :ref:`transition`.""" + target: "State | None" = field(init=False) + """The destination :ref:`State` of the :ref:`transition`, or ``None`` for targetless.""" def __post_init__(self): self.state = self.transition.source diff --git a/statemachine/invoke.py b/statemachine/invoke.py index 9c775563..5d09c14d 100644 --- a/statemachine/invoke.py +++ b/statemachine/invoke.py @@ -7,7 +7,6 @@ """ import asyncio -import logging import threading import uuid from concurrent.futures import Future @@ -33,8 +32,6 @@ from .state import State from .statemachine import StateChart -logger = logging.getLogger(__name__) - @runtime_checkable class IInvoke(Protocol): @@ -51,12 +48,7 @@ def _stop_child_machine(child: "StateChart | None") -> None: """Stop a child state machine and cancel all its invocations.""" if child is None: return - logger.debug("invoke: stopping child machine %s", type(child).__name__) - try: - child._engine.running = False - child._engine._invoke_manager.cancel_all() - except Exception: - logger.debug("Error stopping child machine", exc_info=True) + child._engine.stop() class _InvokeCallableWrapper: @@ -282,6 +274,14 @@ def __init__(self, engine: "BaseEngine"): self._active: Dict[str, Invocation] = {} self._pending: "List[Tuple[State, dict]]" = [] + @property + def _debug(self): + return self._engine._debug + + @property + def _log_id(self): + return self._engine._log_id + @property def sm(self) -> "StateChart": return self._engine.sm @@ -302,7 +302,7 @@ def mark_for_invoke(self, state: "State", event_kwargs: "dict | None" = None): def cancel_for_state(self, state: "State"): """Called by ``_exit_states()`` before exiting a state.""" - logger.debug("invoke cancel_for_state: %s", state.id) + self._debug("%s invoke cancel_for_state: %s", self._log_id, state.id) for inv_id, inv in list(self._active.items()): if inv.state_id == state.id and not inv.ctx.cancelled.is_set(): self._cancel(inv_id) @@ -313,7 +313,7 @@ def cancel_for_state(self, state: "State"): def cancel_all(self): """Cancel all active invocations.""" - logger.debug("invoke cancel_all: %d active", len(self._active)) + self._debug("%s invoke cancel_all: %d active", self._log_id, len(self._active)) for inv_id in list(self._active.keys()): self._cancel(inv_id) self._cleanup_terminated() @@ -362,7 +362,7 @@ def _spawn_one_sync(self, callback: "CallbackWrapper", **kwargs): invocation._handler = handler self._active[ctx.invokeid] = invocation - logger.debug("invoke spawn sync: %s on state %s", ctx.invokeid, state.id) + self._debug("%s invoke spawn sync: %s on state %s", self._log_id, ctx.invokeid, state.id) thread = threading.Thread( target=self._run_sync_handler, @@ -400,8 +400,11 @@ def _run_sync_handler( self.sm.send("error.execution", error=e) finally: invocation.terminated = True - logger.debug( - "invoke %s: completed (cancelled=%s)", ctx.invokeid, ctx.cancelled.is_set() + self._debug( + "%s invoke %s: completed (cancelled=%s)", + self._log_id, + ctx.invokeid, + ctx.cancelled.is_set(), ) # --- Async spawning --- @@ -431,7 +434,7 @@ def _spawn_one_async(self, callback: "CallbackWrapper", **kwargs): invocation._handler = handler self._active[ctx.invokeid] = invocation - logger.debug("invoke spawn async: %s on state %s", ctx.invokeid, state.id) + self._debug("%s invoke spawn async: %s on state %s", self._log_id, ctx.invokeid, state.id) loop = asyncio.get_running_loop() task = loop.create_task(self._run_async_handler(callback, handler, ctx, invocation)) @@ -469,8 +472,11 @@ async def _run_async_handler( await self.sm.send("error.execution", error=e) finally: invocation.terminated = True - logger.debug( - "invoke %s: completed (cancelled=%s)", ctx.invokeid, ctx.cancelled.is_set() + self._debug( + "%s invoke %s: completed (cancelled=%s)", + self._log_id, + ctx.invokeid, + ctx.cancelled.is_set(), ) # --- Cancel --- @@ -480,7 +486,7 @@ def _cancel(self, invokeid: str): if not invocation or invocation.ctx.cancelled.is_set(): return - logger.debug("invoke cancel: %s", invokeid) + self._debug("%s invoke cancel: %s", self._log_id, invokeid) # 1) Signal cancellation so the handler can check and stop early. invocation.ctx.cancelled.set() @@ -490,7 +496,7 @@ def _cancel(self, invokeid: str): try: handler.on_cancel() except Exception: - logger.debug("Error in on_cancel for %s", invokeid, exc_info=True) + self._debug("%s Error in on_cancel for %s", self._log_id, invokeid, exc_info=True) # 3) Cancel the async task (raises CancelledError at next await). if invocation.task is not None and not invocation.task.done(): @@ -564,7 +570,9 @@ def handle_external_event(self, trigger_data) -> None: and handler.autoforward and hasattr(handler, "on_event") ): - logger.debug("invoke autoforward: %s -> %s", event_name, inv.invokeid) + self._debug( + "%s invoke autoforward: %s -> %s", self._log_id, event_name, inv.invokeid + ) handler.on_event(event_name, **trigger_data.kwargs) def _make_context( diff --git a/statemachine/io/scxml/actions.py b/statemachine/io/scxml/actions.py index 1da3cc62..c4e46cca 100644 --- a/statemachine/io/scxml/actions.py +++ b/statemachine/io/scxml/actions.py @@ -28,6 +28,7 @@ from .schema import ScriptAction logger = logging.getLogger(__name__) +_debug = logger.debug if logger.isEnabledFor(logging.DEBUG) else lambda *a, **k: None protected_attrs = _event_data_kwargs | {"_sessionid", "_ioprocessors", "_name", "_event"} @@ -220,7 +221,7 @@ def __init__(self, cond: str, processor=None): def __call__(self, *args, **kwargs): result = _eval(self.action, **kwargs) - logger.debug("Cond %s -> %s", self.action, result) + _debug("Cond %s -> %s", self.action, result) return result @staticmethod @@ -298,7 +299,7 @@ def __call__(self, *args, **kwargs): f"{self.action.location}" ) setattr(obj, attr, value) - logger.debug(f"Assign: {self.action.location} = {value!r}") + _debug("Assign: %s = %r", self.action.location, value) class Log(CallableAction): diff --git a/statemachine/state.py b/statemachine/state.py index 32c436ff..e8aa572a 100644 --- a/statemachine/state.py +++ b/statemachine/state.py @@ -1,7 +1,6 @@ from enum import Enum from typing import TYPE_CHECKING from typing import Any -from typing import Dict from typing import Generator from typing import List from typing import cast @@ -12,7 +11,6 @@ from .callbacks import CallbackSpecList from .event import _expand_event_id from .exceptions import InvalidDefinition -from .exceptions import StateMachineError from .i18n import _ from .invoke import normalize_invoke_callbacks from .transition import Transition @@ -246,6 +244,7 @@ def __init__( raise InvalidDefinition(_("'donedata' can only be specified on final states.")) self.enter.add(donedata, priority=CallbackPriority.INLINE) self.document_order = 0 + self._hash = id(self) self._init_states() def _init_states(self): @@ -267,7 +266,7 @@ def __eq__(self, other): ) def __hash__(self): - return hash(repr(self)) + return self._hash def _setup(self): self.enter.add("on_enter_state", priority=CallbackPriority.GENERIC, is_convention=True) @@ -294,22 +293,6 @@ def __repr__(self): def __str__(self): return self.name - def __get__(self, machine, owner): - if machine is None: - return self - return self.for_instance(machine=machine, cache=machine._states_for_instance) - - def __set__(self, instance, value): - raise StateMachineError( - _("State overriding is not allowed. Trying to add '{}' to {}").format(value, self.id) - ) - - def for_instance(self, machine: "StateChart", cache: Dict["State", "State"]) -> "State": - if self not in cache: - cache[self] = InstanceState(self, machine) - - return cache[self] - @property def id(self) -> str: return self._id @@ -320,6 +303,7 @@ def _set_id(self, id: str) -> "State": self.value = id if not self.name: self.name = self._id.replace("_", " ").capitalize() + self._hash = hash((self.name, self._id)) return self @@ -366,67 +350,52 @@ def is_descendant(self, state: "State") -> bool: class InstanceState(State): - """ """ + """Per-instance proxy for a State, delegating attribute access to the underlying State. + + Uses ``__getattr__`` for automatic delegation of instance attributes (name, value, + transitions, etc.) and explicit property overrides for attributes that access private + fields or have custom logic (id, initial, final, parallel, is_active). + """ def __init__( self, state: State, machine: "StateChart", ): - self._state = ref(state) + self._state = state self._machine = ref(machine) + self._hash = hash(state) self._init_states() - def _ref(self) -> State: - """Dereference the weakref, raising if the referent has been collected.""" - state = self._state() - assert state is not None - return state - - @property - def name(self): - return self._ref().name - - @property - def value(self): - return self._ref().value - - @property - def transitions(self): - return self._ref().transitions - - @property - def enter(self): - return self._ref().enter - - @property - def exit(self): - return self._ref().exit - - @property - def invoke(self): - return self._ref().invoke + def __getattr__(self, name: str): + value = getattr(self._state, name) + self.__dict__[name] = value + return value def __eq__(self, other): - return self._ref() == other + return self._state == other def __hash__(self): - return hash(repr(self._ref())) + return self._hash def __repr__(self): - return repr(self._ref()) + return repr(self._state) + + @property + def id(self) -> str: + return self._state._id @property def initial(self): - return self._ref()._initial + return self._state._initial @property def final(self): - return self._ref()._final + return self._state._final @property - def id(self) -> str: - return (self._state() or self)._id # type: ignore[union-attr] + def parallel(self): + return self._state._parallel @property def is_active(self): @@ -434,34 +403,6 @@ def is_active(self): assert machine is not None return self.value in machine.configuration_values - @property - def is_atomic(self): - return self._ref().is_atomic - - @property - def parent(self): - return self._ref().parent - - @property - def states(self): - return self._ref().states - - @property - def history(self): - return self._ref().history - - @property - def parallel(self): - return self._ref().parallel - - @property - def is_compound(self): - return self._ref().is_compound - - @property - def document_order(self): - return self._ref().document_order - class AnyState(State): """A special state that works as a "ANY" placeholder. diff --git a/statemachine/statemachine.py b/statemachine/statemachine.py index 6a5fce59..c3143a84 100644 --- a/statemachine/statemachine.py +++ b/statemachine/statemachine.py @@ -16,6 +16,7 @@ from .callbacks import CallbacksRegistry from .callbacks import SpecListGrouper from .callbacks import SpecReference +from .configuration import Configuration from .dispatcher import Listener from .dispatcher import Listeners from .engines.async_ import AsyncEngine @@ -24,12 +25,14 @@ from .event_data import TriggerData from .exceptions import InvalidDefinition from .exceptions import InvalidStateValue +from .exceptions import StateMachineError from .exceptions import TransitionNotAllowed from .factory import StateMachineMetaclass from .graph import iterate_states_and_transitions from .i18n import _ from .model import Model from .signature import SignatureAdapter +from .state import InstanceState from .utils import run_async_from_sync if TYPE_CHECKING: @@ -150,7 +153,7 @@ def __init__( [start_value] if start_value is not None else list(self.start_configuration_values) ) self._callbacks = CallbacksRegistry() - self._states_for_instance: Dict[State, State] = {} + self._config = self._build_configuration() self._listeners: Dict[int, Any] = {} """Listeners that provides attributes to be used as callbacks.""" @@ -193,6 +196,22 @@ def _resolve_class_listeners(self, **kwargs: Any) -> List[object]: resolved.append(instance) return resolved + def _build_configuration(self) -> Configuration: + """Create InstanceState entries and return a new Configuration.""" + instance_states: Dict[str, Any] = {} + events = self.__class__._events + for state in self.states_map.values(): + ist = InstanceState(state, self) + instance_states[state.id] = ist + if state.id not in events: + vars(self)[state.id] = ist + return Configuration( + instance_states=instance_states, + model=self.model, + state_field=self.state_field, + states_map=self.states_map, + ) + def activate_initial_state(self) -> Any: result = self._engine.activate_initial_state() if not isawaitable(result): @@ -205,6 +224,14 @@ def _processing_loop(self, caller_future: "Any | None" = None) -> Any: return result return run_async_from_sync(result) + def __setattr__(self, name, value): + # Fast path: internal/private attributes are never state IDs. + if not name.startswith("_") and name in self.__class__.states_map: + raise StateMachineError( + _("State overriding is not allowed. Trying to add '{}' to {}").format(value, name) + ) + super().__setattr__(name, value) + def __repr__(self): configuration_ids = [s.id for s in self.configuration] return ( @@ -213,9 +240,9 @@ def __repr__(self): ) def __getstate__(self): - state = self.__dict__.copy() + state = {k: v for k, v in self.__dict__.items() if not isinstance(v, InstanceState)} del state["_callbacks"] - del state["_states_for_instance"] + del state["_config"] del state["_engine"] return state @@ -223,7 +250,7 @@ def __setstate__(self, state: Dict[str, Any]) -> None: listeners = state.pop("_listeners") self.__dict__.update(state) # type: ignore[attr-defined] self._callbacks = CallbacksRegistry() - self._states_for_instance = {} + self._config = self._build_configuration() self._listeners = {} # _listeners already contained both class-level and runtime listeners @@ -335,44 +362,16 @@ def _graph(self): def configuration_values(self) -> OrderedSet[Any]: """The state configuration values is the set of currently active states's values (or ids if no custom value is defined).""" - if isinstance(self.current_state_value, OrderedSet): - return self.current_state_value - return OrderedSet([self.current_state_value]) + return self._config.values @property def configuration(self) -> OrderedSet["State"]: """The set of currently active states.""" - if self.current_state_value is None: - return OrderedSet() - - if not isinstance(self.current_state_value, MutableSet): - return OrderedSet( - [ - self.states_map[self.current_state_value].for_instance( - machine=self, - cache=self._states_for_instance, - ) - ] - ) - - return OrderedSet( - [ - self.states_map[value].for_instance( - machine=self, - cache=self._states_for_instance, - ) - for value in self.current_state_value - ] - ) + return self._config.states @configuration.setter def configuration(self, new_configuration: OrderedSet["State"]): - if len(new_configuration) == 0: - self.current_state_value = None - elif len(new_configuration) == 1: - self.current_state_value = new_configuration.pop().value - else: - self.current_state_value = OrderedSet(s.value for s in new_configuration) + self._config.states = new_configuration @property def current_state_value(self): @@ -381,17 +380,11 @@ def current_state_value(self): This is a low level API, that can be used to assign any valid state value completely bypassing all the hooks and validations. """ - return getattr(self.model, self.state_field, None) + return self._config.value @current_state_value.setter def current_state_value(self, value): - if ( - value is not None - and not isinstance(value, MutableSet) - and value not in self.states_map - ): - raise InvalidStateValue(value) - setattr(self.model, self.state_field, value) + self._config.value = value @property def current_state(self) -> "State | MutableSet[State]": @@ -405,36 +398,7 @@ def current_state(self) -> "State | MutableSet[State]": DeprecationWarning, stacklevel=2, ) - current_value = self.current_state_value - - try: - if isinstance(current_value, list): - return OrderedSet( - [ - self.states_map[value].for_instance( - machine=self, - cache=self._states_for_instance, - ) - for value in current_value - ] - ) - - state: State = self.states_map[current_value].for_instance( - machine=self, - cache=self._states_for_instance, - ) - return state - except KeyError as err: - if self.current_state_value is None: - raise InvalidStateValue( - self.current_state_value, - _( - "There's no current state set. In async code, " - "did you activate the initial state? " - "(e.g., `await sm.activate_initial_state()`)" - ), - ) from err - raise InvalidStateValue(self.current_state_value) from err + return self._config.current_state @current_state.setter def current_state(self, value): # pragma: no cover diff --git a/tests/test_configuration.py b/tests/test_configuration.py new file mode 100644 index 00000000..a3b49985 --- /dev/null +++ b/tests/test_configuration.py @@ -0,0 +1,66 @@ +"""Tests for the Configuration class internals. + +These tests cover branches in statemachine/configuration.py that are not +exercised by the higher-level state machine tests. +""" + +import warnings + +from statemachine.orderedset import OrderedSet + +from statemachine import State +from statemachine import StateChart + + +class ParallelSM(StateChart): + """A parallel state chart for testing multi-element configuration.""" + + s1 = State(initial=True) + s2 = State() + s3 = State(final=True) + + go = s1.to(s2) + finish = s2.to(s3) + + +class TestConfigurationStatesSetter: + def test_set_empty_configuration(self): + sm = ParallelSM() + assert len(sm.configuration) > 0 + + sm.configuration = OrderedSet() + assert sm.current_state_value is None + + def test_set_multi_element_configuration(self): + sm = ParallelSM() + s1_inst = sm.s1 + s2_inst = sm.s2 + + sm.configuration = OrderedSet([s1_inst, s2_inst]) + assert isinstance(sm.current_state_value, OrderedSet) + assert sm.current_state_value == OrderedSet([ParallelSM.s1.value, ParallelSM.s2.value]) + + +class TestConfigurationDiscard: + def test_discard_nonmatching_scalar(self): + sm = ParallelSM() + # current value is s1 (scalar) + assert sm.current_state_value == ParallelSM.s1.value + + # discard s2 — should be a no-op since s2 is not active + sm._config.discard(ParallelSM.s2) + assert sm.current_state_value == ParallelSM.s1.value + + +class TestConfigurationCurrentState: + def test_current_state_with_multiple_active_states(self): + sm = ParallelSM() + s1_inst = sm.s1 + s2_inst = sm.s2 + sm.configuration = OrderedSet([s1_inst, s2_inst]) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + result = sm.current_state + assert isinstance(result, OrderedSet) + assert len(result) == 2 diff --git a/tests/test_profiling.py b/tests/test_profiling.py index 7da292ff..43863ebd 100644 --- a/tests/test_profiling.py +++ b/tests/test_profiling.py @@ -2,10 +2,16 @@ import pytest +from statemachine import HistoryState from statemachine import State from statemachine import StateChart +# --------------------------------------------------------------------------- +# Machines under test +# --------------------------------------------------------------------------- + +# 1. Flat machine with model, guards, and listener callbacks (v1-style) class OrderControl(StateChart): allow_event_without_transition = False catch_errors_as_events = False @@ -45,6 +51,111 @@ def after_receive_payment(self): self.payment_received = True +# 2. Compound (nested) states +class CompoundSC(StateChart): + class active(State.Compound, name="Active"): + idle = State(initial=True) + working = State() + begin = idle.to(working) + + off = State(initial=True) + done = State(final=True) + + turn_on = off.to(active) + turn_off = active.to(done) + + +# 3. Parallel regions +class ParallelSC(StateChart): + class both(State.Parallel, name="Both"): + class left(State.Compound, name="Left"): + l1 = State(initial=True) + l2 = State() + go_l = l1.to(l2) + back_l = l2.to(l1) + + class right(State.Compound, name="Right"): + r1 = State(initial=True) + r2 = State() + go_r = r1.to(r2) + back_r = r2.to(r1) + + start = State(initial=True) + enter = start.to(both) + + +# 4. Guards with boolean expressions +class GuardedSC(StateChart): + s1 = State(initial=True) + s2 = State() + s3 = State(final=True) + + def check_a(self): + return True + + def check_b(self): + return False + + go = s1.to(s2, cond="check_a") | s1.to(s3, cond="check_b") + back = s2.to(s1) + + +# 5. History states (shallow) +class HistoryShallowSC(StateChart): + class process(State.Compound, name="Process"): + step1 = State(initial=True) + step2 = State() + advance = step1.to(step2) + h = HistoryState() + + paused = State(initial=True) + + pause = process.to(paused) + resume = paused.to(process.h) + begin = paused.to(process) + + +# 6. Deep history with nested compound states +class DeepHistorySC(StateChart): + class outer(State.Compound, name="Outer"): + class inner(State.Compound, name="Inner"): + a = State(initial=True) + b = State() + go = a.to(b) + back = b.to(a) + + start = State(initial=True) + enter_inner = start.to(inner) + h = HistoryState(type="deep") + + away = State(initial=True) + + dive = away.to(outer) + leave = outer.to(away) + restore = away.to(outer.h) + + +# 7. Many-transition stress machine (wide, not deep) +class ManyTransitionsSC(StateChart): + s1 = State(initial=True) + s2 = State() + s3 = State() + s4 = State() + s5 = State() + + go_12 = s1.to(s2) + go_23 = s2.to(s3) + go_34 = s3.to(s4) + go_45 = s4.to(s5) + go_51 = s5.to(s1) + reset = s2.to(s1) | s3.to(s1) | s4.to(s1) | s5.to(s1) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + def create_order(): order = Order() assert order.state_machine.waiting_for_payment.is_active @@ -54,12 +165,129 @@ def add_to_order(sm, amount): sm.add_to_order(amount) +# --------------------------------------------------------------------------- +# Benchmark: instance creation +# --------------------------------------------------------------------------- + + @pytest.mark.slow() -def test_setup_performance(benchmark): - benchmark.pedantic(create_order, rounds=10, iterations=1000) +class TestSetupPerformance: + """Benchmark the cost of creating and activating state machine instances.""" + + def test_flat_machine(self, benchmark): + benchmark.pedantic(create_order, rounds=10, iterations=1000) + + def test_compound_machine(self, benchmark): + benchmark.pedantic(lambda: CompoundSC(), rounds=10, iterations=1000) + + def test_parallel_machine(self, benchmark): + benchmark.pedantic(lambda: ParallelSC(), rounds=10, iterations=1000) + + def test_guarded_machine(self, benchmark): + benchmark.pedantic(lambda: GuardedSC(), rounds=10, iterations=1000) + + def test_history_machine(self, benchmark): + benchmark.pedantic(lambda: HistoryShallowSC(), rounds=10, iterations=1000) + + def test_deep_history_machine(self, benchmark): + benchmark.pedantic(lambda: DeepHistorySC(), rounds=10, iterations=1000) + + +# --------------------------------------------------------------------------- +# Benchmark: event throughput +# --------------------------------------------------------------------------- @pytest.mark.slow() -def test_event_performance(benchmark): - order = Order() - benchmark.pedantic(add_to_order, args=(order.state_machine, 1), rounds=10, iterations=1000) +class TestEventPerformance: + """Benchmark event processing (self-transitions and state changes).""" + + def test_flat_self_transition(self, benchmark): + """Self-transition on a flat machine with model/listener.""" + order = Order() + sm = order.state_machine + benchmark.pedantic(add_to_order, args=(sm, 1), rounds=10, iterations=1000) + + def test_compound_enter_exit(self, benchmark): + """Enter and exit a compound state repeatedly.""" + + def cycle(): + sm = CompoundSC() + sm.turn_on() + sm.begin() + sm.turn_off() + + benchmark.pedantic(cycle, rounds=10, iterations=500) + + def test_parallel_region_events(self, benchmark): + """Send events within parallel regions.""" + sm = ParallelSC() + sm.enter() + + def cycle(): + sm.go_l() + sm.go_r() + sm.back_l() + sm.back_r() + + benchmark.pedantic(cycle, rounds=10, iterations=500) + + def test_guarded_transitions(self, benchmark): + """Guard evaluation + transition selection.""" + sm = GuardedSC() + + def cycle(): + sm.go() + sm.back() + + benchmark.pedantic(cycle, rounds=10, iterations=1000) + + def test_history_pause_resume(self, benchmark): + """Shallow history: pause and resume compound state.""" + sm = HistoryShallowSC() + sm.begin() + sm.advance() + + def cycle(): + sm.pause() + sm.resume() + + benchmark.pedantic(cycle, rounds=10, iterations=500) + + def test_deep_history_cycle(self, benchmark): + """Deep history: leave and restore nested compound state.""" + sm = DeepHistorySC() + sm.dive() + sm.enter_inner() + sm.go() + + def cycle(): + sm.leave() + sm.restore() + + benchmark.pedantic(cycle, rounds=10, iterations=500) + + def test_many_transitions_full_cycle(self, benchmark): + """Traverse a 5-state ring (s1→s2→s3→s4→s5→s1).""" + sm = ManyTransitionsSC() + + def cycle(): + sm.go_12() + sm.go_23() + sm.go_34() + sm.go_45() + sm.go_51() + + benchmark.pedantic(cycle, rounds=10, iterations=500) + + def test_many_transitions_reset(self, benchmark): + """Composite event (|) selecting among multiple source states.""" + sm = ManyTransitionsSC() + + def cycle(): + sm.go_12() + sm.go_23() + sm.go_34() + sm.reset() + + benchmark.pedantic(cycle, rounds=10, iterations=500) diff --git a/tests/test_statemachine_compat.py b/tests/test_statemachine_compat.py index edfda0b1..b163886e 100644 --- a/tests/test_statemachine_compat.py +++ b/tests/test_statemachine_compat.py @@ -356,19 +356,3 @@ class SM(StateMachine): sm = SM() with pytest.warns(DeprecationWarning, match="current_state"): _ = sm.current_state # noqa: F841 - - def test_current_state_with_list_value(self): - """current_state handles list current_state_value (backward compat).""" - - class SM(StateMachine): - s1 = State(initial=True) - s2 = State(final=True) - - go = s1.to(s2) - - sm = SM() - setattr(sm.model, sm.state_field, [sm.s1.value]) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - config = sm.current_state - assert sm.s1 in config diff --git a/tests/test_threading.py b/tests/test_threading.py index 5f6721a7..b2d305e8 100644 --- a/tests/test_threading.py +++ b/tests/test_threading.py @@ -1,6 +1,8 @@ import threading import time +from collections import Counter +import pytest from statemachine.state import State from statemachine.statemachine import StateChart @@ -115,6 +117,184 @@ def __init__(self, name): assert c3.fsm.statuses_history == ["c3.green", "c3.green", "c3.green", "c3.yellow"] +class TestThreadSafety: + """Stress tests for concurrent access to a single state machine instance. + + These tests exercise real contention: multiple threads sending events to the + same SM simultaneously, synchronized via barriers to maximize overlap. + """ + + @pytest.fixture() + def cycling_machine(self): + class CyclingMachine(StateChart): + s1 = State(initial=True) + s2 = State() + s3 = State() + cycle = s1.to(s2) | s2.to(s3) | s3.to(s1) + + return CyclingMachine() + + @pytest.mark.parametrize("num_threads", [4, 8]) + def test_concurrent_sends_no_lost_events(self, cycling_machine, num_threads): + """All events sent concurrently must be processed — none lost.""" + events_per_thread = 300 + total_events = num_threads * events_per_thread + barrier = threading.Barrier(num_threads) + errors = [] + + def sender(): + try: + barrier.wait(timeout=5) + for _ in range(events_per_thread): + cycling_machine.send("cycle") + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=sender) for _ in range(num_threads)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=30) + + assert not errors, f"Thread errors: {errors}" + + # The machine cycles s1→s2→s3→s1. After N total cycle events starting + # from s1, the state is determined by (N % 3). + expected_states = {0: "s1", 1: "s2", 2: "s3"} + expected = expected_states[total_events % 3] + assert cycling_machine.current_state_value == expected + + def test_concurrent_sends_state_consistency(self, cycling_machine): + """State must always be one of the valid states, never corrupted.""" + valid_values = {"s1", "s2", "s3"} + num_threads = 6 + events_per_thread = 500 + barrier = threading.Barrier(num_threads + 1) # +1 for observer + stop_event = threading.Event() + observed_values = [] + errors = [] + + def sender(): + try: + barrier.wait(timeout=5) + for _ in range(events_per_thread): + cycling_machine.send("cycle") + except Exception as e: + errors.append(e) + + def observer(): + barrier.wait(timeout=5) + while not stop_event.is_set(): + val = cycling_machine.current_state_value + observed_values.append(val) + + threads = [threading.Thread(target=sender) for _ in range(num_threads)] + obs_thread = threading.Thread(target=observer) + + for t in threads: + t.start() + obs_thread.start() + + for t in threads: + t.join(timeout=30) + + stop_event.set() + obs_thread.join(timeout=5) + + assert not errors, f"Thread errors: {errors}" + # None may appear transiently during configuration updates — that's expected. + invalid = [v for v in observed_values if v not in valid_values and v is not None] + assert not invalid, f"Observed invalid state values: {set(invalid)}" + assert len(observed_values) > 100, "Observer didn't collect enough samples" + + def test_concurrent_sends_with_callbacks(self): + """Callbacks must execute exactly once per transition under contention.""" + call_log = [] + lock = threading.Lock() + + class CallbackMachine(StateChart): + s1 = State(initial=True) + s2 = State() + go = s1.to(s2) | s2.to(s1) + + def on_enter_s2(self): + with lock: + call_log.append("enter_s2") + + def on_enter_s1(self): + with lock: + call_log.append("enter_s1") + + sm = CallbackMachine() + num_threads = 4 + events_per_thread = 200 + total_events = num_threads * events_per_thread + barrier = threading.Barrier(num_threads) + errors = [] + + def sender(): + try: + barrier.wait(timeout=5) + for _ in range(events_per_thread): + sm.send("go") + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=sender) for _ in range(num_threads)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=30) + + assert not errors, f"Thread errors: {errors}" + + # Each transition fires exactly one on_enter callback. + # +1 because initial activation also fires on_enter_s1. + counts = Counter(call_log) + total_callbacks = counts["enter_s1"] + counts["enter_s2"] + assert total_callbacks == total_events + 1 + + def test_concurrent_send_and_read_configuration(self, cycling_machine): + """Reading configuration while events are being processed must not raise.""" + num_senders = 4 + events_per_sender = 300 + barrier = threading.Barrier(num_senders + 1) + stop_event = threading.Event() + errors = [] + + def sender(): + try: + barrier.wait(timeout=5) + for _ in range(events_per_sender): + cycling_machine.send("cycle") + except Exception as e: + errors.append(e) + + def reader(): + barrier.wait(timeout=5) + while not stop_event.is_set(): + try: + _ = cycling_machine.configuration + _ = cycling_machine.current_state_value + _ = list(cycling_machine.configuration) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=sender) for _ in range(num_senders)] + reader_thread = threading.Thread(target=reader) + + for t in threads: + t.start() + reader_thread.start() + + for t in threads: + t.join(timeout=30) + stop_event.set() + reader_thread.join(timeout=5) + + assert not errors, f"Thread errors: {errors}" + + async def test_regression_443_with_modifications_for_async_engine(): """ Test for https://github.com/fgmacedo/python-statemachine/issues/443