diff --git a/README.md b/README.md index f3af4ecb..43f141a6 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,17 @@ True ``` -Generate a diagram: +Generate a diagram or get a text representation with f-strings: + +```py +>>> print(f"{sm:md}") +| State | Event | Guard | Target | +| ------ | ----- | ----- | ------ | +| Green | cycle | | Yellow | +| Yellow | cycle | | Red | +| Red | cycle | | Green | + +``` ```python sm._graph().write_png("traffic_light.png") @@ -341,7 +351,7 @@ There's a lot more to explore: - **`prepare_event`** callback — inject custom data into all callbacks - **Observer pattern** — register external listeners to watch events and state changes - **Django integration** — auto-discover state machines in Django apps with `MachineMixin` -- **Diagram generation** — from the CLI, at runtime, or in Jupyter notebooks +- **Diagram generation** — via f-strings (`f"{sm:mermaid}"`), CLI, Sphinx directive, or Jupyter - **Dictionary-based definitions** — create state machines from data structures - **Internationalization** — error messages in multiple languages diff --git a/docs/conf.py b/docs/conf.py index 18846738..59e518b5 100755 --- a/docs/conf.py +++ b/docs/conf.py @@ -52,6 +52,7 @@ "sphinx_gallery.gen_gallery", "sphinx_copybutton", "statemachine.contrib.diagram.sphinx_ext", + "sphinxcontrib.mermaid", ] autosectionlabel_prefix_document = True diff --git a/docs/diagram.md b/docs/diagram.md index 6b09b1ba..ff3962df 100644 --- a/docs/diagram.md +++ b/docs/diagram.md @@ -27,7 +27,6 @@ sudo apt install graphviz For other systems, see the [Graphviz downloads page](https://graphviz.org/download/). - ## Generating diagrams Every state machine instance exposes a `_graph()` method that returns a @@ -77,8 +76,7 @@ For higher resolution PNGs, set the DPI before exporting: ```python graph = sm._graph() -graph.set_dpi(300) -graph.write_png("order_control_300dpi.png") +graph.set_dpi(300).write_png("order_control_300dpi.png") ``` ```{note} @@ -89,7 +87,144 @@ complete list. ``` -## Command line +## Text representations + +State machines support multiple text-based output formats, all accessible +through Python's built-in `format()` protocol, the `formatter` API, or +the command line. + +| Format | Aliases | Description | Dependencies | +|--------|---------|-------------|--------------| +| `mermaid` | | [Mermaid stateDiagram-v2](https://mermaid.js.org/syntax/stateDiagram.html) source | None [^mermaid] | +| `md` | `markdown` | Transition table (pipe-delimited Markdown) | None | +| `rst` | | Transition table (RST grid table) | None | +| `dot` | | [Graphviz DOT](https://graphviz.org/doc/info/lang.html) language source | pydot | +| `svg` | | SVG markup (generated via DOT) | pydot, Graphviz | + +[^mermaid]: Mermaid has a known rendering bug + ([mermaid-js/mermaid#4052](https://github.com/mermaid-js/mermaid/issues/4052)) + where transitions targeting or originating from a compound state inside a + parallel region crash the renderer. As a workaround, the `MermaidRenderer` + redirects such transitions to the compound's initial child state. The + visual result is equivalent — Mermaid draws the arrow crossing into the + compound boundary — but the arrow points to the child rather than the + compound border. This workaround will be revisited when the upstream bug + is resolved. + + +### Using `format()` + +Use f-strings or the built-in `format()` function — no diagram imports needed: + +```py +>>> from tests.examples.traffic_light_machine import TrafficLightMachine +>>> sm = TrafficLightMachine() +>>> print(f"{sm:mermaid}") +stateDiagram-v2 + direction LR + state "Green" as green + state "Yellow" as yellow + state "Red" as red + [*] --> green + green --> yellow : cycle + yellow --> red : cycle + red --> green : cycle + + classDef active fill:#40E0D0,stroke:#333 + green:::active + + +>>> print(f"{sm:md}") +| State | Event | Guard | Target | +| ------ | ----- | ----- | ------ | +| Green | cycle | | Yellow | +| Yellow | cycle | | Red | +| Red | cycle | | Green | + + +``` + +Works on **classes** too (no active-state highlighting): + +```py +>>> print(f"{TrafficLightMachine:mermaid}") +stateDiagram-v2 + direction LR + state "Green" as green + state "Yellow" as yellow + state "Red" as red + [*] --> green + green --> yellow : cycle + yellow --> red : cycle + red --> green : cycle + + +``` + +The `dot` format returns the Graphviz DOT language source: + +```py +>>> print(f"{sm:dot}") # doctest: +ELLIPSIS +digraph TrafficLightMachine { +... +} + +``` + +An empty format spec (e.g., `f"{sm:}"`) falls back to `repr()`. + + +(formatter-api)= +### Using the `formatter` API + +The `formatter` object is the programmatic entry point for rendering +state machines in any registered text format: + +```py +>>> from statemachine.contrib.diagram import formatter +>>> from tests.examples.traffic_light_machine import TrafficLightMachine + +>>> print(formatter.render(TrafficLightMachine, "mermaid")) +stateDiagram-v2 + direction LR + state "Green" as green + state "Yellow" as yellow + state "Red" as red + [*] --> green + green --> yellow : cycle + yellow --> red : cycle + red --> green : cycle + + +>>> formatter.supported_formats() +['dot', 'markdown', 'md', 'mermaid', 'rst', 'svg'] + +``` + +Both `format()` and the Sphinx directive delegate to this same `formatter` +under the hood. + + +#### Registering custom formats + +The `formatter` is extensible — register your own format with a +decorator and it becomes available everywhere (`format()`, CLI, +Sphinx directive): + +```python +from statemachine.contrib.diagram import formatter + +@formatter.register_format("plantuml", "puml") +def _render_plantuml(machine_or_class): + # your PlantUML renderer here + ... +``` + +After registration, `f"{sm:plantuml}"` and `--format plantuml` work +immediately. + + +### Command line You can generate diagrams without writing Python code: @@ -110,6 +245,93 @@ send events before rendering: python -m statemachine.contrib.diagram tests.examples.traffic_light_machine.TrafficLightMachine diagram.png --events cycle cycle cycle ``` +Use `--format` to produce a text format instead of a Graphviz image: + +```bash +# Mermaid stateDiagram-v2 +python -m statemachine.contrib.diagram tests.examples.traffic_light_machine.TrafficLightMachine output.mmd --format mermaid + +# DOT source +python -m statemachine.contrib.diagram tests.examples.traffic_light_machine.TrafficLightMachine output.dot --format dot + +# Markdown transition table +python -m statemachine.contrib.diagram tests.examples.traffic_light_machine.TrafficLightMachine output.md --format md + +# RST transition table +python -m statemachine.contrib.diagram tests.examples.traffic_light_machine.TrafficLightMachine output.rst --format rst +``` + +Use `-` as the output file to write to stdout (handy for piping): + +```bash +python -m statemachine.contrib.diagram tests.examples.traffic_light_machine.TrafficLightMachine - --format mermaid +``` + + +## Auto-expanding docstrings + +Use `{statechart:FORMAT}` placeholders in your class docstring to embed +a live representation of the state machine. The placeholder is replaced +at class definition time, so the docstring always reflects the actual +states and transitions: + +```py +>>> from statemachine.statemachine import StateChart +>>> from statemachine.state import State + +>>> class TrafficLight(StateChart): +... """A traffic light. +... +... {statechart:md} +... """ +... green = State(initial=True) +... yellow = State() +... red = State() +... cycle = green.to(yellow) | yellow.to(red) | red.to(green) + +>>> print(TrafficLight.__doc__) +A traffic light. + +| State | Event | Guard | Target | +| ------ | ----- | ----- | ------ | +| Green | cycle | | Yellow | +| Yellow | cycle | | Red | +| Red | cycle | | Green | + + + +``` + +Any registered format works: `{statechart:rst}`, `{statechart:mermaid}`, +`{statechart:dot}`, etc. + +### Choosing the right format + +| Context | Recommended format | +|---------|-------------------| +| Sphinx with RST (autodoc default) | `{statechart:rst}` | +| Sphinx with MyST Markdown | `{statechart:md}` | +| `help()` in terminal / IDE | Either works; `md` reads more cleanly | + +### Sphinx autodoc integration + +Since the placeholder is expanded at class definition time, Sphinx `autodoc` +sees the final rendered text — no extra configuration needed. + +For example, this class uses `{statechart:rst}` in its docstring: + +```{literalinclude} ../tests/machines/showcase_simple.py +:pyobject: SimpleSC +:language: python +``` + +And here is the rendered autodoc output: + +```{eval-rst} +.. autoclass:: tests.machines.showcase_simple.SimpleSC + :noindex: +``` + ## Sphinx directive @@ -179,6 +401,26 @@ zoom and pan freely: :align: center ``` +### Mermaid format + +Use `:format: mermaid` to render via +[sphinxcontrib-mermaid](https://github.com/mgaitan/sphinxcontrib-mermaid) +instead of Graphviz SVG — useful when you don't want to install Graphviz +in your docs build environment: + +````markdown +```{statemachine-diagram} myproject.machines.TrafficLight +:format: mermaid +:caption: Rendered as Mermaid +``` +```` + +```{statemachine-diagram} tests.examples.traffic_light_machine.TrafficLightMachine +:format: mermaid +:caption: TrafficLightMachine (Mermaid) +:align: center +``` + ### Directive options The directive supports the same layout options as the standard `image` and @@ -190,6 +432,10 @@ The directive supports the same layout options as the standard `image` and : Events to send in sequence. When present, the machine is instantiated and each event is sent before rendering. +`:format:` *(string)* +: Output format. Use `mermaid` to render via sphinxcontrib-mermaid + instead of Graphviz SVG. Default: DOT/SVG. + **Image/figure options:** `:caption:` *(string)* @@ -299,9 +545,9 @@ dot.write_png("order_control_class.png") ## Visual showcase This section shows how each state machine feature is rendered in diagrams. -Each example includes the class definition, the **class** diagram (no -active state), and **instance** diagrams (with the current state -highlighted after sending events). +Each example includes the class definition, diagrams in both **Graphviz** +and **Mermaid** formats, and **instance** diagrams with the current state +highlighted after sending events. ### Simple states @@ -314,7 +560,12 @@ A minimal state machine with three atomic states and linear transitions. ``` ```{statemachine-diagram} tests.machines.showcase_simple.SimpleSC -:caption: Class +:caption: Class (Graphviz) +``` + +```{statemachine-diagram} tests.machines.showcase_simple.SimpleSC +:format: mermaid +:caption: Class (Mermaid) ``` ```{statemachine-diagram} tests.machines.showcase_simple.SimpleSC @@ -343,7 +594,12 @@ States can declare `entry` / `exit` callbacks, shown in the state label. ``` ```{statemachine-diagram} tests.machines.showcase_actions.ActionsSC -:caption: Class +:caption: Class (Graphviz) +``` + +```{statemachine-diagram} tests.machines.showcase_actions.ActionsSC +:format: mermaid +:caption: Class (Mermaid) ``` ```{statemachine-diagram} tests.machines.showcase_actions.ActionsSC @@ -362,7 +618,12 @@ Transitions can have `cond` guards, shown in brackets on the edge label. ``` ```{statemachine-diagram} tests.machines.showcase_guards.GuardSC -:caption: Class +:caption: Class (Graphviz) +``` + +```{statemachine-diagram} tests.machines.showcase_guards.GuardSC +:format: mermaid +:caption: Class (Mermaid) ``` ```{statemachine-diagram} tests.machines.showcase_guards.GuardSC @@ -381,7 +642,12 @@ A transition from a state back to itself. ``` ```{statemachine-diagram} tests.machines.showcase_self_transition.SelfTransitionSC -:caption: Class +:caption: Class (Graphviz) +``` + +```{statemachine-diagram} tests.machines.showcase_self_transition.SelfTransitionSC +:format: mermaid +:caption: Class (Mermaid) ``` ```{statemachine-diagram} tests.machines.showcase_self_transition.SelfTransitionSC @@ -400,7 +666,12 @@ Internal transitions execute actions without exiting/entering the state. ``` ```{statemachine-diagram} tests.machines.showcase_internal.InternalSC -:caption: Class +:caption: Class (Graphviz) +``` + +```{statemachine-diagram} tests.machines.showcase_internal.InternalSC +:format: mermaid +:caption: Class (Mermaid) ``` ```{statemachine-diagram} tests.machines.showcase_internal.InternalSC @@ -420,10 +691,15 @@ its initial child. ``` ```{statemachine-diagram} tests.machines.showcase_compound.CompoundSC -:caption: Class +:caption: Class (Graphviz) :target: ``` +```{statemachine-diagram} tests.machines.showcase_compound.CompoundSC +:format: mermaid +:caption: Class (Mermaid) +``` + ```{statemachine-diagram} tests.machines.showcase_compound.CompoundSC :events: :caption: Off @@ -453,10 +729,15 @@ A parallel state activates all its regions simultaneously. ``` ```{statemachine-diagram} tests.machines.showcase_parallel.ParallelSC -:caption: Class +:caption: Class (Graphviz) :target: ``` +```{statemachine-diagram} tests.machines.showcase_parallel.ParallelSC +:format: mermaid +:caption: Class (Mermaid) +``` + ```{statemachine-diagram} tests.machines.showcase_parallel.ParallelSC :events: enter :caption: Both active @@ -470,6 +751,41 @@ A parallel state activates all its regions simultaneously. ``` +### Parallel with cross-boundary transitions + +A transition targeting a compound state **inside** a parallel region triggers a +rendering bug in Mermaid (`mermaid-js/mermaid#4052`). The Mermaid renderer works +around this by redirecting the arrow to the compound's initial child — compare the +``rebuild`` arrow in both diagrams below. + +```{literalinclude} ../tests/machines/showcase_parallel_compound.py +:pyobject: ParallelCompoundSC +:language: python +``` + +```{statemachine-diagram} tests.machines.showcase_parallel_compound.ParallelCompoundSC +:caption: Class (Graphviz) — ``rebuild`` points to the Build compound border +:target: +``` + +```{statemachine-diagram} tests.machines.showcase_parallel_compound.ParallelCompoundSC +:format: mermaid +:caption: Class (Mermaid) — ``rebuild`` is redirected to Compile (initial child of Build) +``` + +```{statemachine-diagram} tests.machines.showcase_parallel_compound.ParallelCompoundSC +:events: start, do_build +:caption: Build done +:target: +``` + +```{statemachine-diagram} tests.machines.showcase_parallel_compound.ParallelCompoundSC +:events: start, do_build, do_test +:caption: Pipeline done → Review +:target: +``` + + ### History states (shallow) A history pseudo-state remembers the last active child of a compound state. @@ -480,10 +796,15 @@ A history pseudo-state remembers the last active child of a compound state. ``` ```{statemachine-diagram} tests.machines.showcase_history.HistorySC -:caption: Class +:caption: Class (Graphviz) :target: ``` +```{statemachine-diagram} tests.machines.showcase_history.HistorySC +:format: mermaid +:caption: Class (Mermaid) +``` + ```{statemachine-diagram} tests.machines.showcase_history.HistorySC :events: begin, advance :caption: Step2 @@ -513,10 +834,15 @@ Deep history remembers the exact leaf state across nested compounds. ``` ```{statemachine-diagram} tests.machines.showcase_deep_history.DeepHistorySC -:caption: Class +:caption: Class (Graphviz) :target: ``` +```{statemachine-diagram} tests.machines.showcase_deep_history.DeepHistorySC +:format: mermaid +:caption: Class (Mermaid) +``` + ```{statemachine-diagram} tests.machines.showcase_deep_history.DeepHistorySC :events: dive, enter_inner, go :caption: Inner/B diff --git a/docs/releases/3.1.0.md b/docs/releases/3.1.0.md index 34c6a3f5..e566ec62 100644 --- a/docs/releases/3.1.0.md +++ b/docs/releases/3.1.0.md @@ -4,6 +4,90 @@ ## What's new in 3.1.0 +### Text representations with `format()` + +State machines now support Python's built-in `format()` protocol. Use f-strings +or `format()` to get text representations — on both classes and instances: + +```python +f"{TrafficLightMachine:md}" +f"{sm:mermaid}" +format(sm, "rst") +``` + +Supported formats: + +| Format | Output | Requires | +|-----------|---------------------------|-----------------------| +| `dot` | Graphviz DOT source | `pydot` | +| `svg` | SVG markup (via Graphviz) | `pydot` + `graphviz` | +| `mermaid` | Mermaid stateDiagram-v2 | — | +| `md` | Markdown transition table | — | +| `rst` | RST transition table | — | + +See {ref}`diagram:Text representations` for details. + + +### Formatter facade + +A new `Formatter` facade with decorator-based registration unifies all text +format rendering behind a single API. Adding a new format requires only +registering a render function — no changes to `__format__`, the CLI, or the +Sphinx directive: + +```python +from statemachine.contrib.diagram import formatter + +formatter.render(sm, "mermaid") +formatter.supported_formats() + +@formatter.register_format("custom") +def _render_custom(machine_or_class): + ... +``` + +See {ref}`formatter-api` for details. + + +### Mermaid diagram support + +State machines can now be rendered as +[Mermaid `stateDiagram-v2`](https://mermaid.js.org/syntax/stateDiagram.html) +source text — no Graphviz installation required. Supports compound states, +parallel regions, history states, guards, and active-state highlighting. + +Three ways to use it: + +- **f-strings:** `f"{sm:mermaid}"` +- **CLI:** `python -m statemachine.contrib.diagram MyMachine - --format mermaid` +- **Sphinx directive:** `:format: mermaid` renders via `sphinxcontrib-mermaid`. + +See {ref}`diagram:Mermaid format` for details. + + +### Auto-expanding docstrings + +Use `{statechart:FORMAT}` placeholders in your class docstring to embed a +live representation of the state machine. The placeholder is replaced at +class definition time, so the docstring always stays in sync with the code: + +```python +class TrafficLight(StateChart): + """A traffic light. + + {statechart:md} + """ + green = State(initial=True) + yellow = State() + red = State() + cycle = green.to(yellow) | yellow.to(red) | red.to(green) +``` + +Any registered format works: `md`, `rst`, `mermaid`, `dot`, etc. +Works with Sphinx autodoc — the expanded docstring is what gets rendered. +See {ref}`diagram:Auto-expanding docstrings` for details. + + ### Sphinx directive for inline diagrams A new Sphinx extension renders state machine diagrams directly in your @@ -29,10 +113,27 @@ events before rendering (highlighting the current state). Using `:target:` without a value makes the diagram clickable, opening the full SVG in a new browser tab for zooming — useful for large statecharts. +The `:format: mermaid` option renders via `sphinxcontrib-mermaid` instead of +Graphviz. + See {ref}`diagram:Sphinx directive` for full documentation. [#589](https://github.com/fgmacedo/python-statemachine/pull/589). +### Diagram CLI `--events` and `--format` options + +The `python -m statemachine.contrib.diagram` command now accepts: + +- `--events` to instantiate the machine and send events before rendering, + highlighting the current active state. +- `--format` to choose the output format (`mermaid`, `md`, `rst`, `dot`, `svg`, + or image formats via Graphviz). Use `-` as the output path to write text + formats to stdout. + +See {ref}`diagram:Command line` for details. +[#593](https://github.com/fgmacedo/python-statemachine/pull/593). + + ### Performance: 5x–7x faster event processing The engine's hot paths have been systematically profiled and optimized, resulting in @@ -49,17 +150,10 @@ machine instance concurrently. This is now documented in the [#592](https://github.com/fgmacedo/python-statemachine/pull/592). -### Diagram CLI `--events` option - -The `python -m statemachine.contrib.diagram` command now accepts `--events` to -instantiate the machine and send events before rendering, highlighting the -current active state — matching the Sphinx directive's `:events:` option. -See {ref}`diagram:Command line` for details. - ### Bugfixes in 3.1.0 - Fixes silent misuse of `Event()` with multiple positional arguments. Passing more than one - transition to `Event()` (e.g., `Event(t1, t2)`) now raises {ref}`InvalidDefinition` with a + transition to `Event()` (e.g., `Event(t1, t2)`) now raises `InvalidDefinition` with a clear message suggesting the `|` operator. Previously, the second argument was silently interpreted as the event `id`, leaving the extra transitions eventless (auto-firing). [#588](https://github.com/fgmacedo/python-statemachine/pull/588). diff --git a/docs/tutorial.md b/docs/tutorial.md index e6b23d3b..d49526d7 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -364,16 +364,55 @@ Or from the command line: python -m statemachine.contrib.diagram my_module.CoffeeOrder order.png ``` +### Text representations with `format()` + +You can also get text representations of any state machine using Python's built-in +`format()` or f-strings — no Graphviz needed: + +```py +>>> from tests.machines.tutorial_coffee_order import CoffeeOrder + +>>> print(f"{CoffeeOrder:md}") +| State | Event | Guard | Target | +| --------- | ------- | ----- | --------- | +| Pending | start | | Preparing | +| Preparing | finish | | Ready | +| Ready | pick_up | | Picked up | + +``` + +Supported formats include `mermaid`, `md` (markdown table), `rst`, `dot`, and `svg`. +Works on both classes and instances: + +```py +>>> print(f"{CoffeeOrder:mermaid}") +stateDiagram-v2 + direction LR + state "Pending" as pending + state "Preparing" as preparing + state "Ready" as ready + state "Picked up" as picked_up + [*] --> pending + picked_up --> [*] + pending --> preparing : start + preparing --> ready : finish + ready --> picked_up : pick_up + + +``` + ```{tip} -Diagram generation requires [Graphviz](https://graphviz.org/) (`dot` command) +Graphviz diagram generation requires [Graphviz](https://graphviz.org/) (`dot` command) and the `diagrams` extra: pip install python-statemachine[diagrams] + +Text formats (`md`, `rst`, `mermaid`) work without any extra dependencies. ``` ```{seealso} -See [](diagram.md) for highlighting active states, Jupyter integration, -SVG output, DPI settings, Sphinx directive, and the `quickchart_write_svg` +See [](diagram.md) for all formats, highlighting active states, auto-expanding +docstrings, Jupyter integration, Sphinx directive, and the `quickchart_write_svg` alternative that doesn't require Graphviz. ``` diff --git a/pyproject.toml b/pyproject.toml index b05ccfb3..40b3b9c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ dev = [ "sphinx-autobuild; python_version >'3.8'", "furo >=2024.5.6; python_version >'3.8'", "sphinx-copybutton >=0.5.2; python_version >'3.8'", + "sphinxcontrib-mermaid; python_version >'3.8'", "pdbr>=0.8.9; python_version >'3.8'", "babel >=2.16.0; python_version >='3.8'", "pytest-xdist>=3.6.1", diff --git a/statemachine/contrib/diagram/__init__.py b/statemachine/contrib/diagram/__init__.py index a2e31a71..c6317304 100644 --- a/statemachine/contrib/diagram/__init__.py +++ b/statemachine/contrib/diagram/__init__.py @@ -3,8 +3,11 @@ from urllib.request import urlopen from .extract import extract +from .formatter import formatter as formatter from .renderers.dot import DotRenderer from .renderers.dot import DotRendererConfig +from .renderers.mermaid import MermaidRenderer +from .renderers.mermaid import MermaidRendererConfig class DotGraphMachine: @@ -56,6 +59,32 @@ def __call__(self): return self.get_graph() +class MermaidGraphMachine: + """Facade for generating Mermaid stateDiagram-v2 source from a state machine.""" + + direction = "LR" + active_fill = "#40E0D0" + active_stroke = "#333" + + def __init__(self, machine): + self.machine = machine + + def _build_config(self) -> MermaidRendererConfig: + return MermaidRendererConfig( + direction=self.direction, + active_fill=self.active_fill, + active_stroke=self.active_stroke, + ) + + def get_mermaid(self) -> str: + ir = extract(self.machine) + renderer = MermaidRenderer(config=self._build_config()) + return renderer.render(ir) + + def __call__(self) -> str: + return self.get_mermaid() + + def quickchart_write_svg(sm, path: str): """ If the default dependency of GraphViz installed locally doesn't work for you. As an option, @@ -135,7 +164,7 @@ def import_sm(qualname): return smclass -def write_image(qualname, out, events=None): +def write_image(qualname, out, events=None, fmt=None): """ Given a `qualname`, that is the fully qualified dotted path to a StateMachine classes, imports the class and generates a dot graph using the `pydot` lib. @@ -146,7 +175,13 @@ def write_image(qualname, out, events=None): If `events` is provided, the machine is instantiated and each event is sent before rendering, so the diagram highlights the current active state. + + If `fmt` is provided, it overrides the output format (any registered text + format such as ``"mermaid"``, ``"dot"``, ``"md"``, ``"rst"``). + Use ``out="-"`` to write to stdout. """ + import sys + smclass = import_sm(qualname) if events: @@ -156,9 +191,20 @@ def write_image(qualname, out, events=None): else: machine = smclass - graph = DotGraphMachine(machine).get_graph() - out_extension = out.rsplit(".", 1)[1] - graph.write(out, format=out_extension) + if fmt is not None: + text = formatter.render(machine, fmt) + if out == "-": + sys.stdout.write(text) + else: + with open(out, "w") as f: + f.write(text) + else: + graph = DotGraphMachine(machine).get_graph() + if out == "-": + sys.stdout.buffer.write(graph.create_svg()) # type: ignore[attr-defined] + else: + out_extension = out.rsplit(".", 1)[1] + graph.write(out, format=out_extension) def main(argv=None): @@ -180,6 +226,12 @@ def main(argv=None): nargs="+", help="Instantiate the machine and send these events before rendering.", ) + parser.add_argument( + "--format", + choices=formatter.supported_formats(), + default=None, + help="Output as text format instead of Graphviz image.", + ) args = parser.parse_args(argv) - write_image(qualname=args.class_path, out=args.out, events=args.events) + write_image(qualname=args.class_path, out=args.out, events=args.events, fmt=args.format) diff --git a/statemachine/contrib/diagram/extract.py b/statemachine/contrib/diagram/extract.py index 37f4fc88..15a1f2da 100644 --- a/statemachine/contrib/diagram/extract.py +++ b/statemachine/contrib/diagram/extract.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: from statemachine.state import State from statemachine.statemachine import StateChart + from statemachine.transition import Transition # A StateChart class or instance — both expose the same structural metadata. MachineRef = Union["StateChart", "type[StateChart]"] @@ -101,6 +102,33 @@ def _extract_state( ) +def _format_event_names(transition: "Transition") -> str: + """Build a display string for the events that trigger a transition. + + ``_expand_event_id`` registers both the Python attribute name + (``done_invoke_X``) and the SCXML dot form (``done.invoke.X``) under the + same transition. For diagram display we only want unique *semantic* events, + keeping the Python attribute name when an alias pair exists. + """ + events = list(transition.events) + if not events: + return "" + + all_ids = {str(e) for e in events} + + display: List[str] = [] + for event in events: + eid = str(event) + # Skip dot-form aliases (e.g. "done.invoke.X") when the underscore + # form ("done_invoke_X") is also registered on this transition. + if "." in eid and eid.replace(".", "_") in all_ids: + continue + if eid not in display: # pragma: no branch + display.append(eid) + + return " ".join(display) + + def _extract_transitions_from_state(state: "State") -> List[DiagramTransition]: """Extract transitions from a single state (non-recursive).""" result: List[DiagramTransition] = [] @@ -114,7 +142,7 @@ def _extract_transitions_from_state(state: "State") -> List[DiagramTransition]: DiagramTransition( source=transition.source.id, targets=target_ids, - event=transition.event, + event=_format_event_names(transition), guards=cond_strs, is_internal=transition.internal, ) diff --git a/statemachine/contrib/diagram/formatter.py b/statemachine/contrib/diagram/formatter.py new file mode 100644 index 00000000..0ce8a1b0 --- /dev/null +++ b/statemachine/contrib/diagram/formatter.py @@ -0,0 +1,137 @@ +"""Unified facade for rendering state machines in multiple text formats. + +The :class:`Formatter` class provides a decorator-based registry where each +renderer declares the format names it handles. Adding a new format only +requires writing a renderer function and decorating it — no changes to +``__format__``, ``factory.py``, or ``statemachine.py``. + +A module-level :data:`formatter` instance is the single public entry point:: + + from statemachine.contrib.diagram import formatter + + print(formatter.render(sm, "mermaid")) + + @formatter.register_format("plantuml") + def _render_plantuml(machine): + ... +""" + +from typing import TYPE_CHECKING +from typing import Callable +from typing import Dict +from typing import List + +if TYPE_CHECKING: + from typing import Union + + from statemachine.statemachine import StateChart + + MachineRef = Union["StateChart", "type[StateChart]"] + + +class Formatter: + """Unified facade for rendering state machines in multiple text formats.""" + + def __init__(self) -> None: + self._formats: Dict[str, "Callable[[MachineRef], str]"] = {} + + def register_format( + self, *names: str + ) -> "Callable[[Callable[[MachineRef], str]], Callable[[MachineRef], str]]": + """Decorator factory that registers a renderer under one or more format names. + + Usage:: + + @formatter.register_format("md", "markdown") + def _render_md(machine_or_class): + ... + """ + + def decorator( + fn: "Callable[[MachineRef], str]", + ) -> "Callable[[MachineRef], str]": + for name in names: + self._formats[name] = fn + return fn + + return decorator + + def render(self, machine_or_class: "MachineRef", fmt: str) -> str: + """Render a state machine in the given text format. + + Args: + machine_or_class: A ``StateChart`` instance or class. + fmt: Format name (e.g., ``"mermaid"``, ``"dot"``, ``"md"``). + Empty string falls back to ``repr()``. + + Raises: + ValueError: If ``fmt`` is not registered. + """ + if fmt == "": + return repr(machine_or_class) + + renderer_fn = self._formats.get(fmt) + if renderer_fn is None: + primary = sorted({self._primary_name(fn) for fn in set(self._formats.values())}) + raise ValueError( + f"Unsupported format: {fmt!r}. Use {', '.join(repr(n) for n in primary)}." + ) + return renderer_fn(machine_or_class) + + def supported_formats(self) -> List[str]: + """Return sorted list of all registered format names (including aliases).""" + return sorted(self._formats) + + def _primary_name(self, fn: "Callable[[MachineRef], str]") -> str: + """Return the first registered name for a given renderer function.""" + for name, registered_fn in self._formats.items(): + if registered_fn is fn: + return name + return "?" # pragma: no cover + + +formatter = Formatter() +"""Module-level :class:`Formatter` instance — the single public entry point.""" + + +# --------------------------------------------------------------------------- +# Built-in format registrations +# --------------------------------------------------------------------------- + + +@formatter.register_format("dot") +def _render_dot(machine_or_class: "MachineRef") -> str: + from statemachine.contrib.diagram import DotGraphMachine + + return DotGraphMachine(machine_or_class).get_graph().to_string() # type: ignore[no-any-return] + + +@formatter.register_format("svg") +def _render_svg(machine_or_class: "MachineRef") -> str: + from statemachine.contrib.diagram import DotGraphMachine + + svg_bytes: bytes = DotGraphMachine(machine_or_class).get_graph().create_svg() # type: ignore[attr-defined] + return svg_bytes.decode("utf-8") + + +@formatter.register_format("mermaid") +def _render_mermaid(machine_or_class: "MachineRef") -> str: + from statemachine.contrib.diagram import MermaidGraphMachine + + return MermaidGraphMachine(machine_or_class).get_mermaid() + + +@formatter.register_format("md", "markdown") +def _render_md(machine_or_class: "MachineRef") -> str: + from statemachine.contrib.diagram.extract import extract + from statemachine.contrib.diagram.renderers.table import TransitionTableRenderer + + return TransitionTableRenderer().render(extract(machine_or_class), fmt="md") + + +@formatter.register_format("rst") +def _render_rst(machine_or_class: "MachineRef") -> str: + from statemachine.contrib.diagram.extract import extract + from statemachine.contrib.diagram.renderers.table import TransitionTableRenderer + + return TransitionTableRenderer().render(extract(machine_or_class), fmt="rst") diff --git a/statemachine/contrib/diagram/renderers/mermaid.py b/statemachine/contrib/diagram/renderers/mermaid.py new file mode 100644 index 00000000..15ba61d5 --- /dev/null +++ b/statemachine/contrib/diagram/renderers/mermaid.py @@ -0,0 +1,348 @@ +from dataclasses import dataclass +from typing import Dict +from typing import List +from typing import Optional +from typing import Set + +from ..model import ActionType +from ..model import DiagramAction +from ..model import DiagramGraph +from ..model import DiagramState +from ..model import DiagramTransition +from ..model import StateType + + +@dataclass +class MermaidRendererConfig: + """Configuration for the Mermaid renderer.""" + + direction: str = "LR" + active_fill: str = "#40E0D0" + active_stroke: str = "#333" + + +class MermaidRenderer: + """Renders a DiagramGraph into a Mermaid stateDiagram-v2 source string. + + Mermaid's stateDiagram-v2 has a rendering bug + (`mermaid-js/mermaid#4052 `_) + where transitions whose source or target is a compound state + (``state X { ... }``) **inside a parallel region** crash with + ``Cannot set properties of undefined (setting 'rank')``. To work around + this, the renderer rewrites compound-state endpoints that are descendants + of a parallel state, redirecting them to the compound's initial child. + Compound states outside parallel regions are left unchanged. + """ + + def __init__(self, config: Optional[MermaidRendererConfig] = None): + self.config = config or MermaidRendererConfig() + self._active_ids: List[str] = [] + self._rendered_transitions: Set[tuple] = set() + self._compound_ids: Set[str] = set() + self._initial_child_map: Dict[str, str] = {} + self._parallel_descendant_ids: Set[str] = set() + self._all_descendants_map: Dict[str, Set[str]] = {} + + def render(self, graph: DiagramGraph) -> str: + """Render a DiagramGraph to a Mermaid stateDiagram-v2 string.""" + self._active_ids = [] + self._rendered_transitions = set() + self._compound_ids = graph.compound_state_ids + self._initial_child_map = self._build_initial_child_map(graph.states) + self._parallel_descendant_ids = self._collect_parallel_descendants(graph.states) + self._all_descendants_map = self._build_all_descendants_map(graph.states) + + lines: List[str] = [] + lines.append("stateDiagram-v2") + lines.append(f" direction {self.config.direction}") + + top_ids = {s.id for s in graph.states} + self._render_states(graph.states, graph.transitions, lines, indent=1) + self._render_initial_and_final(graph.states, lines, indent=1) + self._render_scope_transitions(graph.transitions, top_ids, lines, indent=1) + + if self._active_ids: + cfg = self.config + lines.append("") + lines.append(f" classDef active fill:{cfg.active_fill},stroke:{cfg.active_stroke}") + for sid in self._active_ids: + lines.append(f" {sid}:::active") + + return "\n".join(lines) + "\n" + + def _build_initial_child_map(self, states: List[DiagramState]) -> Dict[str, str]: + """Build a map from compound state ID to its initial child ID (recursive).""" + result: Dict[str, str] = {} + for state in states: + if state.children: + initial = next((c for c in state.children if c.is_initial), None) + if initial: + result[state.id] = initial.id + result.update(self._build_initial_child_map(state.children)) + return result + + @staticmethod + def _collect_parallel_descendants( + states: List[DiagramState], + inside_parallel: bool = False, + ) -> Set[str]: + """Collect IDs of all states that are descendants of a parallel state.""" + result: Set[str] = set() + for state in states: + if inside_parallel: + result.add(state.id) + child_inside = inside_parallel or state.type == StateType.PARALLEL + result.update( + MermaidRenderer._collect_parallel_descendants(state.children, child_inside) + ) + return result + + def _build_all_descendants_map(self, states: List[DiagramState]) -> Dict[str, Set[str]]: + """Map each compound state ID to the set of all its descendant IDs.""" + result: Dict[str, Set[str]] = {} + for state in states: + if state.children: + result[state.id] = self._collect_recursive_descendants(state.children) + result.update(self._build_all_descendants_map(state.children)) + return result + + @staticmethod + def _collect_recursive_descendants(states: List[DiagramState]) -> Set[str]: + """Collect all state IDs in a subtree recursively.""" + ids: Set[str] = set() + for s in states: + ids.add(s.id) + ids.update(MermaidRenderer._collect_recursive_descendants(s.children)) + return ids + + def _resolve_endpoint(self, state_id: str) -> str: + """Resolve a transition endpoint for Mermaid compatibility. + + Only redirects compound states that are inside a parallel region — + this is where Mermaid's rendering bug (mermaid-js/mermaid#4052) occurs. + Compound states outside parallel regions are left unchanged. + """ + if ( + state_id in self._compound_ids + and state_id in self._parallel_descendant_ids + and state_id in self._initial_child_map + ): + return self._initial_child_map[state_id] + return state_id + + def _render_states( + self, + states: List[DiagramState], + transitions: List[DiagramTransition], + lines: List[str], + indent: int, + ) -> None: + for state in states: + if state.type in (StateType.HISTORY_SHALLOW, StateType.HISTORY_DEEP): + label = "H*" if state.type == StateType.HISTORY_DEEP else "H" + pad = " " * indent + lines.append(f'{pad}state "{label}" as {state.id}') + continue + + if state.type == StateType.CHOICE: + pad = " " * indent + lines.append(f"{pad}state {state.id} <>") + continue + + if state.type == StateType.FORK: + pad = " " * indent + lines.append(f"{pad}state {state.id} <>") + continue + + if state.type == StateType.JOIN: + pad = " " * indent + lines.append(f"{pad}state {state.id} <>") + continue + + if state.children: + self._render_compound_state(state, transitions, lines, indent) + else: + self._render_atomic_state(state, lines, indent) + + def _render_atomic_state( + self, + state: DiagramState, + lines: List[str], + indent: int, + ) -> None: + pad = " " * indent + + if state.name != state.id: + lines.append(f'{pad}state "{state.name}" as {state.id}') + + actions = [a for a in state.actions if a.type != ActionType.INTERNAL or a.body] + if actions: + for action in actions: + lines.append(f"{pad}{state.id} : {self._format_action(action)}") + + if state.is_active: + self._active_ids.append(state.id) + + def _render_compound_state( + self, + state: DiagramState, + transitions: List[DiagramTransition], + lines: List[str], + indent: int, + ) -> None: + pad = " " * indent + + if state.type == StateType.PARALLEL: + lines.append(f'{pad}state "{state.name}" as {state.id} {{') + regions = [c for c in state.children if c.is_parallel_area or c.children] + for i, region in enumerate(regions): + if i > 0: + lines.append(f"{pad} --") + self._render_compound_state(region, transitions, lines, indent + 1) + lines.append(f"{pad}}}") + else: + label = state.name if state.name != state.id else "" + if label: + lines.append(f'{pad}state "{label}" as {state.id} {{') + else: + lines.append(f"{pad}state {state.id} {{") + + initial_child = next((c for c in state.children if c.is_initial), None) + if initial_child: + lines.append(f"{pad} [*] --> {initial_child.id}") + + self._render_states(state.children, transitions, lines, indent + 1) + + # Render transitions scoped to this compound + child_ids = self._collect_all_descendant_ids(state.children) + self._render_scope_transitions(transitions, child_ids, lines, indent + 1) + + # Final state transitions + for child in state.children: + if child.type == StateType.FINAL: + lines.append(f"{pad} {child.id} --> [*]") + + lines.append(f"{pad}}}") + + if state.is_active: + self._active_ids.append(state.id) + + def _collect_all_descendant_ids(self, states: List[DiagramState]) -> Set[str]: + """Collect all state IDs in a subtree (direct children only for scope).""" + ids: Set[str] = set() + for s in states: + ids.add(s.id) + return ids + + def _render_scope_transitions( + self, + transitions: List[DiagramTransition], + scope_ids: Set[str], + lines: List[str], + indent: int, + ) -> None: + """Render transitions that belong to this scope level. + + A transition belongs to scope S if all its endpoints are *reachable* + from S (either directly in S or descendants of a compound in S) **and** + the transition is not fully internal to a single compound in S (those + are rendered by the compound's inner scope). + + This allows cross-boundary transitions (e.g., an outer state targeting + a history pseudo-state inside a compound) to be rendered at the correct + scope level — Mermaid draws the arrow crossing the compound border. + + Mermaid crashes when the source or target is a compound state inside a + parallel region (mermaid-js/mermaid#4052). For those cases, endpoints + are redirected to the compound's initial child via ``_resolve_endpoint``. + """ + # Build the descendant sets for compounds in this scope + compound_descendants: Dict[str, Set[str]] = {} + expanded: Set[str] = set(scope_ids) + for sid in scope_ids: + if sid in self._all_descendants_map: + compound_descendants[sid] = self._all_descendants_map[sid] + expanded |= self._all_descendants_map[sid] + + for t in transitions: + if t.is_initial or t.is_internal: + continue + + targets = t.targets if t.targets else [t.source] + + # All endpoints must be reachable from this scope + if t.source not in expanded: + continue + if not all(target in expanded for target in targets): + continue + + # Skip transitions fully internal to a single compound — + # those will be rendered by the compound's inner scope. + if self._is_fully_internal(t.source, targets, compound_descendants): + continue + + # Resolve endpoints for rendering (redirect compound → initial child) + source = self._resolve_endpoint(t.source) + resolved_targets = [self._resolve_endpoint(tid) for tid in targets] + + for target in resolved_targets: + key = (source, target, t.event) + if key in self._rendered_transitions: + continue + self._rendered_transitions.add(key) + self._render_single_transition(t, source, target, lines, indent) + + @staticmethod + def _is_fully_internal( + source: str, + targets: List[str], + compound_descendants: Dict[str, Set[str]], + ) -> bool: + """Check if all endpoints belong to the same compound's descendants.""" + for descendants in compound_descendants.values(): + if source in descendants and all(tgt in descendants for tgt in targets): + return True + return False + + def _render_single_transition( + self, + transition: DiagramTransition, + source: str, + target: str, + lines: List[str], + indent: int, + ) -> None: + pad = " " * indent + label_parts: List[str] = [] + if transition.event: + label_parts.append(transition.event) + if transition.guards: + label_parts.append(f"[{', '.join(transition.guards)}]") + + label = " ".join(label_parts) + if label: + lines.append(f"{pad}{source} --> {target} : {label}") + else: + lines.append(f"{pad}{source} --> {target}") + + @staticmethod + def _format_action(action: DiagramAction) -> str: + if action.type == ActionType.INTERNAL: + return action.body + return f"{action.type.value} / {action.body}" + + def _render_initial_and_final( + self, + states: List[DiagramState], + lines: List[str], + indent: int, + ) -> None: + """Render top-level [*] --> initial and final --> [*] arrows.""" + pad = " " * indent + initial = next((s for s in states if s.is_initial), None) + if initial: + lines.append(f"{pad}[*] --> {initial.id}") + + for state in states: + if state.type == StateType.FINAL: + lines.append(f"{pad}{state.id} --> [*]") diff --git a/statemachine/contrib/diagram/renderers/table.py b/statemachine/contrib/diagram/renderers/table.py new file mode 100644 index 00000000..eeaa18ec --- /dev/null +++ b/statemachine/contrib/diagram/renderers/table.py @@ -0,0 +1,105 @@ +from typing import List + +from ..model import DiagramGraph +from ..model import DiagramState +from ..model import DiagramTransition + + +class TransitionTableRenderer: + """Renders a DiagramGraph as a transition table in markdown or RST format.""" + + def render(self, graph: DiagramGraph, fmt: str = "md") -> str: + """Render the transition table. + + Args: + graph: The diagram IR to render. + fmt: Output format — ``"md"`` for markdown, ``"rst"`` for reStructuredText. + + Returns: + The formatted transition table as a string. + """ + rows = self._collect_rows(graph.states, graph.transitions) + + if fmt == "rst": + return self._render_rst(rows) + return self._render_md(rows) + + def _collect_rows( + self, + states: List[DiagramState], + transitions: List[DiagramTransition], + ) -> "List[tuple[str, str, str, str]]": + """Collect (State, Event, Guard, Target) tuples from the IR.""" + rows: List[tuple[str, str, str, str]] = [] + state_names = self._build_state_name_map(states) + + for t in transitions: + if t.is_initial or t.is_internal: + continue + + source_name = state_names.get(t.source, t.source) + guard = ", ".join(t.guards) if t.guards else "" + event = t.event or "" + + if t.targets: + for target_id in t.targets: + target_name = state_names.get(target_id, target_id) + rows.append((source_name, event, guard, target_name)) + else: + rows.append((source_name, event, guard, source_name)) + + return rows + + def _build_state_name_map(self, states: List[DiagramState]) -> dict: + """Build a mapping from state ID to display name, recursively.""" + result: dict = {} + for state in states: + result[state.id] = state.name + if state.children: + result.update(self._build_state_name_map(state.children)) + return result + + def _render_md(self, rows: "List[tuple[str, str, str, str]]") -> str: + """Render as a markdown table.""" + headers = ("State", "Event", "Guard", "Target") + col_widths = [len(h) for h in headers] + + for row in rows: + for i, cell in enumerate(row): + col_widths[i] = max(col_widths[i], len(cell)) + + def _fmt_row(cells: "tuple[str, ...]") -> str: + parts = [cell.ljust(col_widths[i]) for i, cell in enumerate(cells)] + return "| " + " | ".join(parts) + " |" + + lines = [_fmt_row(headers)] + lines.append("| " + " | ".join("-" * w for w in col_widths) + " |") + for row in rows: + lines.append(_fmt_row(row)) + + return "\n".join(lines) + "\n" + + def _render_rst(self, rows: "List[tuple[str, str, str, str]]") -> str: + """Render as an RST grid table.""" + headers = ("State", "Event", "Guard", "Target") + col_widths = [len(h) for h in headers] + + for row in rows: + for i, cell in enumerate(row): + col_widths[i] = max(col_widths[i], len(cell)) + + def _border(char: str = "-") -> str: + return "+" + "+".join(char * (w + 2) for w in col_widths) + "+" + + def _data_row(cells: "tuple[str, ...]") -> str: + parts = [f" {cell.ljust(col_widths[i])} " for i, cell in enumerate(cells)] + return "|" + "|".join(parts) + "|" + + lines = [_border("-")] + lines.append(_data_row(headers)) + lines.append(_border("=")) + for row in rows: + lines.append(_data_row(row)) + lines.append(_border("-")) + + return "\n".join(lines) + "\n" diff --git a/statemachine/contrib/diagram/sphinx_ext.py b/statemachine/contrib/diagram/sphinx_ext.py index bbc9a8ac..84ab50f3 100644 --- a/statemachine/contrib/diagram/sphinx_ext.py +++ b/statemachine/contrib/diagram/sphinx_ext.py @@ -39,7 +39,7 @@ def _parse_events(value: str) -> list[str]: # Match the outer ... element, stripping XML prologue/DOCTYPE. -_SVG_TAG_RE = re.compile(rb"()", re.DOTALL) +_SVG_TAG_RE = re.compile(r"()", re.DOTALL) # Match fixed width/height attributes (e.g. width="702pt" height="170pt"). _SVG_WIDTH_RE = re.compile(r'\bwidth="([^"]*(?:pt|px))"') @@ -61,6 +61,7 @@ class StateMachineDiagram(SphinxDirective): option_spec: ClassVar[dict[str, Any]] = { # State-machine options "events": directives.unchanged, + "format": directives.unchanged, # Standard image/figure options "caption": directives.unchanged, "alt": directives.unchanged, @@ -78,7 +79,7 @@ def run(self) -> list[nodes.Node]: qualname = self.arguments[0] try: - from statemachine.contrib.diagram import DotGraphMachine + from statemachine.contrib.diagram import formatter from statemachine.contrib.diagram import import_sm sm_class = import_sm(qualname) @@ -97,9 +98,13 @@ def run(self) -> list[nodes.Node]: else: machine = sm_class + output_format = self.options.get("format", "").strip().lower() + + if output_format == "mermaid": + return self._run_mermaid(machine, formatter, qualname) + try: - graph = DotGraphMachine(machine).get_graph() - svg_bytes: bytes = graph.create_svg() # type: ignore[attr-defined] + svg_text = formatter.render(machine, "svg") except Exception as exc: return [ self.state_machine.reporter.warning( @@ -108,12 +113,12 @@ def run(self) -> list[nodes.Node]: ) ] - svg_tag, intrinsic_width, intrinsic_height = self._prepare_svg(svg_bytes) + svg_tag, intrinsic_width, intrinsic_height = self._prepare_svg(svg_text) svg_styles = self._build_svg_styles(intrinsic_width, intrinsic_height) svg_tag = svg_tag.replace("{svg_tag}' if target: @@ -143,10 +148,49 @@ def run(self) -> list[nodes.Node]: return [raw_node] - def _prepare_svg(self, svg_bytes: bytes) -> tuple[str, str, str]: + def _run_mermaid(self, machine: object, formatter: Any, qualname: str) -> list[nodes.Node]: + """Render a Mermaid diagram using sphinxcontrib-mermaid's node type.""" + try: + mermaid_src = formatter.render(machine, "mermaid") + except Exception as exc: + return [ + self.state_machine.reporter.warning( + f"statemachine-diagram: failed to generate mermaid for {qualname!r}: {exc}", + line=self.lineno, + ) + ] + + try: + from sphinxcontrib.mermaid import ( # type: ignore[import-untyped] + mermaid as MermaidNode, + ) + except ImportError: + # Fallback: emit a raw code block if sphinxcontrib-mermaid is not installed + code_node = nodes.literal_block(mermaid_src, mermaid_src) + code_node["language"] = "mermaid" + return [code_node] + + node = MermaidNode() + node["code"] = mermaid_src + node["options"] = {} + + caption = self.options.get("caption") + if caption: + figure_node = nodes.figure() + figure_node += node + figure_node += nodes.caption(caption, caption) + if "name" in self.options: + self.add_name(figure_node) + return [figure_node] + + if "name" in self.options: + self.add_name(node) + return [node] + + def _prepare_svg(self, svg_text: str) -> tuple[str, str, str]: """Extract the ```` element and its intrinsic dimensions.""" - match = _SVG_TAG_RE.search(svg_bytes) - svg_tag = match.group(1).decode("utf-8") if match else svg_bytes.decode("utf-8") + match = _SVG_TAG_RE.search(svg_text) + svg_tag = match.group(1) if match else svg_text width_match = _SVG_WIDTH_RE.search(svg_tag) height_match = _SVG_HEIGHT_RE.search(svg_tag) @@ -188,7 +232,7 @@ def _build_svg_styles(self, intrinsic_width: str, intrinsic_height: str) -> str: return f'style="{"; ".join(parts)}"' - def _resolve_target(self, svg_bytes: bytes) -> str: + def _resolve_target(self, svg_text: str) -> str: """Return the href for the wrapper ```` tag, if any. When ``:target:`` is given without a value (or as empty string), the @@ -211,8 +255,8 @@ def _resolve_target(self, svg_bytes: bytes) -> str: outdir = os.path.join(self.env.app.outdir, "_images") os.makedirs(outdir, exist_ok=True) outpath = os.path.join(outdir, filename) - with open(outpath, "wb") as f: - f.write(svg_bytes) + with open(outpath, "w", encoding="utf-8") as f: + f.write(svg_text) return f"/_images/{filename}" diff --git a/statemachine/factory.py b/statemachine/factory.py index d470e3bd..c29825f7 100644 --- a/statemachine/factory.py +++ b/statemachine/factory.py @@ -1,3 +1,4 @@ +import re from typing import Any from typing import Dict from typing import List @@ -91,6 +92,42 @@ def __init__( cls._check() cls._setup() + cls._expand_docstring() + + _STATECHART_RE = re.compile(r"\{statechart:(\w+)\}") + + def _expand_docstring(cls) -> None: + """Replace ``{statechart:FORMAT}`` placeholders in the class docstring.""" + doc = cls.__doc__ + if not doc: + return + + from .contrib.diagram.formatter import formatter + + def _replace(match: "re.Match[str]") -> str: + fmt = match.group(1) + rendered = formatter.render(cls, fmt) # type: ignore[arg-type] + + # Respect the indentation of the placeholder line. + line_start = doc.rfind("\n", 0, match.start()) + if line_start == -1: + indent = "" + else: + indent_match = re.match(r"[ \t]*", doc[line_start + 1 : match.start()]) + indent = indent_match.group() if indent_match else "" + + if indent: + lines = rendered.split("\n") + rendered = lines[0] + "\n" + "\n".join(indent + line for line in lines[1:]) + + return rendered + + cls.__doc__ = cls._STATECHART_RE.sub(_replace, doc) + + def __format__(cls, fmt: str) -> str: + from .contrib.diagram.formatter import formatter + + return formatter.render(cls, fmt) # type: ignore[arg-type] def _initials_by_document_order( # noqa: C901 cls, states: List[State], parent: "State | None" = None, order: int = 1 diff --git a/statemachine/statemachine.py b/statemachine/statemachine.py index c3143a84..d33ea122 100644 --- a/statemachine/statemachine.py +++ b/statemachine/statemachine.py @@ -239,6 +239,11 @@ def __repr__(self): f"configuration={configuration_ids!r})" ) + def __format__(self, fmt: str) -> str: + from .contrib.diagram.formatter import formatter + + return formatter.render(self, fmt) + def __getstate__(self): state = {k: v for k, v in self.__dict__.items() if not isinstance(v, InstanceState)} del state["_callbacks"] diff --git a/tests/machines/showcase_parallel_compound.py b/tests/machines/showcase_parallel_compound.py new file mode 100644 index 00000000..049def76 --- /dev/null +++ b/tests/machines/showcase_parallel_compound.py @@ -0,0 +1,34 @@ +from statemachine import State +from statemachine import StateChart + + +class ParallelCompoundSC(StateChart): + """Parallel regions with a cross-boundary transition into an inner compound. + + The ``rebuild`` transition targets ``pipeline.build`` — a compound state + inside a parallel region. This is the exact pattern that triggers + `mermaid-js/mermaid#4052 `_; + the Mermaid renderer works around it by redirecting the arrow to the + compound's initial child. + + {statechart:rst} + """ + + class pipeline(State.Parallel, name="Pipeline"): + class build(State.Compound, name="Build"): + compile = State(initial=True) + link = State(final=True) + do_build = compile.to(link) + + class test(State.Compound, name="Test"): + unit = State(initial=True) + e2e = State(final=True) + do_test = unit.to(e2e) + + idle = State(initial=True) + review = State() + + start = idle.to(pipeline) + done_state_pipeline = pipeline.to(review) + rebuild = review.to(pipeline.build) + accept = review.to(idle) diff --git a/tests/machines/showcase_simple.py b/tests/machines/showcase_simple.py index affc1ce1..ca99839d 100644 --- a/tests/machines/showcase_simple.py +++ b/tests/machines/showcase_simple.py @@ -3,6 +3,11 @@ class SimpleSC(StateChart): + """A simple three-state machine. + + {statechart:rst} + """ + idle = State(initial=True) running = State() done = State(final=True) diff --git a/tests/test_contrib_diagram.py b/tests/test_contrib_diagram.py index 3d3a8152..2da3cdd8 100644 --- a/tests/test_contrib_diagram.py +++ b/tests/test_contrib_diagram.py @@ -8,6 +8,7 @@ from statemachine.contrib.diagram import DotGraphMachine from statemachine.contrib.diagram import main from statemachine.contrib.diagram import quickchart_write_svg +from statemachine.contrib.diagram.extract import _format_event_names from statemachine.contrib.diagram.model import ActionType from statemachine.contrib.diagram.model import StateType from statemachine.contrib.diagram.renderers.dot import DotRenderer @@ -161,6 +162,109 @@ def test_generate_complain_about_module_without_sm(self, tmp_path): with pytest.raises(ValueError, match=expected_error): main(["tests.examples", str(out)]) + def test_format_mermaid(self, tmp_path): + out = tmp_path / "sm.mmd" + + main( + [ + "tests.examples.traffic_light_machine.TrafficLightMachine", + str(out), + "--format", + "mermaid", + ] + ) + + content = out.read_text() + assert "stateDiagram-v2" in content + assert "green --> yellow : cycle" in content + + def test_format_md(self, tmp_path): + out = tmp_path / "sm.md" + + main( + [ + "tests.examples.traffic_light_machine.TrafficLightMachine", + str(out), + "--format", + "md", + ] + ) + + content = out.read_text() + assert "| State" in content + assert "cycle" in content + + def test_format_rst(self, tmp_path): + out = tmp_path / "sm.rst" + + main( + [ + "tests.examples.traffic_light_machine.TrafficLightMachine", + str(out), + "--format", + "rst", + ] + ) + + content = out.read_text() + assert "+---" in content + assert "cycle" in content + + def test_format_mermaid_stdout(self, capsys): + main( + [ + "tests.examples.traffic_light_machine.TrafficLightMachine", + "-", + "--format", + "mermaid", + ] + ) + + captured = capsys.readouterr() + assert "stateDiagram-v2" in captured.out + + def test_format_md_stdout(self, capsys): + main( + [ + "tests.examples.traffic_light_machine.TrafficLightMachine", + "-", + "--format", + "md", + ] + ) + + captured = capsys.readouterr() + assert "| State" in captured.out + + def test_stdout_default_svg(self, capsys): + """Default format to stdout writes SVG bytes.""" + main( + [ + "tests.examples.traffic_light_machine.TrafficLightMachine", + "-", + ] + ) + + captured = capsys.readouterr() + assert "\n\n' - b'' - b"" + svg_text = ( + '\n\n' + '' + "" ) directive = self._make_directive() - svg_tag, _, _ = directive._prepare_svg(svg_bytes) + svg_tag, _, _ = directive._prepare_svg(svg_text) assert not svg_tag.startswith("" in svg_tag def test_extracts_intrinsic_dimensions(self): - svg_bytes = b'' + svg_text = '' directive = self._make_directive() - _, w, h = directive._prepare_svg(svg_bytes) + _, w, h = directive._prepare_svg(svg_text) assert w == "702pt" assert h == "170pt" def test_removes_fixed_dimensions(self): - svg_bytes = b'' + svg_text = '' directive = self._make_directive() - svg_tag, _, _ = directive._prepare_svg(svg_bytes) + svg_tag, _, _ = directive._prepare_svg(svg_text) assert 'width="702pt"' not in svg_tag assert 'height="170pt"' not in svg_tag assert "viewBox" in svg_tag def test_handles_no_dimensions(self): - svg_bytes = b'' + svg_text = '' directive = self._make_directive() - _, w, h = directive._prepare_svg(svg_bytes) + _, w, h = directive._prepare_svg(svg_text) assert w == "" assert h == "" def test_handles_px_dimensions(self): - svg_bytes = b'' + svg_text = '' directive = self._make_directive() - _, w, h = directive._prepare_svg(svg_bytes) + _, w, h = directive._prepare_svg(svg_text) assert w == "200px" assert h == "100px" @@ -927,15 +1133,15 @@ def _make_directive(self, options=None, tmp_path=None): def test_no_target_option(self): directive = self._make_directive() - assert directive._resolve_target(b"") == "" + assert directive._resolve_target("") == "" def test_explicit_target_url(self): directive = self._make_directive({"target": "https://example.com/diagram.svg"}) - assert directive._resolve_target(b"") == "https://example.com/diagram.svg" + assert directive._resolve_target("") == "https://example.com/diagram.svg" def test_empty_target_generates_file(self, tmp_path): directive = self._make_directive({"target": ""}, tmp_path=tmp_path) - svg_data = b"" + svg_data = "" result = directive._resolve_target(svg_data) assert result.startswith("/_images/statemachine-") @@ -945,21 +1151,21 @@ def test_empty_target_generates_file(self, tmp_path): images_dir = tmp_path / "_images" svg_files = list(images_dir.glob("statemachine-*.svg")) assert len(svg_files) == 1 - assert svg_files[0].read_bytes() == svg_data + assert svg_files[0].read_text(encoding="utf-8") == svg_data def test_empty_target_deterministic_filename(self, tmp_path): """Same qualname + events produces the same filename.""" directive1 = self._make_directive({"target": "", "events": "go"}, tmp_path=tmp_path) directive2 = self._make_directive({"target": "", "events": "go"}, tmp_path=tmp_path) - result1 = directive1._resolve_target(b"1") - result2 = directive2._resolve_target(b"2") + result1 = directive1._resolve_target("1") + result2 = directive2._resolve_target("2") assert result1 == result2 def test_different_events_different_filename(self, tmp_path): """Different events produce different filenames.""" d1 = self._make_directive({"target": "", "events": "a"}, tmp_path=tmp_path) d2 = self._make_directive({"target": "", "events": "b"}, tmp_path=tmp_path) - assert d1._resolve_target(b"") != d2._resolve_target(b"") + assert d1._resolve_target("") != d2._resolve_target("") class TestDirectiveRun: @@ -1131,3 +1337,450 @@ def test_graph_reflects_active_state(self): svg_root = _parse_svg(sm._graph()) yellow_node = _find_state_node(svg_root, "yellow") assert yellow_node is not None + + +class TestFormat: + """Tests for StateChart.__format__ (instance and class).""" + + def test_format_mermaid_instance(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + sm = TrafficLightMachine() + result = f"{sm:mermaid}" + assert "stateDiagram-v2" in result + assert "green:::active" in result + + def test_format_mermaid_class(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + result = f"{TrafficLightMachine:mermaid}" + assert "stateDiagram-v2" in result + assert "green --> yellow : cycle" in result + + def test_format_md_instance(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + sm = TrafficLightMachine() + result = f"{sm:md}" + assert "| State" in result + assert "cycle" in result + + def test_format_md_class(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + result = f"{TrafficLightMachine:md}" + assert "| State" in result + + def test_format_markdown_alias(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + result = format(TrafficLightMachine, "markdown") + assert "| State" in result + + def test_format_rst_instance(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + sm = TrafficLightMachine() + result = f"{sm:rst}" + assert "+---" in result + + def test_format_rst_class(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + result = f"{TrafficLightMachine:rst}" + assert "+---" in result + + def test_format_dot_instance(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + sm = TrafficLightMachine() + result = f"{sm:dot}" + assert result.startswith("digraph TrafficLightMachine {") + assert "green" in result + + def test_format_dot_class(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + result = f"{TrafficLightMachine:dot}" + assert result.startswith("digraph TrafficLightMachine {") + + def test_format_empty_falls_back_to_repr(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + sm = TrafficLightMachine() + result = f"{sm:}" + assert "TrafficLightMachine(" in result + + def test_format_empty_class(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + result = f"{TrafficLightMachine:}" + assert "TrafficLightMachine" in result + + def test_format_invalid_raises(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + sm = TrafficLightMachine() + with pytest.raises(ValueError, match="Unsupported format"): + f"{sm:invalid}" + + def test_format_invalid_class_raises(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + with pytest.raises(ValueError, match="Unsupported format"): + f"{TrafficLightMachine:invalid}" + + +class TestDocstringExpansion: + """Tests for {statechart:FORMAT} placeholder expansion in docstrings.""" + + def test_md_placeholder(self): + from statemachine.state import State + from statemachine.statemachine import StateChart + + class MyMachine(StateChart): + """Machine. + + {statechart:md} + """ + + s1 = State(initial=True) + s2 = State(final=True) + + go = s1.to(s2) + + assert "| State" in MyMachine.__doc__ + assert "{statechart:md}" not in MyMachine.__doc__ + + def test_rst_placeholder(self): + from statemachine.state import State + from statemachine.statemachine import StateChart + + class MyMachine(StateChart): + """Machine. + + {statechart:rst} + """ + + s1 = State(initial=True) + s2 = State(final=True) + + go = s1.to(s2) + + assert "+---" in MyMachine.__doc__ + assert "{statechart:rst}" not in MyMachine.__doc__ + + def test_mermaid_placeholder(self): + from statemachine.state import State + from statemachine.statemachine import StateChart + + class MyMachine(StateChart): + """{statechart:mermaid}""" + + s1 = State(initial=True) + s2 = State(final=True) + + go = s1.to(s2) + + assert "stateDiagram-v2" in MyMachine.__doc__ + + def test_no_placeholder_unchanged(self): + from statemachine.state import State + from statemachine.statemachine import StateChart + + class MyMachine(StateChart): + """Just a plain docstring.""" + + s1 = State(initial=True) + s2 = State(final=True) + + go = s1.to(s2) + + assert MyMachine.__doc__ == "Just a plain docstring." + + def test_no_docstring(self): + from statemachine.state import State + from statemachine.statemachine import StateChart + + class MyMachine(StateChart): + s1 = State(initial=True) + s2 = State(final=True) + + go = s1.to(s2) + + assert MyMachine.__doc__ is None + + def test_indentation_preserved(self): + from statemachine.state import State + from statemachine.statemachine import StateChart + + class MyMachine(StateChart): + __doc__ = "Doc.\n\n Table:\n\n {statechart:md}\n\n End.\n" + + s1 = State(initial=True) + s2 = State(final=True) + + go = s1.to(s2) + + lines = MyMachine.__doc__.split("\n") + table_lines = [line for line in lines if "|" in line] + for line in table_lines: + assert line.startswith(" |") + assert "End." in MyMachine.__doc__ + + def test_multiple_placeholders(self): + from statemachine.state import State + from statemachine.statemachine import StateChart + + class MyMachine(StateChart): + """MD: {statechart:md} + + Mermaid: {statechart:mermaid} + """ + + s1 = State(initial=True) + s2 = State(final=True) + + go = s1.to(s2) + + assert "| State" in MyMachine.__doc__ + assert "stateDiagram-v2" in MyMachine.__doc__ + + +class TestFormatter: + """Tests for the Formatter facade (render, register_format, supported_formats).""" + + def test_render_mermaid(self): + from statemachine.contrib.diagram import formatter + + from tests.examples.traffic_light_machine import TrafficLightMachine + + result = formatter.render(TrafficLightMachine, "mermaid") + assert "stateDiagram-v2" in result + + def test_render_dot(self): + from statemachine.contrib.diagram import formatter + + from tests.examples.traffic_light_machine import TrafficLightMachine + + result = formatter.render(TrafficLightMachine, "dot") + assert result.startswith("digraph TrafficLightMachine {") + + def test_render_svg(self): + from statemachine.contrib.diagram import formatter + + from tests.examples.traffic_light_machine import TrafficLightMachine + + result = formatter.render(TrafficLightMachine, "svg") + assert isinstance(result, str) + assert " s1" in result + assert "s1 --> s2 : go" in result + + def test_initial_and_final(self): + graph = DiagramGraph( + name="InitFinal", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.FINAL), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2"], event="finish"), + ], + ) + result = MermaidRenderer().render(graph) + assert "[*] --> s1" in result + assert "s2 --> [*]" in result + + def test_custom_direction(self): + config = MermaidRendererConfig(direction="TB") + graph = DiagramGraph( + name="TB", + states=[DiagramState(id="a", name="A", type=StateType.REGULAR, is_initial=True)], + ) + result = MermaidRenderer(config=config).render(graph) + assert "direction TB" in result + + def test_state_name_differs_from_id(self): + graph = DiagramGraph( + name="Named", + states=[ + DiagramState( + id="my_state", name="My State", type=StateType.REGULAR, is_initial=True + ), + ], + ) + result = MermaidRenderer().render(graph) + assert 'state "My State" as my_state' in result + + def test_state_name_equals_id_no_declaration(self): + """When name == id, no explicit state declaration is emitted.""" + graph = DiagramGraph( + name="NoDecl", + states=[ + DiagramState(id="s1", name="s1", type=StateType.REGULAR, is_initial=True), + ], + ) + result = MermaidRenderer().render(graph) + assert 'state "s1"' not in result + + +class TestMermaidRendererTransitions: + """Transition rendering tests.""" + + def test_transition_with_guards(self): + graph = DiagramGraph( + name="Guards", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2"], event="go", guards=["is_ready"]), + ], + ) + result = MermaidRenderer().render(graph) + assert "s1 --> s2 : go [is_ready]" in result + + def test_eventless_transition(self): + graph = DiagramGraph( + name="Eventless", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2"], event=""), + ], + ) + result = MermaidRenderer().render(graph) + assert "s1 --> s2\n" in result + + def test_self_transition(self): + graph = DiagramGraph( + name="SelfLoop", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s1"], event="tick"), + ], + ) + result = MermaidRenderer().render(graph) + assert "s1 --> s1 : tick" in result + + def test_targetless_transition(self): + graph = DiagramGraph( + name="Targetless", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + ], + transitions=[ + DiagramTransition(source="s1", targets=[], event="tick"), + ], + ) + result = MermaidRenderer().render(graph) + assert "s1 --> s1 : tick" in result + + def test_multi_target_transition(self): + graph = DiagramGraph( + name="Multi", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + DiagramState(id="s3", name="S3", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2", "s3"], event="split"), + ], + ) + result = MermaidRenderer().render(graph) + assert "s1 --> s2 : split" in result + assert "s1 --> s3 : split" in result + + def test_internal_transitions_skipped(self): + graph = DiagramGraph( + name="Internal", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s1"], event="check", is_internal=True), + ], + ) + result = MermaidRenderer().render(graph) + assert "s1 --> s1" not in result + + def test_initial_transitions_skipped(self): + graph = DiagramGraph( + name="InitTrans", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2"], event="", is_initial=True), + ], + ) + result = MermaidRenderer().render(graph) + # Implicit initial transitions are NOT rendered as edges + assert "s1 --> s2" not in result + + +class TestMermaidRendererActiveState: + """Active state highlighting tests.""" + + def test_active_state_class(self): + graph = DiagramGraph( + name="Active", + states=[ + DiagramState( + id="s1", name="S1", type=StateType.REGULAR, is_initial=True, is_active=True + ), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2"], event="go"), + ], + ) + result = MermaidRenderer().render(graph) + assert "classDef active" in result + assert "s1:::active" in result + assert "s2:::active" not in result + + def test_no_active_state_no_classdef(self): + graph = DiagramGraph( + name="NoActive", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + ], + ) + result = MermaidRenderer().render(graph) + assert "classDef" not in result + + def test_active_fill_config(self): + config = MermaidRendererConfig(active_fill="#FF0000", active_stroke="#000") + graph = DiagramGraph( + name="CustomActive", + states=[ + DiagramState( + id="s1", name="S1", type=StateType.REGULAR, is_initial=True, is_active=True + ), + ], + ) + result = MermaidRenderer(config=config).render(graph) + assert "fill:#FF0000" in result + assert "stroke:#000" in result + + +class TestMermaidRendererCompound: + """Compound and parallel state tests.""" + + def test_compound_state(self): + class SM(StateChart): + class parent(State.Compound, name="Parent"): + child1 = State(initial=True) + child2 = State(final=True) + go = child1.to(child2) + + start = State(initial=True) + end = State(final=True) + + enter = start.to(parent) + finish = parent.to(end) + + result = MermaidGraphMachine(SM).get_mermaid() + assert 'state "Parent" as parent {' in result + assert "[*] --> child1" in result + assert "child1 --> child2 : go" in result + assert "child2 --> [*]" in result + assert "start --> parent : enter" in result + assert "parent --> end : finish" in result + + def test_compound_no_duplicate_transitions(self): + """Transitions inside compound states must not also appear at top level.""" + + class SM(StateChart): + class parent(State.Compound, name="Parent"): + child1 = State(initial=True) + child2 = State(final=True) + go = child1.to(child2) + + start = State(initial=True) + enter = start.to(parent) + + result = MermaidGraphMachine(SM).get_mermaid() + # "child1 --> child2 : go" should appear exactly once (inside compound) + assert result.count("child1 --> child2 : go") == 1 + + def test_parallel_state(self): + class SM(StateChart): + class p(State.Parallel, name="Parallel"): + class r1(State.Compound, name="Region1"): + a = State(initial=True) + a_done = State(final=True) + finish_a = a.to(a_done) + + class r2(State.Compound, name="Region2"): + b = State(initial=True) + b_done = State(final=True) + finish_b = b.to(b_done) + + start = State(initial=True) + begin = start.to(p) + + result = MermaidGraphMachine(SM).get_mermaid() + assert 'state "Parallel" as p {' in result + assert "--" in result # parallel separator + + def test_parallel_redirects_compound_endpoints(self): + """Transitions to/from compound states inside parallel regions are redirected + to the initial child (Mermaid workaround for mermaid-js/mermaid#4052).""" + + class SM(StateChart): + class p(State.Parallel, name="Parallel"): + class region1(State.Compound, name="Region1"): + idle = State(initial=True) + + class inner(State.Compound, name="Inner"): + working = State(initial=True) + + start = idle.to(inner) + + class region2(State.Compound, name="Region2"): + x = State(initial=True) + + begin = State(initial=True) + enter = begin.to(p) + + result = MermaidGraphMachine(SM).get_mermaid() + # Inside parallel: compound endpoint redirected to initial child + assert "idle --> working : start" in result + assert "idle --> inner" not in result + + def test_compound_outside_parallel_not_redirected(self): + """Compound states outside parallel regions keep direct transitions.""" + + class SM(StateChart): + class parent(State.Compound, name="Parent"): + child = State(initial=True) + + start = State(initial=True) + end = State(final=True) + enter = start.to(parent) + leave = parent.to(end) + + result = MermaidGraphMachine(SM).get_mermaid() + assert "start --> parent : enter" in result + assert "parent --> end : leave" in result + + def test_nested_compound(self): + class SM(StateChart): + class outer(State.Compound, name="Outer"): + class inner(State.Compound, name="Inner"): + deep = State(initial=True) + deep_final = State(final=True) + go_deep = deep.to(deep_final) + + start_inner = State(initial=True) + to_inner = start_inner.to(inner) + + begin = State(initial=True) + enter = begin.to(outer) + + result = MermaidGraphMachine(SM).get_mermaid() + assert 'state "Outer" as outer {' in result + assert 'state "Inner" as inner {' in result + + +class TestMermaidRendererPseudoStates: + """Pseudo-state rendering tests.""" + + def test_history_shallow(self): + graph = DiagramGraph( + name="History", + states=[ + DiagramState( + id="comp", + name="Comp", + type=StateType.REGULAR, + is_initial=True, + children=[ + DiagramState(id="h", name="H", type=StateType.HISTORY_SHALLOW), + DiagramState(id="c1", name="C1", type=StateType.REGULAR, is_initial=True), + ], + ), + ], + compound_state_ids={"comp"}, + ) + result = MermaidRenderer().render(graph) + assert 'state "H" as h' in result + + def test_history_deep(self): + graph = DiagramGraph( + name="DeepHistory", + states=[ + DiagramState( + id="comp", + name="Comp", + type=StateType.REGULAR, + is_initial=True, + children=[ + DiagramState(id="h", name="H*", type=StateType.HISTORY_DEEP), + DiagramState(id="c1", name="C1", type=StateType.REGULAR, is_initial=True), + ], + ), + ], + compound_state_ids={"comp"}, + ) + result = MermaidRenderer().render(graph) + assert 'state "H*" as h' in result + + def test_choice_state(self): + graph = DiagramGraph( + name="Choice", + states=[ + DiagramState(id="ch", name="ch", type=StateType.CHOICE, is_initial=True), + ], + ) + result = MermaidRenderer().render(graph) + assert "state ch <>" in result + + def test_fork_state(self): + graph = DiagramGraph( + name="Fork", + states=[ + DiagramState(id="fk", name="fk", type=StateType.FORK, is_initial=True), + ], + ) + result = MermaidRenderer().render(graph) + assert "state fk <>" in result + + def test_join_state(self): + graph = DiagramGraph( + name="Join", + states=[ + DiagramState(id="jn", name="jn", type=StateType.JOIN, is_initial=True), + ], + ) + result = MermaidRenderer().render(graph) + assert "state jn <>" in result + + +class TestMermaidRendererActions: + """State action rendering tests.""" + + def test_entry_exit_actions(self): + graph = DiagramGraph( + name="Actions", + states=[ + DiagramState( + id="s1", + name="S1", + type=StateType.REGULAR, + is_initial=True, + actions=[ + DiagramAction(type=ActionType.ENTRY, body="setup"), + DiagramAction(type=ActionType.EXIT, body="cleanup"), + ], + ), + ], + ) + result = MermaidRenderer().render(graph) + assert "s1 : entry / setup" in result + assert "s1 : exit / cleanup" in result + + def test_internal_action(self): + graph = DiagramGraph( + name="InternalAction", + states=[ + DiagramState( + id="s1", + name="S1", + type=StateType.REGULAR, + is_initial=True, + actions=[ + DiagramAction(type=ActionType.INTERNAL, body="tick / handle"), + ], + ), + ], + ) + result = MermaidRenderer().render(graph) + assert "s1 : tick / handle" in result + + def test_empty_internal_action_skipped(self): + graph = DiagramGraph( + name="EmptyInternal", + states=[ + DiagramState( + id="s1", + name="S1", + type=StateType.REGULAR, + is_initial=True, + actions=[ + DiagramAction(type=ActionType.INTERNAL, body=""), + ], + ), + ], + ) + result = MermaidRenderer().render(graph) + assert "s1 : " not in result + + +class TestMermaidGraphMachine: + """Tests for the MermaidGraphMachine facade.""" + + def test_facade_returns_string(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + result = MermaidGraphMachine(TrafficLightMachine).get_mermaid() + assert isinstance(result, str) + assert "stateDiagram-v2" in result + + def test_facade_callable(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + facade = MermaidGraphMachine(TrafficLightMachine) + assert facade() == facade.get_mermaid() + + def test_facade_with_instance(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + sm = TrafficLightMachine() + result = MermaidGraphMachine(sm).get_mermaid() + assert "green:::active" in result + + def test_facade_custom_config(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + class Custom(MermaidGraphMachine): + direction = "TB" + active_fill = "#FF0000" + + sm = TrafficLightMachine() + result = Custom(sm).get_mermaid() + assert "direction TB" in result + assert "fill:#FF0000" in result + + +class TestMermaidRendererEdgeCases: + """Edge case tests for coverage.""" + + def test_compound_state_name_equals_id(self): + """Compound state where name == id uses unquoted declaration.""" + graph = DiagramGraph( + name="NameId", + states=[ + DiagramState( + id="comp", + name="comp", + type=StateType.REGULAR, + is_initial=True, + children=[ + DiagramState(id="c1", name="C1", type=StateType.REGULAR, is_initial=True), + ], + ), + ], + compound_state_ids={"comp"}, + ) + result = MermaidRenderer().render(graph) + assert "state comp {" in result + assert '"comp"' not in result + + def test_active_compound_state(self): + """Compound state that is active gets classDef.""" + graph = DiagramGraph( + name="ActiveComp", + states=[ + DiagramState( + id="comp", + name="Comp", + type=StateType.REGULAR, + is_initial=True, + is_active=True, + children=[ + DiagramState(id="c1", name="C1", type=StateType.REGULAR, is_initial=True), + ], + ), + ], + compound_state_ids={"comp"}, + ) + result = MermaidRenderer().render(graph) + assert "comp:::active" in result + + def test_cross_scope_transition_rendered_at_parent(self): + """Transition crossing compound boundaries is rendered at the parent scope.""" + graph = DiagramGraph( + name="CrossScope", + states=[ + DiagramState( + id="comp", + name="Comp", + type=StateType.REGULAR, + is_initial=True, + children=[ + DiagramState(id="c1", name="C1", type=StateType.REGULAR, is_initial=True), + ], + ), + DiagramState(id="outside", name="Outside", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="c1", targets=["outside"], event="leave"), + ], + compound_state_ids={"comp"}, + ) + result = MermaidRenderer().render(graph) + # c1 is inside comp, outside is at top level — the transition + # crosses the compound boundary and is rendered at the top scope. + assert "c1 --> outside : leave" in result + # It should NOT appear inside the compound block + lines = result.split("\n") + for line in lines: + if "c1 --> outside" in line: + # Should be at indent level 1 (top scope), not deeper + assert line.startswith(" c1"), f"Expected top-level indent, got: {line!r}" + + def test_cross_scope_to_history_state(self): + """Transition from outside a compound to a history state inside it is rendered.""" + graph = DiagramGraph( + name="HistoryCross", + states=[ + DiagramState( + id="process", + name="Process", + type=StateType.REGULAR, + children=[ + DiagramState( + id="step1", name="Step1", type=StateType.REGULAR, is_initial=True + ), + DiagramState(id="step2", name="Step2", type=StateType.REGULAR), + DiagramState(id="h", name="H", type=StateType.HISTORY_SHALLOW), + ], + ), + DiagramState(id="paused", name="Paused", type=StateType.REGULAR, is_initial=True), + ], + transitions=[ + DiagramTransition(source="step1", targets=["step2"], event="advance"), + DiagramTransition(source="process", targets=["paused"], event="pause"), + DiagramTransition(source="paused", targets=["h"], event="resume"), + DiagramTransition(source="paused", targets=["process"], event="begin"), + ], + compound_state_ids={"process"}, + ) + result = MermaidRenderer().render(graph) + # The resume transition crosses the compound boundary + assert "paused --> h : resume" in result + # advance stays inside the compound + assert "step1 --> step2 : advance" in result + # pause and begin are at top level (both endpoints are top-level) + assert "process --> paused : pause" in result + assert "paused --> process : begin" in result + + def test_no_initial_state(self): + """Graph with no initial state omits [*] arrow.""" + graph = DiagramGraph( + name="NoInitial", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR), + ], + ) + result = MermaidRenderer().render(graph) + assert "[*]" not in result + + def test_duplicate_transition_rendered_once(self): + """Duplicate transitions in the IR are rendered only once.""" + graph = DiagramGraph( + name="Dedup", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2"], event="go"), + DiagramTransition(source="s1", targets=["s2"], event="go"), + ], + ) + result = MermaidRenderer().render(graph) + assert result.count("s1 --> s2 : go") == 1 + + def test_compound_no_initial_child(self): + """Compound state with no initial child omits internal [*] arrow.""" + graph = DiagramGraph( + name="NoInitChild", + states=[ + DiagramState( + id="comp", + name="Comp", + type=StateType.REGULAR, + is_initial=True, + children=[ + DiagramState(id="c1", name="C1", type=StateType.REGULAR), + ], + ), + ], + compound_state_ids={"comp"}, + ) + result = MermaidRenderer().render(graph) + # No [*] --> c1 inside the compound + lines = result.strip().split("\n") + inner_initial = [ln for ln in lines if "[*] --> c1" in ln] + assert len(inner_initial) == 0 + + +class TestMermaidRendererIntegration: + """Integration tests with real state machines.""" + + def test_traffic_light(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + result = MermaidGraphMachine(TrafficLightMachine).get_mermaid() + assert "green --> yellow : cycle" in result + assert "yellow --> red : cycle" in result + assert "red --> green : cycle" in result + + def test_traffic_light_with_events(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + sm = TrafficLightMachine() + sm.send("cycle") + result = MermaidGraphMachine(sm).get_mermaid() + assert "yellow:::active" in result diff --git a/tests/test_transition_table.py b/tests/test_transition_table.py new file mode 100644 index 00000000..198cf495 --- /dev/null +++ b/tests/test_transition_table.py @@ -0,0 +1,201 @@ +from statemachine.contrib.diagram.extract import extract +from statemachine.contrib.diagram.model import DiagramGraph +from statemachine.contrib.diagram.model import DiagramState +from statemachine.contrib.diagram.model import DiagramTransition +from statemachine.contrib.diagram.model import StateType +from statemachine.contrib.diagram.renderers.table import TransitionTableRenderer + +from statemachine import State +from statemachine import StateChart + + +class TestTransitionTableMarkdown: + """Markdown transition table tests.""" + + def test_simple_table(self): + graph = DiagramGraph( + name="Simple", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2"], event="go"), + ], + ) + result = TransitionTableRenderer().render(graph, fmt="md") + assert "| State" in result + assert "| Event" in result + assert "| Guard" in result + assert "| Target" in result + assert "| S1" in result + assert "go" in result + assert "| S2" in result + + def test_with_guards(self): + graph = DiagramGraph( + name="Guards", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2"], event="go", guards=["is_ready"]), + ], + ) + result = TransitionTableRenderer().render(graph, fmt="md") + assert "is_ready" in result + + def test_multiple_targets(self): + graph = DiagramGraph( + name="Multi", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + DiagramState(id="s3", name="S3", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2", "s3"], event="split"), + ], + ) + result = TransitionTableRenderer().render(graph, fmt="md") + lines = result.strip().split("\n") + # Header + separator + 2 data rows + assert len(lines) == 4 + + def test_skips_initial_transitions(self): + graph = DiagramGraph( + name="SkipInit", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2"], event="", is_initial=True), + DiagramTransition(source="s1", targets=["s2"], event="go"), + ], + ) + result = TransitionTableRenderer().render(graph, fmt="md") + lines = result.strip().split("\n") + # Header + separator + 1 data row (initial skipped) + assert len(lines) == 3 + + def test_skips_internal_transitions(self): + graph = DiagramGraph( + name="SkipInternal", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s1"], event="check", is_internal=True), + ], + ) + result = TransitionTableRenderer().render(graph, fmt="md") + lines = result.strip().split("\n") + # Header + separator only (no data rows) + assert len(lines) == 2 + + def test_targetless_transition(self): + graph = DiagramGraph( + name="Targetless", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + ], + transitions=[ + DiagramTransition(source="s1", targets=[], event="tick"), + ], + ) + result = TransitionTableRenderer().render(graph, fmt="md") + assert "tick" in result + # Target falls back to source name + assert "S1" in result + + +class TestTransitionTableRST: + """RST grid table tests.""" + + def test_rst_format(self): + graph = DiagramGraph( + name="RST", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2"], event="go"), + ], + ) + result = TransitionTableRenderer().render(graph, fmt="rst") + assert "+---" in result + assert "|" in result + assert "====" in result # header separator + assert "go" in result + + def test_rst_with_guards(self): + graph = DiagramGraph( + name="RSTGuards", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2"], event="go", guards=["is_ready"]), + ], + ) + result = TransitionTableRenderer().render(graph, fmt="rst") + assert "is_ready" in result + + +class TestTransitionTableIntegration: + """Integration tests with real state machines.""" + + def test_traffic_light_md(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + ir = extract(TrafficLightMachine) + result = TransitionTableRenderer().render(ir, fmt="md") + assert "Green" in result + assert "Yellow" in result + assert "Red" in result + assert "cycle" in result + + def test_traffic_light_rst(self): + from tests.examples.traffic_light_machine import TrafficLightMachine + + ir = extract(TrafficLightMachine) + result = TransitionTableRenderer().render(ir, fmt="rst") + assert "Green" in result + assert "cycle" in result + assert "+---" in result + + def test_compound_state_names(self): + """Child state names are properly resolved.""" + + class SM(StateChart): + class parent(State.Compound, name="Parent"): + child1 = State(initial=True) + child2 = State(final=True) + go = child1.to(child2) + + start = State(initial=True) + enter = start.to(parent) + + ir = extract(SM) + result = TransitionTableRenderer().render(ir, fmt="md") + assert "Child1" in result + assert "Child2" in result + + def test_default_format_is_md(self): + """render() without fmt defaults to markdown.""" + graph = DiagramGraph( + name="Default", + states=[ + DiagramState(id="s1", name="S1", type=StateType.REGULAR, is_initial=True), + DiagramState(id="s2", name="S2", type=StateType.REGULAR), + ], + transitions=[ + DiagramTransition(source="s1", targets=["s2"], event="go"), + ], + ) + result = TransitionTableRenderer().render(graph) + assert "| State" in result # markdown uses pipes diff --git a/uv.lock b/uv.lock index 4da05165..8ccdb666 100644 --- a/uv.lock +++ b/uv.lock @@ -1114,6 +1114,8 @@ dev = [ { name = "sphinx-autobuild" }, { name = "sphinx-copybutton" }, { name = "sphinx-gallery" }, + { name = "sphinxcontrib-mermaid", version = "1.2.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "sphinxcontrib-mermaid", version = "2.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, ] [package.metadata] @@ -1147,6 +1149,7 @@ dev = [ { name = "sphinx-autobuild", marker = "python_full_version >= '3.9'" }, { name = "sphinx-copybutton", marker = "python_full_version >= '3.9'", specifier = ">=0.5.2" }, { name = "sphinx-gallery", marker = "python_full_version >= '3.9'" }, + { name = "sphinxcontrib-mermaid", marker = "python_full_version >= '3.9'" }, ] [[package]] @@ -1389,6 +1392,39 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c2/42/4c8646762ee83602e3fb3fbe774c2fac12f317deb0b5dbeeedd2d3ba4b77/sphinxcontrib_jsmath-1.0.1-py2.py3-none-any.whl", hash = "sha256:2ec2eaebfb78f3f2078e73666b1415417a116cc848b72e5172e596c871103178", size = 5071, upload-time = "2019-01-21T16:10:14.333Z" }, ] +[[package]] +name = "sphinxcontrib-mermaid" +version = "1.2.3" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.10'", +] +dependencies = [ + { name = "pyyaml", marker = "python_full_version < '3.10'" }, + { name = "sphinx", marker = "python_full_version < '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f5/49/c6ddfe709a4ab76ac6e5a00e696f73626b2c189dc1e1965a361ec102e6cc/sphinxcontrib_mermaid-1.2.3.tar.gz", hash = "sha256:358699d0ec924ef679b41873d9edd97d0773446daf9760c75e18dc0adfd91371", size = 18885, upload-time = "2025-11-26T04:18:32.43Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/39/8b54299ffa00e597d3b0b4d042241a0a0b22cb429ad007ccfb9c1745b4d1/sphinxcontrib_mermaid-1.2.3-py3-none-any.whl", hash = "sha256:5be782b27026bef97bfb15ccb2f7868b674a1afc0982b54cb149702cfc25aa02", size = 13413, upload-time = "2025-11-26T04:18:31.269Z" }, +] + +[[package]] +name = "sphinxcontrib-mermaid" +version = "2.0.1" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.10'", +] +dependencies = [ + { name = "jinja2", marker = "python_full_version >= '3.10'" }, + { name = "pyyaml", marker = "python_full_version >= '3.10'" }, + { name = "sphinx", marker = "python_full_version >= '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2b/ae/999891de292919b66ea34f2c22fc22c9be90ab3536fbc0fca95716277351/sphinxcontrib_mermaid-2.0.1.tar.gz", hash = "sha256:a21a385a059a6cafd192aa3a586b14bf5c42721e229db67b459dc825d7f0a497", size = 19839, upload-time = "2026-03-05T14:10:41.901Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/03/46/25d64bcd7821c8d6f1080e1c43d5fcdfc442a18f759a230b5ccdc891093e/sphinxcontrib_mermaid-2.0.1-py3-none-any.whl", hash = "sha256:9dca7fbe827bad5e7e2b97c4047682cfd26e3e07398cfdc96c7a8842ae7f06e7", size = 14064, upload-time = "2026-03-05T14:10:40.533Z" }, +] + [[package]] name = "sphinxcontrib-qthelp" version = "2.0.0"