diff --git a/cli/alora/commands.py b/cli/alora/commands.py index 4b2f537e3..742a26bc2 100644 --- a/cli/alora/commands.py +++ b/cli/alora/commands.py @@ -1,10 +1,10 @@ -"""Typer sub-application for the ``m alora`` command group. +"""Typer sub-application for the `m alora` command group. -Provides three commands: ``train`` (fine-tune a base causal language model on a JSONL -dataset to produce a LoRA or aLoRA adapter), ``upload`` (push adapter weights to +Provides three commands: `train` (fine-tune a base causal language model on a JSONL +dataset to produce a LoRA or aLoRA adapter), `upload` (push adapter weights to Hugging Face Hub, optionally packaging the adapter as an intrinsic with an -``io.yaml`` configuration), and ``add-readme`` (use an LLM to auto-generate and -upload an ``INTRINSIC_README.md`` for the trained adapter). +`io.yaml` configuration), and `add-readme` (use an LLM to auto-generate and +upload an `INTRINSIC_README.md` for the trained adapter). """ import json @@ -38,8 +38,8 @@ def alora_train( basemodel: Base model ID or path. outfile: Path to save adapter weights. promptfile: Path to load the prompt format file. - adapter: Adapter type; ``"alora"`` or ``"lora"``. - device: Device to train on: ``"auto"``, ``"cpu"``, ``"cuda"``, or ``"mps"``. + adapter: Adapter type; `"alora"` or `"lora"`. + device: Device to train on: `"auto"`, `"cpu"`, `"cuda"`, or `"mps"`. epochs: Number of training epochs. learning_rate: Learning rate for the optimizer. batch_size: Per-device training batch size. @@ -84,10 +84,10 @@ def alora_upload( Args: weight_path: Path to saved adapter weights directory. name: Destination model name on Hugging Face Hub - (e.g. ``"acme/carbchecker-alora"``). - intrinsic: If ``True``, the adapter implements an intrinsic and an - ``io.yaml`` file must also be provided. - io_yaml: Path to the ``io.yaml`` file configuring input/output processing + (e.g. `"acme/carbchecker-alora"`). + intrinsic: If `True`, the adapter implements an intrinsic and an + `io.yaml` file must also be provided. + io_yaml: Path to the `io.yaml` file configuring input/output processing when the model is invoked as an intrinsic. """ from cli.alora.intrinsic_uploader import upload_intrinsic @@ -143,10 +143,10 @@ def alora_add_readme( Args: datafile: JSONL file with item/label pairs used to train the adapter. basemodel: Base model ID or path. - promptfile: Path to the prompt format file, or ``None``. + promptfile: Path to the prompt format file, or `None`. name: Destination model name on Hugging Face Hub. - hints: Path to a file containing additional domain hints, or ``None``. - io_yaml: Path to the ``io.yaml`` intrinsic configuration file, or ``None``. + hints: Path to a file containing additional domain hints, or `None`. + io_yaml: Path to the `io.yaml` intrinsic configuration file, or `None`. Raises: OSError: If no Hugging Face authentication token is found. diff --git a/cli/alora/intrinsic_uploader.py b/cli/alora/intrinsic_uploader.py index fe7c52a88..6eae2eb10 100644 --- a/cli/alora/intrinsic_uploader.py +++ b/cli/alora/intrinsic_uploader.py @@ -1,10 +1,10 @@ """Upload a trained adapter to Hugging Face Hub in the intrinsic directory layout. Creates or updates a private Hugging Face repository and uploads adapter weights -into a ``//`` sub-directory, together with -the required ``io.yaml`` configuration file. If an ``INTRINSIC_README.md`` exists in -the weight directory it is also uploaded as the repository's root ``README.md``. -Requires an authenticated Hugging Face token obtained via ``huggingface-cli login``. +into a `//` sub-directory, together with +the required `io.yaml` configuration file. If an `INTRINSIC_README.md` exists in +the weight directory it is also uploaded as the repository's root `README.md`. +Requires an authenticated Hugging Face token obtained via `huggingface-cli login`. """ import os @@ -27,31 +27,31 @@ def upload_intrinsic( """Upload an adapter to Hugging Face Hub using the intrinsic directory layout. Creates or updates a private Hugging Face repository and uploads adapter - weights into a ``//`` sub-directory, - together with the ``io.yaml`` configuration file. If an - ``INTRINSIC_README.md`` exists in the weight directory it is also uploaded - as the repository root ``README.md``. + weights into a `//` sub-directory, + together with the `io.yaml` configuration file. If an + `INTRINSIC_README.md` exists in the weight directory it is also uploaded + as the repository root `README.md`. Args: weight_path (str): Local directory containing the adapter weights - (output of ``save_pretrained``). + (output of `save_pretrained`). model_name (str): Target Hugging Face repository name in - ``"/"`` format (e.g. ``"acme/carbchecker-alora"``). + `"/"` format (e.g. `"acme/carbchecker-alora"`). base_model (str): Base model ID or path (e.g. - ``"ibm-granite/granite-3.3-2b-instruct"``). Must contain at most - one ``"/"`` separator. + `"ibm-granite/granite-3.3-2b-instruct"`). Must contain at most + one `"/"` separator. type (Literal["lora", "alora"]): Adapter type, used as the leaf directory name in the repository layout. - io_yaml (str): Path to the ``io.yaml`` configuration file for + io_yaml (str): Path to the `io.yaml` configuration file for intrinsic input/output processing. private (bool): Whether the repository should be private. Currently - only ``True`` is supported. + only `True` is supported. Raises: - AssertionError: If ``weight_path`` or ``io_yaml`` do not exist, if - ``private`` is ``False``, if ``base_model`` contains more than one - ``"/"`` separator, or if ``model_name`` does not contain exactly - one ``"/"`` separator. + AssertionError: If `weight_path` or `io_yaml` do not exist, if + `private` is `False`, if `base_model` contains more than one + `"/"` separator, or if `model_name` does not contain exactly + one `"/"` separator. OSError: If no Hugging Face authentication token is found. """ try: diff --git a/cli/alora/readme_generator.py b/cli/alora/readme_generator.py index de2643e20..bda96987a 100644 --- a/cli/alora/readme_generator.py +++ b/cli/alora/readme_generator.py @@ -1,10 +1,10 @@ """LLM-assisted generator for adapter intrinsic README files. -Uses a ``MelleaSession`` with rejection sampling to derive README template variables +Uses a `MelleaSession` with rejection sampling to derive README template variables from a JSONL training dataset — including a high-level description, the inferred Python argument list, and Jinja2-renderable sample rows. Validates the generated output with deterministic requirements (correct naming conventions, syntactically -valid argument lists) before rendering the final ``INTRINSIC_README.md`` via a +valid argument lists) before rendering the final `INTRINSIC_README.md` via a Jinja2 template. """ @@ -28,10 +28,10 @@ class ReadmeTemplateVars(BaseModel): high_level_description (str): A 2-3 sentence description of what the intrinsic adapter does. dataset_description (str): Brief description of the training dataset contents and format. userid (str): HuggingFace user ID (the namespace portion of the model name). - intrinsic_name (str): Short snake_case identifier for the intrinsic (e.g. ``"carbchecker"``). - intrinsic_name_camelcase (str): CamelCase version of ``intrinsic_name`` (e.g. ``"CarbChecker"``). - arglist (str): Python function argument list with type hints (e.g. ``"description: str"``). - arglist_without_type_annotations (str): Argument list without type hints (e.g. ``"description"``). + intrinsic_name (str): Short snake_case identifier for the intrinsic (e.g. `"carbchecker"`). + intrinsic_name_camelcase (str): CamelCase version of `intrinsic_name` (e.g. `"CarbChecker"`). + arglist (str): Python function argument list with type hints (e.g. `"description: str"`). + arglist_without_type_annotations (str): Argument list without type hints (e.g. `"description"`). """ high_level_description: str @@ -141,20 +141,20 @@ def make_readme_jinja_dict( """Generate all template variables for the intrinsic README using an LLM. Loads the first five lines of the JSONL dataset, determines the input structure, - and uses ``m.instruct`` with deterministic requirements and rejection sampling to + and uses `m.instruct` with deterministic requirements and rejection sampling to generate README template variables. Args: - m: Active ``MelleaSession`` to use for LLM generation. + m: Active `MelleaSession` to use for LLM generation. dataset_path: Path to the JSONL training dataset file. base_model: Base model ID or path used to train the adapter. prompt_file: Path to the prompt format file (empty string if not provided). name: Destination model name on Hugging Face Hub - (e.g. ``"acme/carbchecker-alora"``). + (e.g. `"acme/carbchecker-alora"`). hints: Optional string of additional domain hints to include in the prompt. Returns: - Dict of Jinja2 template variables for rendering the ``INTRINSIC_README.md``. + Dict of Jinja2 template variables for rendering the `INTRINSIC_README.md`. """ # Load first 5 lines of the dataset. samples = [] @@ -294,19 +294,19 @@ def generate_readme( ) -> str: """Generate an INTRINSIC_README.md file from the dataset and template. - Creates a ``MelleaSession``, uses the LLM to generate template variables, - renders the Jinja template, and writes the result to ``output_path``. + Creates a `MelleaSession`, uses the LLM to generate template variables, + renders the Jinja template, and writes the result to `output_path`. Args: dataset_path: Path to the JSONL training dataset file. base_model: Base model ID or path used to train the adapter. - prompt_file: Path to the prompt format file, or ``None``. + prompt_file: Path to the prompt format file, or `None`. output_path: Destination path for the generated README file. name: Destination model name on Hugging Face Hub. hints: Optional string of additional domain hints for the LLM. Returns: - The path to the written output file (same as ``output_path``). + The path to the written output file (same as `output_path`). """ from jinja2 import Environment, FileSystemLoader diff --git a/cli/alora/train.py b/cli/alora/train.py index bda54ae39..a6a9bb203 100644 --- a/cli/alora/train.py +++ b/cli/alora/train.py @@ -1,10 +1,10 @@ """Fine-tune a causal language model to produce a LoRA or aLoRA adapter. -Loads a JSONL dataset of ``item``/``label`` pairs, applies an 80/20 train/validation -split, and trains using HuggingFace PEFT and TRL's ``SFTTrainer`` — saving the +Loads a JSONL dataset of `item`/`label` pairs, applies an 80/20 train/validation +split, and trains using HuggingFace PEFT and TRL's `SFTTrainer` — saving the checkpoint with the lowest validation loss. Supports CUDA, MPS (macOS, PyTorch ≥ 2.8), and CPU device selection, and handles the -``alora_invocation_tokens`` configuration required for aLoRA training. +`alora_invocation_tokens` configuration required for aLoRA training. """ import json @@ -42,18 +42,18 @@ def load_dataset_from_json(json_path, tokenizer, invocation_prompt): """Load a JSONL dataset and format it for SFT training. - Reads ``item``/``label`` pairs from a JSONL file and builds a HuggingFace - ``Dataset`` with ``input`` and ``target`` columns. Each input is formatted as - ``"{item}\\nRequirement: <|end_of_text|>\\n{invocation_prompt}"``. + Reads `item`/`label` pairs from a JSONL file and builds a HuggingFace + `Dataset` with `input` and `target` columns. Each input is formatted as + `"{item}\\nRequirement: <|end_of_text|>\\n{invocation_prompt}"`. Args: - json_path: Path to the JSONL file containing ``item``/``label`` pairs. + json_path: Path to the JSONL file containing `item`/`label` pairs. tokenizer: HuggingFace tokenizer instance (currently unused, reserved for future tokenization steps). invocation_prompt: Invocation string appended to each input prompt. Returns: - A HuggingFace ``Dataset`` with ``"input"`` and ``"target"`` string columns. + A HuggingFace `Dataset` with `"input"` and `"target"` string columns. """ data = [] with open(json_path, encoding="utf-8") as f: @@ -77,12 +77,12 @@ def formatting_prompts_func(example): """Concatenate input and target columns for SFT prompt formatting. Args: - example: A batch dict with ``"input"`` and ``"target"`` list fields, as - produced by HuggingFace ``Dataset.map`` in batched mode. + example: A batch dict with `"input"` and `"target"` list fields, as + produced by HuggingFace `Dataset.map` in batched mode. Returns: - A list of strings, each formed by concatenating the ``input`` and - ``target`` values for a single example in the batch. + A list of strings, each formed by concatenating the `input` and + `target` values for a single example in the batch. """ return [ f"{example['input'][i]}{example['target'][i]}" @@ -95,7 +95,7 @@ class SaveBestModelCallback(TrainerCallback): Attributes: best_eval_loss (float): Lowest evaluation loss seen so far across all - evaluation steps. Initialised to ``float("inf")``. + evaluation steps. Initialised to `float("inf")`. """ def __init__(self): @@ -105,18 +105,18 @@ def on_evaluate(self, args, state, control, **kwargs): """Save the adapter weights if the current evaluation loss is a new best. Called automatically by the HuggingFace Trainer after each evaluation - step. Compares the current ``eval_loss`` from ``metrics`` against - ``best_eval_loss`` and, if lower, updates the stored best and saves the - model to ``args.output_dir``. + step. Compares the current `eval_loss` from `metrics` against + `best_eval_loss` and, if lower, updates the stored best and saves the + model to `args.output_dir`. Args: - args: ``TrainingArguments`` instance with training configuration, - including ``output_dir``. - state: ``TrainerState`` instance with the current training state. - control: ``TrainerControl`` instance for controlling training flow. + args: `TrainingArguments` instance with training configuration, + including `output_dir`. + state: `TrainerState` instance with the current training state. + control: `TrainerControl` instance for controlling training flow. **kwargs: Additional keyword arguments provided by the Trainer, - including ``"model"`` (the current PEFT model) and - ``"metrics"`` (a dict containing ``"eval_loss"``). + including `"model"` (the current PEFT model) and + `"metrics"` (a dict containing `"eval_loss"`). """ model = kwargs["model"] metrics = kwargs["metrics"] @@ -132,13 +132,13 @@ class SafeSaveTrainer(SFTTrainer): def save_model(self, output_dir: str | None = None, _internal_call: bool = False): """Save the model and tokenizer with safe serialization always enabled. - Overrides ``SFTTrainer.save_model`` to call ``save_pretrained`` with - ``safe_serialization=True``, ensuring weights are saved in safetensors + Overrides `SFTTrainer.save_model` to call `save_pretrained` with + `safe_serialization=True`, ensuring weights are saved in safetensors format rather than the legacy pickle-based format. Args: output_dir (str | None): Directory to save the model into. If - ``None``, the trainer's configured ``output_dir`` is used. + `None`, the trainer's configured `output_dir` is used. _internal_call (bool): Internal flag passed through from the Trainer base class; not used by this override. """ @@ -165,8 +165,8 @@ def train_model( """Fine-tune a causal language model to produce a LoRA or aLoRA adapter. Loads and 80/20-splits the JSONL dataset, configures PEFT with the specified - adapter type, trains using ``SFTTrainer`` with a best-checkpoint callback, saves - the adapter weights, and removes the PEFT-generated ``README.md`` from the output + adapter type, trains using `SFTTrainer` with a best-checkpoint callback, saves + the adapter weights, and removes the PEFT-generated `README.md` from the output directory. Args: @@ -174,11 +174,11 @@ def train_model( base_model: Hugging Face model ID or local path to the base model. output_file: Destination path for the trained adapter weights. prompt_file: Optional path to a JSON config file with an - ``"invocation_prompt"`` key. Defaults to the aLoRA invocation token. - adapter: Adapter type to train -- ``"alora"`` (default) or ``"lora"``. - device: Device selection -- ``"auto"``, ``"cpu"``, ``"cuda"``, or - ``"mps"``. - run_name: Name of the training run (passed to ``SFTConfig``). + `"invocation_prompt"` key. Defaults to the aLoRA invocation token. + adapter: Adapter type to train -- `"alora"` (default) or `"lora"`. + device: Device selection -- `"auto"`, `"cpu"`, `"cuda"`, or + `"mps"`. + run_name: Name of the training run (passed to `SFTConfig`). epochs: Number of training epochs. learning_rate: Optimizer learning rate. batch_size: Per-device training batch size. @@ -186,10 +186,10 @@ def train_model( grad_accum: Gradient accumulation steps. Raises: - ValueError: If ``device`` is not one of ``"auto"``, ``"cpu"``, - ``"cuda"``, or ``"mps"``. + ValueError: If `device` is not one of `"auto"`, `"cpu"`, + `"cuda"`, or `"mps"`. RuntimeError: If the GPU has insufficient VRAM to load the model - (wraps ``NotImplementedError`` for meta tensor errors). + (wraps `NotImplementedError` for meta tensor errors). """ if prompt_file: # load the configurable variable invocation_prompt diff --git a/cli/alora/upload.py b/cli/alora/upload.py index 53425b582..58a9ca3b1 100644 --- a/cli/alora/upload.py +++ b/cli/alora/upload.py @@ -1,9 +1,9 @@ """Upload a trained LoRA or aLoRA adapter to Hugging Face Hub. Creates the target repository if it does not already exist and pushes the entire -adapter weights directory (output of ``save_pretrained``) to the repository root. -Requires an authenticated Hugging Face token set via the ``HF_TOKEN`` environment -variable or ``huggingface-cli login``. +adapter weights directory (output of `save_pretrained`) to the repository root. +Requires an authenticated Hugging Face token set via the `HF_TOKEN` environment +variable or `huggingface-cli login`. """ import os @@ -20,7 +20,7 @@ def upload_model(weight_path: str, model_name: str, private: bool = True): private (bool): Whether the repo should be private. Default: True. Raises: - FileNotFoundError: If ``weight_path`` does not exist on disk. + FileNotFoundError: If `weight_path` does not exist on disk. OSError: If no Hugging Face authentication token is found. RuntimeError: If creating or accessing the Hugging Face repository fails. """ diff --git a/cli/decompose/__init__.py b/cli/decompose/__init__.py index 76d37d5bf..f538fbd44 100644 --- a/cli/decompose/__init__.py +++ b/cli/decompose/__init__.py @@ -1,10 +1,10 @@ -"""Typer sub-application for the ``m decompose`` command group. +"""Typer sub-application for the `m decompose` command group. -Exposes a single ``run`` command that takes a task prompt (from a file or +Exposes a single `run` command that takes a task prompt (from a file or interactively), calls the LLM-based decomposition pipeline to break it into structured subtasks with constraints and dependency ordering, and writes the results as a JSON data file and a ready-to-run Python script. Invoke via -``m decompose run --help`` for full option documentation. +`m decompose run --help` for full option documentation. """ import typer diff --git a/cli/decompose/decompose.py b/cli/decompose/decompose.py index 4e3f0112f..561b53a23 100644 --- a/cli/decompose/decompose.py +++ b/cli/decompose/decompose.py @@ -1,4 +1,4 @@ -"""Implementation of the ``m decompose run`` CLI command. +"""Implementation of the `m decompose run` CLI command. Accepts a task prompt (from a text file or interactive input), calls the multi-step LLM decomposition pipeline to produce a structured list of subtasks each with @@ -25,7 +25,7 @@ class DecompVersion(StrEnum): """Available versions of the decomposition pipeline template. - Newer versions must be declared last to ensure ``latest`` always resolves to + Newer versions must be declared last to ensure `latest` always resolves to the most recent template. Attributes: @@ -267,32 +267,32 @@ def run( Reads the task prompt either from a file or interactively, runs the LLM decomposition pipeline to produce subtask descriptions, Jinja2 prompt templates, constraint lists, and dependency metadata, validates variable ordering, then - writes a ``{out_name}.json`` result file and a rendered ``{out_name}.py`` + writes a `{out_name}.json` result file and a rendered `{out_name}.py` Python script to the output directory. Args: out_dir: Path to an existing directory where output files are saved. out_name: Base name (no extension) for the output files. Defaults to - ``"m_decomp_result"``. + `"m_decomp_result"`. prompt_file: Optional path to a raw-text file containing the task prompt. If omitted, the prompt is collected interactively. model_id: Model name or ID used for all decomposition pipeline steps. - backend: Inference backend -- ``"ollama"`` or ``"openai"``. + backend: Inference backend -- `"ollama"` or `"openai"`. backend_req_timeout: Request timeout in seconds for model inference calls. backend_endpoint: Base URL of the OpenAI-compatible endpoint. Required - when ``backend="openai"``. + when `backend="openai"`. backend_api_key: API key for the configured endpoint. Required when - ``backend="openai"``. + `backend="openai"`. version: Version of the decomposition pipeline template to use. - input_var: Optional list of user-input variable names (e.g. ``"DOC"``). + input_var: Optional list of user-input variable names (e.g. `"DOC"`). Each name must be a valid Python identifier. Pass this option multiple times to define multiple variables. Raises: - AssertionError: If ``out_name`` contains invalid characters, if - ``out_dir`` does not exist or is not a directory, or if any - ``input_var`` name is not a valid Python identifier. - ValueError: If a required input variable is missing from ``input_var`` + AssertionError: If `out_name` contains invalid characters, if + `out_dir` does not exist or is not a directory, or if any + `input_var` name is not a valid Python identifier. + ValueError: If a required input variable is missing from `input_var` or if circular dependencies are detected among subtasks. Exception: Re-raised from the decomposition pipeline after cleaning up any partially written output files. diff --git a/cli/decompose/pipeline.py b/cli/decompose/pipeline.py index 452e16a5b..ed159ae83 100644 --- a/cli/decompose/pipeline.py +++ b/cli/decompose/pipeline.py @@ -1,8 +1,8 @@ """Core decomposition pipeline that breaks a task prompt into structured subtasks. -Provides the ``decompose()`` function, which orchestrates a series of LLM calls +Provides the `decompose()` function, which orchestrates a series of LLM calls (subtask listing, constraint extraction, validation strategy selection, prompt -generation, and constraint assignment) to produce a ``DecompPipelineResult`` +generation, and constraint assignment) to produce a `DecompPipelineResult` containing subtasks, per-subtask prompts, constraints, and dependency information. Supports Ollama, OpenAI-compatible, and RITS inference backends. """ @@ -35,7 +35,7 @@ class ConstraintResult(TypedDict): Attributes: constraint (str): Natural-language description of the constraint. validation_strategy (str): Strategy assigned to validate the constraint; - either ``"code"`` or ``"llm"``. + either `"code"` or `"llm"`. """ constraint: str @@ -52,11 +52,11 @@ class DecompSubtasksResult(TypedDict): constraints (list[ConstraintResult]): List of constraints assigned to this subtask, each with a validation strategy. prompt_template (str): Jinja2 prompt template string for this subtask, - with ``{{ variable }}`` placeholders for inputs and prior subtask results. + with `{{ variable }}` placeholders for inputs and prior subtask results. input_vars_required (list[str]): Ordered list of user-provided input - variable names referenced in ``prompt_template``. + variable names referenced in `prompt_template`. depends_on (list[str]): Ordered list of subtask tags whose results are - referenced in ``prompt_template``. + referenced in `prompt_template`. generated_response (str): Optional field holding the model response produced during execution; not present until the subtask runs. """ @@ -131,17 +131,17 @@ def decompose( task_prompt: Natural-language description of the task to decompose. user_input_variable: Optional list of variable names that will be templated into generated prompts as user-provided input data. Pass - ``None`` or an empty list if the task requires no input variables. + `None` or an empty list if the task requires no input variables. model_id: Model name or ID used for all pipeline steps. - backend: Inference backend -- ``"ollama"``, ``"openai"``, or ``"rits"``. + backend: Inference backend -- `"ollama"`, `"openai"`, or `"rits"`. backend_req_timeout: Request timeout in seconds for model inference calls. backend_endpoint: Base URL of the OpenAI-compatible endpoint. Required - when ``backend`` is ``"openai"`` or ``"rits"``. + when `backend` is `"openai"` or `"rits"`. backend_api_key: API key for the configured endpoint. Required when - ``backend`` is ``"openai"`` or ``"rits"``. + `backend` is `"openai"` or `"rits"`. Returns: - A ``DecompPipelineResult`` containing the original prompt, subtask list, + A `DecompPipelineResult` containing the original prompt, subtask list, identified constraints, and fully annotated subtask objects with prompt templates, constraint assignments, and dependency information. """ diff --git a/cli/decompose/prompt_modules/subtask_constraint_assign/_subtask_constraint_assign.py b/cli/decompose/prompt_modules/subtask_constraint_assign/_subtask_constraint_assign.py index 45f56df0a..6c32f21b0 100644 --- a/cli/decompose/prompt_modules/subtask_constraint_assign/_subtask_constraint_assign.py +++ b/cli/decompose/prompt_modules/subtask_constraint_assign/_subtask_constraint_assign.py @@ -55,7 +55,7 @@ def _default_parser(generated_str: str) -> list[SubtaskPromptConstraintsItem]: Note that the result "constraints" list can be empty. For example - ``` + `` [ SubtaskPromptConstraintsItem( subtask=, tag=, @@ -64,10 +64,10 @@ def _default_parser(generated_str: str) -> list[SubtaskPromptConstraintsItem]: ), ... ] - ``` + `` You can use dot notation to access the values. For example - ``` + `` result: PromptModuleString = # Result of the subtask_constraint_assign.generate() method parsed_result: list[SubtaskPromptConstraintsItem] = result.parse() @@ -76,7 +76,7 @@ def _default_parser(generated_str: str) -> list[SubtaskPromptConstraintsItem]: tag_0: str = result[0].tag prompt_template_0: str = result[0].prompt_template constraints_0: list[str] = result[0].constraints - ``` + `` Raises: TagExtractionError: An error occurred trying to extract content from the @@ -171,12 +171,12 @@ def generate( # type: ignore[override] the subtask title/description in natural language, the second position is a tag/variable with a descriptive name related to its subtask, and the third position is the template prompt for an LLM to execute the subtask. e.g. - ``` + `` subtasks_tags_and_prompts = [ ("1. Read the document and write a summary", "DOCUMENT_SUMMARY", ""), ("2. Write the 3 most important phrases as bullets", "IMPORTANT_PHRASES", "") ] - ``` + `` constraint_list (`Sequence[str]`): A list of constraints written in natural language. This was designed to take in a list of constraints identified from the prompt diff --git a/cli/decompose/prompt_modules/subtask_constraint_assign/_types.py b/cli/decompose/prompt_modules/subtask_constraint_assign/_types.py index 97c5f7a63..bb8c4f4dd 100644 --- a/cli/decompose/prompt_modules/subtask_constraint_assign/_types.py +++ b/cli/decompose/prompt_modules/subtask_constraint_assign/_types.py @@ -5,13 +5,13 @@ class SubtaskPromptConstraintsItem(NamedTuple): """A `tuple` generated by the `subtask_prompt_generator` prompt module. Inherits from `NamedTuple`, so the attributes can be accessed with dot notation. e.g. - ``` + `` # item: SubtaskPromptConstraintsItem subtask_title: str = item.subtask subtask_tag:str = item.tag subtask_prompt_template: str = item.prompt_template subtask_constraints: list[str] = item.constraints - ``` + `` Attributes: subtask (`str`): The subtask title / brief description. diff --git a/cli/decompose/prompt_modules/subtask_list/_subtask_list.py b/cli/decompose/prompt_modules/subtask_list/_subtask_list.py index bf5eed389..4bc6c718c 100644 --- a/cli/decompose/prompt_modules/subtask_list/_subtask_list.py +++ b/cli/decompose/prompt_modules/subtask_list/_subtask_list.py @@ -58,18 +58,18 @@ def _default_parser(generated_str: str) -> list[SubtaskItem]: `tuple` contains the generated "subtask" (`str`) and its generated "tag" (`str`). For example - ``` + `` [ SubtaskItem(subtask=, tag=), SubtaskItem(subtask=, tag=) ] - ``` + `` You can use dot notation to access the values. For example - ``` + `` result: PromptModuleString = subtask_list.generate(task_prompt, mellea_session) parsed_result: list[SubtaskItem] = result.parse() subtask_0: str = result[0].subtask tag_0: str = result[0].tag - ``` + `` Raises: TagExtractionError: An error occurred trying to extract content from the diff --git a/cli/decompose/prompt_modules/subtask_list/_types.py b/cli/decompose/prompt_modules/subtask_list/_types.py index 0dc6f174e..218d09c0e 100644 --- a/cli/decompose/prompt_modules/subtask_list/_types.py +++ b/cli/decompose/prompt_modules/subtask_list/_types.py @@ -5,11 +5,11 @@ class SubtaskItem(NamedTuple): """A `tuple` representing a subtask generated by the `subtask_list_generator` prompt module. Inherits from `NamedTuple`, so the attributes can be accessed with dot notation. e.g. - ``` + `` # item: SubtaskItem subtask_title: str = item.subtask subtask_tag:str = item.tag - ``` + `` Attributes: subtask (`str`): The generated subtask title / brief description. diff --git a/cli/decompose/prompt_modules/subtask_prompt_generator/_subtask_prompt_generator.py b/cli/decompose/prompt_modules/subtask_prompt_generator/_subtask_prompt_generator.py index 05fa5ed71..e28bc6c18 100644 --- a/cli/decompose/prompt_modules/subtask_prompt_generator/_subtask_prompt_generator.py +++ b/cli/decompose/prompt_modules/subtask_prompt_generator/_subtask_prompt_generator.py @@ -50,13 +50,13 @@ def _default_parser(generated_str: str) -> list[SubtaskPromptItem]: its generated "prompt_template" (`str`). For example - ``` + `` [ SubtaskPromptItem(subtask=, tag=, prompt_template=), SubtaskPromptItem(subtask=, tag=, prompt_template=) ] - ``` + `` You can use dot notation to access the values. For example - ``` + `` task_prompt = "..." # Original task prompt to be the reference when generating subtask prompts mellea_session = MelleaSession(...) # A mellea session with a backend subtasks = [ ("1. Read the document and write a summary", "DOCUMENT_SUMMARY"), @@ -74,7 +74,7 @@ def _default_parser(generated_str: str) -> list[SubtaskPromptItem]: subtask_0: str = result[0].subtask tag_0: str = result[0].tag prompt_template_0: str = result[0].prompt_template - ``` + `` Raises: TagExtractionError: An error occurred trying to extract content from the @@ -150,9 +150,9 @@ def generate( # type: ignore[override] Let's say your task is for writing emails addressed to a prospect of a given company, then this task needs to ingest some variables, e.g. - ``` + `` user_input_var_names = ["YOUR_NAME", "PROSPECT_NAME", "PROSPECT_COMPANY", "PRODUCT_DESCRIPTION"] - ``` + `` subtasks_and_tags (`Sequence[tuple[str, str]]`): A list of subtasks and their respective tags. This was designed to receive the parsed result of the `subtask_list` @@ -161,12 +161,12 @@ def generate( # type: ignore[override] The list is composed of `tuple[str, str]` objects where the first position is the subtask title/description in natural language and the second position is a tag/variable with a descriptive name related to its subtask. e.g. - ``` + `` subtasks_and_tags = [ ("1. Read the document and write a summary", "DOCUMENT_SUMMARY"), ("2. Write the 3 most important phrases as bullets", "IMPORTANT_PHRASES"), ] - ``` + `` Returns: PromptModuleString: A `PromptModuleString` class containing the generated output. diff --git a/cli/decompose/prompt_modules/subtask_prompt_generator/_types.py b/cli/decompose/prompt_modules/subtask_prompt_generator/_types.py index 713bdc27e..1c0e57502 100644 --- a/cli/decompose/prompt_modules/subtask_prompt_generator/_types.py +++ b/cli/decompose/prompt_modules/subtask_prompt_generator/_types.py @@ -5,12 +5,12 @@ class SubtaskPromptItem(NamedTuple): """A `tuple` generated by the `subtask_prompt_generator` prompt module. Inherits from `NamedTuple`, so the attributes can be accessed with dot notation. e.g. - ``` + `` # item: SubtaskPromptItem subtask_title: str = item.subtask subtask_tag:str = item.tag subtask_prompt_template: str = item.prompt_template - ``` + `` Attributes: subtask (`str`): The subtask title / brief description. diff --git a/cli/decompose/utils.py b/cli/decompose/utils.py index e170c447a..c5438476d 100644 --- a/cli/decompose/utils.py +++ b/cli/decompose/utils.py @@ -1,6 +1,6 @@ """Filename validation utilities for the decompose pipeline. -Provides ``validate_filename``, which checks that a candidate output filename +Provides `validate_filename`, which checks that a candidate output filename contains only safe characters (alphanumeric, underscores, hyphens, periods, and spaces) and falls within a reasonable length limit. Used to prevent path-traversal or shell-injection issues when writing decomposition output files. @@ -18,7 +18,7 @@ def validate_filename(candidate_str: str) -> bool: candidate_str: The filename candidate to validate. Returns: - ``True`` if the string is a safe, valid filename; ``False`` otherwise. + `True` if the string is a safe, valid filename; `False` otherwise. """ import re diff --git a/cli/eval/__init__.py b/cli/eval/__init__.py index cb57f3230..bad943576 100644 --- a/cli/eval/__init__.py +++ b/cli/eval/__init__.py @@ -1,6 +1,6 @@ """CLI package for test-based LLM evaluation. -Provides the ``m eval`` command group, which orchestrates running a generator model +Provides the `m eval` command group, which orchestrates running a generator model against structured test files and scoring each response with a judge model. Each test file specifies a set of instructions and input examples; results — including per-input pass/fail judgements and cumulative pass rates — are written to JSON or JSONL for diff --git a/cli/eval/commands.py b/cli/eval/commands.py index ec778a234..26029c44d 100644 --- a/cli/eval/commands.py +++ b/cli/eval/commands.py @@ -40,15 +40,15 @@ def eval_run( Args: test_files: Paths to JSON/JSONL files containing test cases. backend: Generation backend name. - model: Generation model name, or ``None`` for the default. + model: Generation model name, or `None` for the default. max_gen_tokens: Maximum tokens to generate for each response. - judge_backend: Judge backend name, or ``None`` to reuse the generation + judge_backend: Judge backend name, or `None` to reuse the generation backend. - judge_model: Judge model name, or ``None`` for the default. + judge_model: Judge model name, or `None` for the default. max_judge_tokens: Maximum tokens for the judge model's output. output_path: File path prefix for the results file. - output_format: Output format -- ``"json"`` or ``"jsonl"``. - continue_on_error: If ``True``, skip failed tests instead of raising. + output_format: Output format -- `"json"` or `"jsonl"`. + continue_on_error: If `True`, skip failed tests instead of raising. """ from cli.eval.runner import run_evaluations diff --git a/cli/eval/runner.py b/cli/eval/runner.py index 9fdf55ff3..f13baf1fd 100644 --- a/cli/eval/runner.py +++ b/cli/eval/runner.py @@ -1,8 +1,8 @@ """Execution engine for the test-based LLM evaluation pipeline. -Loads JSON test files into ``TestBasedEval`` objects and, for each test, runs a +Loads JSON test files into `TestBasedEval` objects and, for each test, runs a generator model to produce responses and a separate judge model to score them. Parses -the judge output for a ``{"score": ..., "justification": ...}`` JSON fragment, +the judge output for a `{"score": ..., "justification": ...}` JSON fragment, aggregates per-input pass/fail counts, and saves the full results to JSON or JSONL. """ @@ -30,7 +30,7 @@ class InputEvalResult: input_text (str): The raw input text sent to the generation model. model_output (str): The text response produced by the generation model. validation_passed (bool): Whether the judge scored this response as passing. - score (int): Numeric score assigned by the judge (``1`` for pass, ``0`` for fail). + score (int): Numeric score assigned by the judge (`1` for pass, `0` for fail). validation_reason (str): Justification text returned by the judge model. """ @@ -53,8 +53,8 @@ def to_dict(self): """Serialise the input evaluation result to a plain dictionary. Returns: - dict: A dictionary with keys ``"input"``, ``"model_output"``, - ``"passed"``, ``"score"``, and ``"justification"``. + dict: A dictionary with keys `"input"`, `"model_output"`, + `"passed"`, `"score"`, and `"justification"`. """ return { "input": self.input_text, @@ -77,7 +77,7 @@ class TestEvalResult: Attributes: passed_count (int): Number of inputs that received a passing score. total_count (int): Total number of inputs evaluated. - pass_rate (float): Fraction of inputs that passed (``passed_count / total_count``). + pass_rate (float): Fraction of inputs that passed (`passed_count / total_count`). """ def __init__(self, test_eval: TestBasedEval, input_results: list[InputEvalResult]): @@ -88,11 +88,11 @@ def to_dict(self): """Serialise the test evaluation result to a plain dictionary. Returns: - dict: A dictionary containing the test metadata (``"test_id"``, - ``"source"``, ``"name"``, ``"instructions"``), per-input results - under ``"input_results"``, expected targets under - ``"expected_targets"``, and summary counts (``"passed"``, - ``"total_count"``, ``"pass_rate"``). + dict: A dictionary containing the test metadata (`"test_id"`, + `"source"`, `"name"`, `"instructions"`), per-input results + under `"input_results"`, expected targets under + `"expected_targets"`, and summary counts (`"passed"`, + `"total_count"`, `"pass_rate"`). """ return { "test_id": self.test_eval.test_id, @@ -125,18 +125,18 @@ def create_session( """Create a mellea session with the specified backend and model. Args: - backend: Backend name: ``"ollama"``, ``"openai"``, ``"hf"``, - ``"watsonx"``, or ``"litellm"``. - model: Model ID or ``ModelIdentifier`` attribute name, or ``None`` + backend: Backend name: `"ollama"`, `"openai"`, `"hf"`, + `"watsonx"`, or `"litellm"`. + model: Model ID or `ModelIdentifier` attribute name, or `None` to use the default model. - max_tokens: Maximum number of tokens to generate, or ``None`` for + max_tokens: Maximum number of tokens to generate, or `None` for the backend default. Returns: - A configured ``MelleaSession`` ready for generation. + A configured `MelleaSession` ready for generation. Raises: - ValueError: If ``backend`` is not one of the supported backend names. + ValueError: If `backend` is not one of the supported backend names. Exception: Re-raised from backend or session construction if initialisation fails. """ @@ -230,20 +230,20 @@ def run_evaluations( Args: test_files: List of paths to JSON test files. Each file should contain - ``"id"``, ``"source"``, ``"name"``, ``"instructions"``, and - ``"examples"`` fields. + `"id"`, `"source"`, `"name"`, `"instructions"`, and + `"examples"` fields. backend: Backend name for the generation model. - model: Model ID for the generator, or ``None`` for the default. - max_gen_tokens: Maximum tokens for the generator, or ``None`` for the + model: Model ID for the generator, or `None` for the default. + max_gen_tokens: Maximum tokens for the generator, or `None` for the backend default. - judge_backend: Backend name for the judge model, or ``None`` to reuse + judge_backend: Backend name for the judge model, or `None` to reuse the generation backend. - judge_model: Model ID for the judge, or ``None`` for the default. - max_judge_tokens: Maximum tokens for the judge, or ``None`` for the + judge_model: Model ID for the judge, or `None` for the default. + max_judge_tokens: Maximum tokens for the judge, or `None` for the backend default. output_path: File path prefix for saving results. - output_format: Output format: ``"json"`` or ``"jsonl"``. - continue_on_error: If ``True``, skip failed test evaluations instead of + output_format: Output format: `"json"` or `"jsonl"`. + continue_on_error: If `True`, skip failed test evaluations instead of raising. """ all_test_evals: list[TestBasedEval] = [] @@ -314,16 +314,16 @@ def execute_test_eval( ) -> TestEvalResult: """Execute a single test evaluation. - For each input in the test, generates a response using ``generation_session``, - then validates using ``judge_session``. + For each input in the test, generates a response using `generation_session`, + then validates using `judge_session`. Args: - test_eval: The ``TestBasedEval`` object containing inputs and targets. - generation_session: ``MelleaSession`` used to produce model responses. - judge_session: ``MelleaSession`` used to score model responses. + test_eval: The `TestBasedEval` object containing inputs and targets. + generation_session: `MelleaSession` used to produce model responses. + judge_session: `MelleaSession` used to score model responses. Returns: - A ``TestEvalResult`` with per-input pass/fail outcomes. + A `TestEvalResult` with per-input pass/fail outcomes. """ input_results = [] @@ -373,8 +373,8 @@ def parse_judge_output(judge_output: str): judge_output: Raw text output from the judge model. Returns: - A ``(score, justification)`` tuple where ``score`` is an integer (or - ``None`` if parsing failed) and ``justification`` is an explanatory + A `(score, justification)` tuple where `score` is an integer (or + `None` if parsing failed) and `justification` is an explanatory string. """ try: @@ -401,10 +401,10 @@ def save_results(results: list[TestEvalResult], output_path: str, output_format: """Persist evaluation results to disk in JSON or JSONL format. Args: - results: List of ``TestEvalResult`` objects to serialise. + results: List of `TestEvalResult` objects to serialise. output_path: Destination file path (extension may be appended if it - does not match ``output_format``). - output_format: Format string: ``"json"`` or ``"jsonl"``. + does not match `output_format`). + output_format: Format string: `"json"` or `"jsonl"`. """ output_path_obj = Path(output_path) if output_path_obj.suffix != f".{output_format}": @@ -441,7 +441,7 @@ def summary_stats(results: list[TestEvalResult]): """Print aggregated pass-rate statistics for a set of evaluation results. Args: - results: List of ``TestEvalResult`` objects to summarise. + results: List of `TestEvalResult` objects to summarise. """ total_inputs = sum(r.total_count for r in results) passed_inputs = sum(r.passed_count for r in results) diff --git a/cli/m.py b/cli/m.py index 4cdb4f7fa..7358cfba2 100644 --- a/cli/m.py +++ b/cli/m.py @@ -1,9 +1,9 @@ -"""Entrypoint for the ``m`` command-line tool. +"""Entrypoint for the `m` command-line tool. -Wires together all CLI sub-applications into a single Typer root command: ``m serve`` -(start a model-serving endpoint), ``m alora`` (train and upload LoRA/aLoRA adapters), -``m decompose`` (LLM-driven task decomposition), and ``m eval`` (test-based model -evaluation). Run ``m --help`` to see all available sub-commands. +Wires together all CLI sub-applications into a single Typer root command: `m serve` +(start a model-serving endpoint), `m alora` (train and upload LoRA/aLoRA adapters), +`m decompose` (LLM-driven task decomposition), and `m eval` (test-based model +evaluation). Run `m --help` to see all available sub-commands. """ import typer @@ -21,9 +21,9 @@ def callback() -> None: """Mellea command-line tool for LLM-powered workflows. - Provides sub-commands for serving models (``m serve``), training and uploading - adapters (``m alora``), decomposing tasks into subtasks (``m decompose``), and - running test-based evaluation pipelines (``m eval``). + Provides sub-commands for serving models (`m serve`), training and uploading + adapters (`m alora`), decomposing tasks into subtasks (`m decompose`), and + running test-based evaluation pipelines (`m eval`). """ diff --git a/mellea/backends/__init__.py b/mellea/backends/__init__.py index 65e8d1e63..672827818 100644 --- a/mellea/backends/__init__.py +++ b/mellea/backends/__init__.py @@ -1,10 +1,10 @@ """Backend implementations for the mellea inference layer. This package exposes the concrete machinery for connecting mellea to language model -servers. It bundles ``FormatterBackend`` (a prompt-engineering base class for legacy -models), ``ModelIdentifier`` (portable cross-platform model names), ``ModelOption`` -(generation parameters such as token limits), ``SimpleLRUCache`` (KV-cache -management), and ``MelleaTool`` / ``tool`` (LLM tool definitions). Reach for this +servers. It bundles `FormatterBackend` (a prompt-engineering base class for legacy +models), `ModelIdentifier` (portable cross-platform model names), `ModelOption` +(generation parameters such as token limits), `SimpleLRUCache` (KV-cache +management), and `MelleaTool` / `tool` (LLM tool definitions). Reach for this package when configuring a backend, declaring tools, or tuning inference options. """ diff --git a/mellea/backends/adapters/adapter.py b/mellea/backends/adapters/adapter.py index 80ac85837..059fe60b6 100644 --- a/mellea/backends/adapters/adapter.py +++ b/mellea/backends/adapters/adapter.py @@ -1,10 +1,10 @@ """Adapter classes for adding fine-tuned modules to inference backends. -Defines the abstract ``Adapter`` base class and its concrete subclasses -``LocalHFAdapter`` (for locally loaded HuggingFace models) and ``IntrinsicAdapter`` +Defines the abstract `Adapter` base class and its concrete subclasses +`LocalHFAdapter` (for locally loaded HuggingFace models) and `IntrinsicAdapter` (for adapters whose metadata is stored in Mellea's intrinsic catalog). Also provides -``get_adapter_for_intrinsic`` for resolving the right adapter class given an -intrinsic name, and ``AdapterMixin`` for backends that support runtime adapter +`get_adapter_for_intrinsic` for resolving the right adapter class given an +intrinsic name, and `AdapterMixin` for backends that support runtime adapter loading and unloading. """ @@ -25,18 +25,18 @@ class Adapter(abc.ABC): """An adapter that can be added to a single backend. An adapter can only be registered with one backend at a time. Use - ``adapter.qualified_name`` when referencing the adapter after adding it. + `adapter.qualified_name` when referencing the adapter after adding it. Args: name (str): Human-readable name of the adapter. adapter_type (AdapterType): Enum describing the adapter type (e.g. - ``AdapterType.LORA`` or ``AdapterType.ALORA``). + `AdapterType.LORA` or `AdapterType.ALORA`). Attributes: qualified_name (str): Unique name used for loading and lookup; formed - as ``"_"``. + as `"_"`. backend (Backend | None): The backend this adapter has been added to, - or ``None`` if not yet added. + or `None` if not yet added. path (str | None): Filesystem path to the adapter weights; set when the adapter is added to a backend. """ @@ -58,7 +58,7 @@ def __init__(self, name: str, adapter_type: AdapterType): class LocalHFAdapter(Adapter): """Abstract adapter subclass for locally loaded HuggingFace model backends. - Subclasses must implement ``get_local_hf_path`` to return the filesystem path + Subclasses must implement `get_local_hf_path` to return the filesystem path from which adapter weights should be loaded given a base model name. """ @@ -68,7 +68,7 @@ def get_local_hf_path(self, base_model_name: str) -> str: Args: base_model_name (str): The base model name; typically the last component - of the HuggingFace model ID (e.g. ``"granite-4.0-micro"``). + of the HuggingFace model ID (e.g. `"granite-4.0-micro"`). Returns: str: Filesystem path to the adapter weights directory. @@ -83,29 +83,29 @@ class IntrinsicAdapter(LocalHFAdapter): * implement intrinsic functions * are packaged as LoRA or aLoRA adapters on top of a base model - * use the shared model loading code in ``mellea.formatters.granite.intrinsics`` + * use the shared model loading code in `mellea.formatters.granite.intrinsics` * use the shared input and output processing code in - ``mellea.formatters.granite.intrinsics`` + `mellea.formatters.granite.intrinsics` Args: - intrinsic_name (str): Name of the intrinsic (e.g. ``"answerability"``); the - adapter's ``qualified_name`` will be derived from this. + intrinsic_name (str): Name of the intrinsic (e.g. `"answerability"`); the + adapter's `qualified_name` will be derived from this. adapter_type (AdapterType): Enum describing the adapter type; defaults to - ``AdapterType.ALORA``. + `AdapterType.ALORA`. config_file (str | pathlib.Path | None): Path to a YAML config file defining the intrinsic's I/O transformations; mutually exclusive with - ``config_dict``. + `config_dict`. config_dict (dict | None): Dict defining the intrinsic's I/O transformations; - mutually exclusive with ``config_file``. + mutually exclusive with `config_file`. base_model_name (str | None): Base model name used to look up the I/O - processing config when neither ``config_file`` nor ``config_dict`` are + processing config when neither `config_file` nor `config_dict` are provided. Attributes: intrinsic_name (str): Name of the intrinsic this adapter implements. intrinsic_metadata (IntriniscsCatalogEntry): Catalog metadata for the intrinsic. base_model_name (str | None): Base model name provided at construction, if any. - adapter_type (AdapterType): The adapter type (``LORA`` or ``ALORA``). + adapter_type (AdapterType): The adapter type (`LORA` or `ALORA`). config (dict): Parsed I/O transformation configuration for the intrinsic. """ @@ -178,7 +178,7 @@ def get_local_hf_path(self, base_model_name: str) -> str: Args: base_model_name (str): The base model name; typically the last component - of the HuggingFace model ID (e.g. ``"granite-3.3-8b-instruct"``). + of the HuggingFace model ID (e.g. `"granite-3.3-8b-instruct"`). Returns: str: Filesystem path to the downloaded adapter weights directory. @@ -217,15 +217,15 @@ def get_adapter_for_intrinsic( """Find an adapter from a dict of available adapters based on the intrinsic name and its allowed adapter types. Args: - intrinsic_name (str): The name of the intrinsic, e.g. ``"answerability"``. + intrinsic_name (str): The name of the intrinsic, e.g. `"answerability"`. intrinsic_adapter_types (list[AdapterType] | tuple[AdapterType, ...]): The adapter types allowed for this intrinsic, e.g. - ``[AdapterType.ALORA, AdapterType.LORA]``. + `[AdapterType.ALORA, AdapterType.LORA]`. available_adapters (dict[str, T]): The available adapters to choose from; - maps ``adapter.qualified_name`` to the adapter object. + maps `adapter.qualified_name` to the adapter object. Returns: - T | None: The first matching adapter found, or ``None`` if no match exists. + T | None: The first matching adapter found, or `None` if no match exists. """ adapter = None for adapter_type in intrinsic_adapter_types: @@ -242,8 +242,8 @@ class AdapterMixin(Backend, abc.ABC): Attributes: base_model_name (str): The short model name used to identify adapter - variants (e.g. ``"granite-3.3-8b-instruct"`` for - ``"ibm-granite/granite-3.3-8b-instruct"``). + variants (e.g. `"granite-3.3-8b-instruct"` for + `"ibm-granite/granite-3.3-8b-instruct"`). """ @property @@ -252,7 +252,7 @@ def base_model_name(self) -> str: """Return the short model name used for adapter variant lookup. Returns: - str: The base model name (e.g. ``"granite-3.3-8b-instruct"``). + str: The base model name (e.g. `"granite-3.3-8b-instruct"`). """ @abc.abstractmethod @@ -270,11 +270,11 @@ def add_adapter(self, *args, **kwargs): def load_adapter(self, adapter_qualified_name: str): """Load a previously registered adapter into the underlying model. - The adapter must have been registered via ``add_adapter`` before calling + The adapter must have been registered via `add_adapter` before calling this method. Args: - adapter_qualified_name (str): The ``adapter.qualified_name`` of the + adapter_qualified_name (str): The `adapter.qualified_name` of the adapter to load. """ @@ -283,7 +283,7 @@ def unload_adapter(self, adapter_qualified_name: str): """Unload a previously loaded adapter from the underlying model. Args: - adapter_qualified_name (str): The ``adapter.qualified_name`` of the + adapter_qualified_name (str): The `adapter.qualified_name` of the adapter to unload. """ @@ -293,7 +293,7 @@ def list_adapters(self) -> list[str]: Returns: list[str]: Qualified adapter names for all adapters that have been - loaded via ``load_adapter``. + loaded via `load_adapter`. Raises: NotImplementedError: If the concrete backend subclass has not @@ -311,16 +311,16 @@ class CustomIntrinsicAdapter(IntrinsicAdapter): a subclass of this class. Creating a subclass of this class appears to be a cosmetic boilerplate development task that isn't actually necessary for any existing use case. - This class has the same functionality as ``IntrinsicAdapter``, except that its + This class has the same functionality as `IntrinsicAdapter`, except that its constructor monkey-patches Mellea global variables to enable the backend to load the user's adapter. The code that performs this monkey-patching is marked as a temporary hack. Args: model_id (str): The HuggingFace model ID used for downloading model weights; - expected format is ``"/"``. + expected format is `"/"`. intrinsic_name (str | None): Catalog name for the intrinsic; defaults to the - repository name portion of ``model_id`` if not provided. + repository name portion of `model_id` if not provided. base_model_name (str): The short name of the base model (NOT its repo ID). """ diff --git a/mellea/backends/adapters/catalog.py b/mellea/backends/adapters/catalog.py index e60223493..61211e1b5 100644 --- a/mellea/backends/adapters/catalog.py +++ b/mellea/backends/adapters/catalog.py @@ -13,8 +13,8 @@ class AdapterType(enum.Enum): """Possible types of adapters for a backend. Attributes: - LORA (str): Standard LoRA adapter; value ``"lora"``. - ALORA (str): Activated LoRA adapter; value ``"alora"``. + LORA (str): Standard LoRA adapter; value `"lora"`. + ALORA (str): Activated LoRA adapter; value `"alora"`. """ LORA = "lora" @@ -29,12 +29,12 @@ class IntriniscsCatalogEntry(pydantic.BaseModel): Attributes: name (str): User-visible name of the intrinsic. internal_name (str | None): Internal name used for adapter loading, or - ``None`` if the same as ``name``. + `None` if the same as `name`. repo_id (str): HuggingFace repository where adapters for the intrinsic are located. adapter_types (tuple[AdapterType, ...]): Adapter types known to be available for this intrinsic; defaults to - ``(AdapterType.LORA, AdapterType.ALORA)``. + `(AdapterType.LORA, AdapterType.ALORA)`. """ name: str = pydantic.Field(description="User-visible name of the intrinsic.") @@ -108,7 +108,7 @@ def fetch_intrinsic_metadata(intrinsic_name: str) -> IntriniscsCatalogEntry: intrinsic. Raises: - ValueError: If ``intrinsic_name`` is not a known intrinsic name. + ValueError: If `intrinsic_name` is not a known intrinsic name. """ if intrinsic_name not in _INTRINSICS_CATALOG: raise ValueError( diff --git a/mellea/backends/backend.py b/mellea/backends/backend.py index b8329392d..3d60c4be7 100644 --- a/mellea/backends/backend.py +++ b/mellea/backends/backend.py @@ -1,10 +1,10 @@ -"""``FormatterBackend``: base class for prompt-engineering backends. +"""`FormatterBackend`: base class for prompt-engineering backends. -``FormatterBackend`` extends the abstract ``Backend`` with a ``ChatFormatter`` and -a ``ModelIdentifier``, bridging mellea's generative programming primitives to models +`FormatterBackend` extends the abstract `Backend` with a `ChatFormatter` and +a `ModelIdentifier`, bridging mellea's generative programming primitives to models that do not yet natively support spans or structured fine-tuning. Concrete backend -implementations (e.g. Ollama, HuggingFace, OpenAI) subclass ``FormatterBackend`` and -supply the model-specific ``generate_from_context`` logic. +implementations (e.g. Ollama, HuggingFace, OpenAI) subclass `FormatterBackend` and +supply the model-specific `generate_from_context` logic. """ import abc @@ -31,7 +31,7 @@ class FormatterBackend(Backend, abc.ABC): Args: model_id (str | ModelIdentifier): The model identifier to use for generation. formatter (ChatFormatter): The formatter used to convert components into prompts. - model_options (dict | None): Default model options; if ``None``, an empty dict + model_options (dict | None): Default model options; if `None`, an empty dict is used. """ diff --git a/mellea/backends/bedrock.py b/mellea/backends/bedrock.py index a340dba94..500d5d4a0 100644 --- a/mellea/backends/bedrock.py +++ b/mellea/backends/bedrock.py @@ -25,7 +25,7 @@ def list_mantle_models(region: str | None = None) -> list: """Return all models available at a bedrock-mantle endpoint. Args: - region: AWS region name (e.g. ``"us-east-1"``), or ``None`` to use the + region: AWS region name (e.g. `"us-east-1"`), or `None` to use the default region. Returns: @@ -43,10 +43,10 @@ def stringify_mantle_model_ids(region: str | None = None) -> str: """Return a human-readable list of all models available at the mantle endpoint for an AWS region. Args: - region: AWS region name, or ``None`` to use the default region. + region: AWS region name, or `None` to use the default region. Returns: - Newline-separated string of model IDs prefixed with ``" * "``. + Newline-separated string of model IDs prefixed with `" * "`. """ models = list_mantle_models() model_names = "\n * ".join([str(m.id) for m in models]) @@ -60,19 +60,19 @@ def create_bedrock_mantle_backend( Args: model_id (ModelIdentifier | str): The model to use, either as a - ``ModelIdentifier`` (which must have a ``bedrock_name``) or a raw + `ModelIdentifier` (which must have a `bedrock_name`) or a raw Bedrock model ID string. - region (str | None): AWS region name, or ``None`` to use the default - region (``"us-east-1"``). + region (str | None): AWS region name, or `None` to use the default + region (`"us-east-1"`). Returns: - OpenAIBackend: An ``OpenAIBackend`` configured to call the specified model + OpenAIBackend: An `OpenAIBackend` configured to call the specified model via AWS Bedrock Mantle. Raises: - Exception: If ``model_id`` is a ``ModelIdentifier`` with no ``bedrock_name`` + Exception: If `model_id` is a `ModelIdentifier` with no `bedrock_name` set. - AssertionError: If the ``AWS_BEARER_TOKEN_BEDROCK`` environment variable is + AssertionError: If the `AWS_BEARER_TOKEN_BEDROCK` environment variable is not set. Exception: If the specified model is not available in the target region. """ diff --git a/mellea/backends/cache.py b/mellea/backends/cache.py index 81f7a7ab3..13393b7f1 100644 --- a/mellea/backends/cache.py +++ b/mellea/backends/cache.py @@ -1,9 +1,9 @@ """Cache abstractions and implementations for model state. -Defines the abstract ``Cache`` interface with ``put``, ``get``, and -``current_size`` methods, and provides a concrete ``SimpleLRUCache`` that evicts +Defines the abstract `Cache` interface with `put`, `get`, and +`current_size` methods, and provides a concrete `SimpleLRUCache` that evicts the least-recently-used entry when capacity is exceeded — optionally calling an -``on_evict`` callback (e.g. to free GPU memory). Used by local HuggingFace backends +`on_evict` callback (e.g. to free GPU memory). Used by local HuggingFace backends to store and reuse KV cache state across requests. """ @@ -40,7 +40,7 @@ def get(self, key: str | int) -> Any | None: key (str | int): The cache key to look up. Returns: - Any | None: The cached value, or ``None`` if ``key`` has no cached entry. + Any | None: The cached value, or `None` if `key` has no cached entry. """ ... @@ -58,7 +58,7 @@ class SimpleLRUCache(Cache): """A simple `LRU `_ cache. Evicts the least-recently-used entry when capacity is exceeded, optionally - invoking an ``on_evict`` callback (e.g. to free GPU memory). Used by local + invoking an `on_evict` callback (e.g. to free GPU memory). Used by local HuggingFace backends to store and reuse KV cache state across requests. Args: @@ -92,7 +92,7 @@ def get(self, key: str | int) -> Any | None: key (str | int): The cache key to look up. Returns: - Any | None: The cached value, or ``None`` if ``key`` is not present. + Any | None: The cached value, or `None` if `key` is not present. """ if key not in self.cache: return None @@ -106,7 +106,7 @@ def put(self, key: str | int, value: Any) -> None: """Insert or update a value in the cache. If the cache is at capacity and the key is new, the least-recently-used - entry is evicted first, invoking the ``on_evict`` callback if set. + entry is evicted first, invoking the `on_evict` callback if set. Args: key (str | int): The cache key to store the value under. diff --git a/mellea/backends/dummy.py b/mellea/backends/dummy.py index 81c7ca2c3..691676cc9 100644 --- a/mellea/backends/dummy.py +++ b/mellea/backends/dummy.py @@ -14,18 +14,18 @@ class DummyBackend(Backend): """A backend for smoke testing. - Returns predetermined string responses in sequence, or ``"dummy"`` if no + Returns predetermined string responses in sequence, or `"dummy"` if no responses are provided. Intended for unit tests and integration smoke tests where real model inference is not needed. Args: responses (list[str] | None): Ordered list of strings to return on - successive ``generate_from_context`` calls, or ``None`` to always - return ``"dummy"``. + successive `generate_from_context` calls, or `None` to always + return `"dummy"`. Attributes: - idx (int): Index of the next response to return from ``responses``; - starts at ``0`` and increments on each call. + idx (int): Index of the next response to return from `responses`; + starts at `0` and increments on each call. """ def __init__(self, responses: list[str] | None): @@ -42,16 +42,16 @@ async def _generate_from_context( model_options: dict | None = None, tool_calls: bool = False, ) -> tuple[ModelOutputThunk[C], Context]: - """Return the next predetermined response for ``action`` given ``ctx``. + """Return the next predetermined response for `action` given `ctx`. - If ``responses`` is ``None``, always returns the string ``"dummy"``. - Otherwise returns the next item from ``responses`` in order. + If `responses` is `None`, always returns the string `"dummy"`. + Otherwise returns the next item from `responses` in order. Args: action (Component[C] | CBlock): The component or content block to generate a completion for. ctx (Context): The current generation context. - format (type[BaseModelSubclass] | None): Must be ``None``; constrained + format (type[BaseModelSubclass] | None): Must be `None`; constrained decoding is not supported. model_options (dict | None): Ignored by this backend. tool_calls (bool): Ignored by this backend. @@ -61,8 +61,8 @@ async def _generate_from_context( response and an updated context. Raises: - AssertionError: If ``format`` is not ``None``. - Exception: If all responses from ``responses`` have been consumed. + AssertionError: If `format` is not `None`. + Exception: If all responses from `responses` have been consumed. """ assert format is None, "The DummyBackend does not support constrained decoding." if self.responses is None: diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index cdff13aee..780a3f6c5 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -80,25 +80,25 @@ class HFAloraCacheInfo: """A dataclass for holding a KV cache and associated generation metadata. - Used by ``LocalHFBackend`` to store intermediate model state that can be + Used by `LocalHFBackend` to store intermediate model state that can be reused across generation requests via an LRU cache. Args: - kv_cache (DynamicCache | None): The HuggingFace ``DynamicCache`` holding - precomputed key/value tensors, or ``None`` if not available. + kv_cache (DynamicCache | None): The HuggingFace `DynamicCache` holding + precomputed key/value tensors, or `None` if not available. merged_token_ids (Any): Token IDs corresponding to the cached prefix. merged_attention (Any): Attention mask for the cached prefix tokens. q_end (int): Index of the last prompt token in the merged token sequence; - defaults to ``-1``. + defaults to `-1`. scores (Any): Optional logit scores from the generation step; defaults to - ``None``. + `None`. Attributes: kv_cache (DynamicCache | None): The cached key/value tensors. merged_token_ids (Any): Token IDs for the cached prefix. merged_attention (Any): Attention mask for the cached prefix. q_end (int): End index of the prompt portion in merged token IDs. - scores (Any): Logit scores from generation, or ``None``. + scores (Any): Logit scores from generation, or `None`. """ kv_cache: DynamicCache | None @@ -220,22 +220,22 @@ class LocalHFBackend(FormatterBackend, AdapterMixin): Args: model_id (str | ModelIdentifier): Used to load the model and tokenizer via - HuggingFace ``Auto*`` classes. + HuggingFace `Auto*` classes. formatter (ChatFormatter | None): Formatter for rendering components into - prompts. Defaults to ``TemplateFormatter``. - use_caches (bool): If ``False``, KV caching is disabled even if a ``Cache`` + prompts. Defaults to `TemplateFormatter`. + use_caches (bool): If `False`, KV caching is disabled even if a `Cache` is provided. cache (Cache | None): Caching strategy; defaults to - ``SimpleLRUCache(0, on_evict=_cleanup_kv_cache)``. + `SimpleLRUCache(0, on_evict=_cleanup_kv_cache)`. custom_config (TransformersTorchConfig | None): Override for - tokenizer/model/device; if provided, ``model_id`` is not used for loading. - default_to_constraint_checking_alora (bool): If ``False``, aLoRA constraint + tokenizer/model/device; if provided, `model_id` is not used for loading. + default_to_constraint_checking_alora (bool): If `False`, aLoRA constraint checking is deactivated; mainly for benchmarking and debugging. model_options (dict | None): Default model options for generation requests. Attributes: to_mellea_model_opts_map (dict): Mapping from HF-specific option names to - Mellea ``ModelOption`` sentinel keys. + Mellea `ModelOption` sentinel keys. from_mellea_model_opts_map (dict): Mapping from Mellea sentinel keys to HF-specific option names. """ @@ -350,9 +350,9 @@ async def _generate_from_context( model_options: dict | None = None, tool_calls: bool = False, ) -> tuple[ModelOutputThunk[C], Context]: - """Generate a completion for ``action`` given ``ctx`` using the HuggingFace model. + """Generate a completion for `action` given `ctx` using the HuggingFace model. - Automatically routes ``Requirement`` and ``Intrinsic`` actions to their + Automatically routes `Requirement` and `Intrinsic` actions to their corresponding aLoRA adapters when available. Args: @@ -363,12 +363,12 @@ async def _generate_from_context( structured/constrained output decoding via llguidance. model_options (dict | None): Per-call model options that override the backend's defaults. - tool_calls (bool): If ``True``, expose available tools to the model and + tool_calls (bool): If `True`, expose available tools to the model and parse tool-call responses. Returns: tuple[ModelOutputThunk[C], Context]: A thunk holding the (lazy) model output - and an updated context that includes ``action`` and the new output. + and an updated context that includes `action` and the new output. """ span = start_generate_span( backend=self, action=action, ctx=ctx, format=format, tool_calls=tool_calls @@ -1035,8 +1035,8 @@ async def processing( """Accumulate decoded text from a streaming chunk or full generation output. For streaming responses the chunk is an already-decoded string from - ``AsyncTextIteratorStreamer``; for non-streaming it is a - ``GenerateDecoderOnlyOutput`` that is decoded here. + `AsyncTextIteratorStreamer`; for non-streaming it is a + `GenerateDecoderOnlyOutput` that is decoded here. Args: mot (ModelOutputThunk): The output thunk being populated. @@ -1082,7 +1082,7 @@ async def post_processing( class used during generation, if any. tool_calls (bool): Whether tool calling was enabled for this request. tools (dict[str, AbstractMelleaTool]): Available tools, keyed by name. - seed: The random seed used during generation, or ``None``. + seed: The random seed used during generation, or `None`. input_ids: The prompt token IDs; used to compute token counts and for KV cache bookkeeping. """ @@ -1261,7 +1261,7 @@ async def generate_from_raw( """Generate completions for multiple actions without chat templating. Passes formatted prompt strings directly to the HuggingFace model's - ``generate`` method as a batch. Tool calling is not supported. + `generate` method as a batch. Tool calling is not supported. Args: actions (Sequence[Component[C] | CBlock]): Actions to generate completions for. @@ -1374,20 +1374,20 @@ async def generate_from_raw( # region cache management def cache_get(self, id: str | int) -> HFAloraCacheInfo | None: - """Retrieve a cached ``HFAloraCacheInfo`` entry by its key. + """Retrieve a cached `HFAloraCacheInfo` entry by its key. Args: id (str | int): The cache key to look up. Returns: - HFAloraCacheInfo | None: The cached entry, or ``None`` if not found. + HFAloraCacheInfo | None: The cached entry, or `None` if not found. """ v = self._cache.get(id) assert v is None or type(v) is HFAloraCacheInfo return v def cache_put(self, id: str | int, v: HFAloraCacheInfo): - """Store an ``HFAloraCacheInfo`` entry in the cache under the given key. + """Store an `HFAloraCacheInfo` entry in the cache under the given key. Args: id (str | int): The cache key to store the entry under. @@ -1477,7 +1477,7 @@ def base_model_name(self): def add_adapter(self, adapter: LocalHFAdapter): """Register a LoRA/aLoRA adapter with this backend so it can be loaded later. - Downloads the adapter weights (via ``adapter.get_local_hf_path``) and records + Downloads the adapter weights (via `adapter.get_local_hf_path`) and records the adapter in the backend's registry. The adapter must not already be registered with a different backend. @@ -1485,7 +1485,7 @@ def add_adapter(self, adapter: LocalHFAdapter): adapter (LocalHFAdapter): The adapter to register with this backend. Raises: - Exception: If ``adapter`` has already been added to a different backend. + Exception: If `adapter` has already been added to a different backend. """ if adapter.backend is not None: if adapter.backend is self: @@ -1511,12 +1511,12 @@ def add_adapter(self, adapter: LocalHFAdapter): def load_adapter(self, adapter_qualified_name: str): """Load a previously registered adapter into the underlying HuggingFace model. - The adapter must have been registered via ``add_adapter`` first. Do not call + The adapter must have been registered via `add_adapter` first. Do not call this method while generation requests are in progress. Args: - adapter_qualified_name (str): The ``adapter.qualified_name`` of the adapter - to load (i.e. ``"_"``) + adapter_qualified_name (str): The `adapter.qualified_name` of the adapter + to load (i.e. `"_"`) Raises: ValueError: If no adapter with the given qualified name has been added to @@ -1559,7 +1559,7 @@ def unload_adapter(self, adapter_qualified_name: str): method returns without error. Args: - adapter_qualified_name (str): The ``adapter.qualified_name`` of the adapter + adapter_qualified_name (str): The `adapter.qualified_name` of the adapter to unload. """ # Check if the backend knows about this adapter. @@ -1579,8 +1579,8 @@ def list_adapters(self) -> list[str]: """List the qualified names of all adapters currently loaded in this backend. Returns: - list[str]: Qualified adapter names (i.e. ``adapter.qualified_name``) for - all adapters that have been loaded via ``load_adapter``. + list[str]: Qualified adapter names (i.e. `adapter.qualified_name`) for + all adapters that have been loaded via `load_adapter`. """ return list(self._loaded_adapters.keys()) diff --git a/mellea/backends/kv_block_helpers.py b/mellea/backends/kv_block_helpers.py index 4388b1222..4a9152901 100644 --- a/mellea/backends/kv_block_helpers.py +++ b/mellea/backends/kv_block_helpers.py @@ -1,8 +1,8 @@ """Low-level utilities for concatenating transformer KV caches (KV smashing). -Provides functions for merging ``DynamicCache`` and legacy tuple caches along the -time axis (``merge_dynamic_caches``, ``legacy_cache_smash``), and -``tokens_to_legacy_cache`` for converting a tokenized prompt into a prefilled KV +Provides functions for merging `DynamicCache` and legacy tuple caches along the +time axis (`merge_dynamic_caches`, `legacy_cache_smash`), and +`tokens_to_legacy_cache` for converting a tokenized prompt into a prefilled KV cache. These helpers are used internally by local HuggingFace backends that reuse cached prefix computations across multiple generation calls. """ @@ -25,10 +25,10 @@ def legacy_cache_smash(a: LegacyCache, b: LegacyCache) -> LegacyCache: Args: a: First legacy KV cache (tuple of per-layer (K, V) tensor pairs). - b: Second legacy KV cache to concatenate after ``a``. + b: Second legacy KV cache to concatenate after `a`. Returns: - New legacy cache with ``b`` appended to ``a`` along the sequence dimension. + New legacy cache with `b` appended to `a` along the sequence dimension. """ legacy_merged = tuple( (torch.cat([a[i][0], b[i][0]], dim=2), torch.cat([a[i][1], b[i][1]], dim=2)) @@ -41,10 +41,10 @@ def merge_dynamic_caches(caches: Iterable[DynamicCache]) -> DynamicCache: """Merges two DynamicCache Ks and Vs along the time axis. Args: - caches: Iterable of ``DynamicCache`` objects to merge in order. + caches: Iterable of `DynamicCache` objects to merge in order. Returns: - A single ``DynamicCache`` with all caches concatenated along the sequence dimension. + A single `DynamicCache` with all caches concatenated along the sequence dimension. """ legacies = [c.to_legacy_cache() for c in caches] # type: ignore assert len(legacies) >= 1 @@ -59,9 +59,9 @@ def tokens_to_legacy_cache( Args: model: The HuggingFace model used for prefill. - device: Target device string (e.g. ``"cuda"``, ``"cpu"``). - tokens_or_cache: Either a ``BatchEncoding`` to prefill, or an existing - ``DynamicCache`` to convert directly. + device: Target device string (e.g. `"cuda"`, `"cpu"`). + tokens_or_cache: Either a `BatchEncoding` to prefill, or an existing + `DynamicCache` to convert directly. Returns: Legacy KV cache representation as a tuple of per-layer (K, V) tensor pairs. diff --git a/mellea/backends/litellm.py b/mellea/backends/litellm.py index 78a4b8ec4..5e79155a8 100644 --- a/mellea/backends/litellm.py +++ b/mellea/backends/litellm.py @@ -56,17 +56,17 @@ class LiteLLMBackend(FormatterBackend): Args: model_id (str): The LiteLLM model identifier string; typically - ``"//"``. + `"//"`. formatter (ChatFormatter | None): Formatter for rendering components. - Defaults to ``TemplateFormatter``. + Defaults to `TemplateFormatter`. base_url (str | None): Base URL for the LLM API endpoint; defaults to the Ollama local endpoint. model_options (dict | None): Default model options for generation requests. Attributes: to_mellea_model_opts_map (dict): Mapping from backend-specific option names to - Mellea ``ModelOption`` sentinel keys. - from_mellea_model_opts_map (dict): Mapping from Mellea ``ModelOption`` sentinel + Mellea `ModelOption` sentinel keys. + from_mellea_model_opts_map (dict): Mapping from Mellea `ModelOption` sentinel keys to backend-specific option names. """ @@ -134,10 +134,10 @@ async def _generate_from_context( model_options: dict | None = None, tool_calls: bool = False, ) -> tuple[ModelOutputThunk[C], Context]: - """Generate a completion for ``action`` given ``ctx`` via the LiteLLM chat API. + """Generate a completion for `action` given `ctx` via the LiteLLM chat API. - Delegates to ``_generate_from_chat_context_standard``. Only chat contexts are - supported; raises ``NotImplementedError`` otherwise. + Delegates to `_generate_from_chat_context_standard`. Only chat contexts are + supported; raises `NotImplementedError` otherwise. Args: action (Component[C] | CBlock): The component or content block to generate @@ -147,12 +147,12 @@ async def _generate_from_context( structured/constrained output decoding. model_options (dict | None): Per-call model options that override the backend's defaults. - tool_calls (bool): If ``True``, expose available tools to the model and + tool_calls (bool): If `True`, expose available tools to the model and parse tool-call responses. Returns: tuple[ModelOutputThunk[C], Context]: A thunk holding the (lazy) model output - and an updated context that includes ``action`` and the new output. + and an updated context that includes `action` and the new output. """ assert ctx.is_chat_context, NotImplementedError( "The Openai backend only supports chat-like contexts." @@ -395,9 +395,9 @@ async def processing( ): """Accumulate content and thinking tokens from a single LiteLLM response chunk. - Called during generation for each ``ModelResponse`` (non-streaming) or - ``ModelResponseStream`` chunk (streaming). Tool call parsing is deferred to - ``post_processing``. + Called during generation for each `ModelResponse` (non-streaming) or + `ModelResponseStream` chunk (streaming). Tool call parsing is deferred to + `post_processing`. Args: mot (ModelOutputThunk): The output thunk being populated. @@ -475,7 +475,7 @@ async def post_processing( used for logging. tools (dict[str, AbstractMelleaTool]): Available tools, keyed by name. thinking: The thinking/reasoning effort level passed to the model, or - ``None`` if reasoning mode was not enabled. + `None` if reasoning mode was not enabled. _format: The structured output format class used during generation, if any. """ # Reconstruct the chat_response from chunks if streamed. @@ -634,7 +634,7 @@ async def generate_from_raw( actions (Sequence[Component[C] | CBlock]): Actions to generate completions for. ctx (Context): The current generation context. format (type[BaseModelSubclass] | None): Optional Pydantic model for - structured output; passed as ``guided_json`` in the request body. + structured output; passed as `guided_json` in the request body. model_options (dict | None): Per-call model options. tool_calls (bool): Ignored; tool calling is not supported on this endpoint. diff --git a/mellea/backends/model_ids.py b/mellea/backends/model_ids.py index 61e163c9c..98a1e8e43 100644 --- a/mellea/backends/model_ids.py +++ b/mellea/backends/model_ids.py @@ -1,6 +1,6 @@ -"""``ModelIdentifier`` dataclass and a catalog of pre-defined model IDs. +"""`ModelIdentifier` dataclass and a catalog of pre-defined model IDs. -``ModelIdentifier`` is a frozen dataclass that groups the platform-specific name +`ModelIdentifier` is a frozen dataclass that groups the platform-specific name variants for a model (HuggingFace, Ollama, WatsonX, MLX, OpenAI, Bedrock) so that a single constant can be passed to any backend without manual string translation. The module also ships a curated catalog of ready-to-use constants for popular @@ -19,13 +19,13 @@ class ModelIdentifier: 2. Using raw strings is annoying because: no autocomplete, typos, hallucinated names, mismatched model and tokenizer names, etc. Args: - hf_model_name (str | None): HuggingFace Hub model repository ID (e.g. ``"ibm-granite/granite-3.3-8b-instruct"``). - ollama_name (str | None): Ollama model tag (e.g. ``"granite3.3:8b"``). - watsonx_name (str | None): WatsonX AI model ID (e.g. ``"ibm/granite-3-2b-instruct"``). + hf_model_name (str | None): HuggingFace Hub model repository ID (e.g. `"ibm-granite/granite-3.3-8b-instruct"`). + ollama_name (str | None): Ollama model tag (e.g. `"granite3.3:8b"`). + watsonx_name (str | None): WatsonX AI model ID (e.g. `"ibm/granite-3-2b-instruct"`). mlx_name (str | None): MLX model identifier for Apple Silicon inference. - openai_name (str | None): OpenAI API model name (e.g. ``"gpt-5.1"``). - bedrock_name (str | None): AWS Bedrock model ID (e.g. ``"openai.gpt-oss-20b"``). - hf_tokenizer_name (str | None): HuggingFace tokenizer ID; defaults to ``hf_model_name`` if ``None``. + openai_name (str | None): OpenAI API model name (e.g. `"gpt-5.1"`). + bedrock_name (str | None): AWS Bedrock model ID (e.g. `"openai.gpt-oss-20b"`). + hf_tokenizer_name (str | None): HuggingFace tokenizer ID; defaults to `hf_model_name` if `None`. """ diff --git a/mellea/backends/model_options.py b/mellea/backends/model_options.py index 428dcab00..c8e50911d 100644 --- a/mellea/backends/model_options.py +++ b/mellea/backends/model_options.py @@ -12,13 +12,13 @@ class ModelOption: Create a dictionary containing model options like this: - ```python + ``python from mellea.backends import ModelOption model_options = { ModelOption.TEMPERATURE : 0.0, ModelOption.SYSTEM_PROMPT : "You are a helpful assistant" } - ``` + `` Attributes: TOOLS (str): Sentinel key for a list or dict of tools to expose for tool calling. @@ -44,7 +44,7 @@ class ModelOption: @staticmethod def replace_keys(options: dict, from_to: dict[str, str]) -> dict[str, Any]: - """Return a new dict with selected keys in ``options`` renamed according to ``from_to``. + """Return a new dict with selected keys in `options` renamed according to `from_to`. Returns a new dict with the keys in `options` replaced with the corresponding value for that key in `from_to`. @@ -56,14 +56,14 @@ def replace_keys(options: dict, from_to: dict[str, str]) -> dict[str, Any]: the source key is always absent in the output. Example: - ```python + ``python >>> options = {"k1": "v1", "k2": "v2", "M1": "m1"} >>> from_to = {"k1": "M1", "k2": "M2"} >>> new_options = replace_keys(options, from_to) >>> print(new_options) ... {"M1": "m1", "M2": "v2"} - ``` + `` * Notice that "M1" keeps the original value "m1", rather than "v1". * Notice that both "k1" and "k2" are absent in the output. @@ -112,16 +112,16 @@ def replace_keys(options: dict, from_to: dict[str, str]) -> dict[str, Any]: @staticmethod def remove_special_keys(model_options: dict[str, Any]) -> dict[str, Any]: - """Return a copy of ``model_options`` with all sentinel-valued keys removed. + """Return a copy of `model_options` with all sentinel-valued keys removed. - Sentinel keys are those whose names start with ``@@@`` (e.g. ``ModelOption.TOOLS``). + Sentinel keys are those whose names start with `@@@` (e.g. `ModelOption.TOOLS`). These are Mellea-internal keys that must not be forwarded to backend APIs. Args: model_options (dict[str, Any]): A model options dictionary that may contain sentinel keys. Returns: - dict[str, Any]: A new dictionary with all ``@@@``-prefixed keys omitted. + dict[str, Any]: A new dictionary with all `@@@`-prefixed keys omitted. """ new_options = {} for k, v in model_options.items(): @@ -133,7 +133,7 @@ def remove_special_keys(model_options: dict[str, Any]) -> dict[str, Any]: def merge_model_options( persistent_opts: dict[str, Any], overwrite_opts: dict[str, Any] | None ) -> dict[str, Any]: - """Merge two model-options dicts, with ``overwrite_opts`` taking precedence on conflicts. + """Merge two model-options dicts, with `overwrite_opts` taking precedence on conflicts. Creates a new dict that contains all keys and values from persistent opts and overwrite opts. If there are duplicate keys, overwrite opts key value pairs will be used. @@ -141,7 +141,7 @@ def merge_model_options( Args: persistent_opts (dict[str, Any]): Base model options (lower precedence). overwrite_opts (dict[str, Any] | None): Per-call model options that override - ``persistent_opts`` on key conflicts; ``None`` is treated as empty. + `persistent_opts` on key conflicts; `None` is treated as empty. Returns: dict[str, Any]: A new merged dictionary. diff --git a/mellea/backends/ollama.py b/mellea/backends/ollama.py index 014734f2c..ff4eac749 100644 --- a/mellea/backends/ollama.py +++ b/mellea/backends/ollama.py @@ -43,18 +43,18 @@ class OllamaModelBackend(FormatterBackend): Args: model_id (str | ModelIdentifier): Ollama model ID. If a - ``ModelIdentifier`` is passed, its ``ollama_name`` attribute must + `ModelIdentifier` is passed, its `ollama_name` attribute must be set. formatter (ChatFormatter | None): Formatter for rendering components. - Defaults to ``TemplateFormatter``. + Defaults to `TemplateFormatter`. base_url (str | None): Ollama server endpoint; defaults to - ``env(OLLAMA_HOST)`` or ``http://localhost:11434``. + `env(OLLAMA_HOST)` or `http://localhost:11434`. model_options (dict | None): Default model options for generation requests. Attributes: to_mellea_model_opts_map (dict): Mapping from Ollama-specific option names - to Mellea ``ModelOption`` sentinel keys. - from_mellea_model_opts_map (dict): Mapping from Mellea ``ModelOption`` + to Mellea `ModelOption` sentinel keys. + from_mellea_model_opts_map (dict): Mapping from Mellea `ModelOption` sentinel keys to Ollama-specific option names. """ @@ -267,9 +267,9 @@ async def _generate_from_context( model_options: dict | None = None, tool_calls: bool = False, ) -> tuple[ModelOutputThunk[C], Context]: - """Generate a completion for ``action`` given ``ctx`` via the Ollama chat API. + """Generate a completion for `action` given `ctx` via the Ollama chat API. - Delegates to ``generate_from_chat_context``. Only chat contexts are supported. + Delegates to `generate_from_chat_context`. Only chat contexts are supported. Args: action (Component[C] | CBlock): The component or content block to generate @@ -279,12 +279,12 @@ async def _generate_from_context( structured/constrained output decoding. model_options (dict | None): Per-call model options that override the backend's defaults. - tool_calls (bool): If ``True``, expose available tools to the model and + tool_calls (bool): If `True`, expose available tools to the model and parse tool-call responses. Returns: tuple[ModelOutputThunk[C], Context]: A thunk holding the (lazy) model output - and an updated context that includes ``action`` and the new output. + and an updated context that includes `action` and the new output. """ from ..telemetry.backend_instrumentation import start_generate_span @@ -319,7 +319,7 @@ async def generate_from_chat_context( ) -> ModelOutputThunk[C]: """Generate a new completion from the provided context using this backend's formatter. - Treats the ``Context`` as a chat history and uses the ``ollama.Client.chat()`` + Treats the `Context` as a chat history and uses the `ollama.Client.chat()` interface to generate a completion. Returns a thunk that lazily resolves the model output. @@ -330,7 +330,7 @@ async def generate_from_chat_context( _format (type[BaseModelSubclass] | None): Optional Pydantic model class for structured output decoding. model_options (dict | None): Per-call model options. - tool_calls (bool): If ``True``, expose available tools and parse responses. + tool_calls (bool): If `True`, expose available tools and parse responses. Returns: ModelOutputThunk[C]: A thunk holding the (lazy) model output. @@ -608,9 +608,9 @@ async def processing( ): """Accumulate text and tool calls from a single Ollama ChatResponse chunk. - Called for each streaming or non-streaming ``ollama.ChatResponse``. Also + Called for each streaming or non-streaming `ollama.ChatResponse`. Also extracts tool call requests inline and merges the chunk into the running - aggregated response stored in ``mot._meta["chat_response"]``. + aggregated response stored in `mot._meta["chat_response"]`. Args: mot (ModelOutputThunk): The output thunk being populated. diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index ad0031189..e3eaa9b14 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -69,24 +69,24 @@ class OpenAIBackend(FormatterBackend): Args: model_id (str | ModelIdentifier): OpenAI-compatible model identifier. - Defaults to ``model_ids.OPENAI_GPT_5_1``. + Defaults to `model_ids.OPENAI_GPT_5_1`. formatter (ChatFormatter | None): Formatter for rendering components. - Defaults to ``TemplateFormatter``. + Defaults to `TemplateFormatter`. base_url (str | None): Base URL for the API endpoint; defaults to the standard OpenAI endpoint if not set. model_options (dict | None): Default model options for generation requests. - default_to_constraint_checking_alora (bool): If ``False``, deactivates aLoRA + default_to_constraint_checking_alora (bool): If `False`, deactivates aLoRA constraint checking; primarily for benchmarking and debugging. - api_key (str | None): API key; falls back to ``OPENAI_API_KEY`` env var. + api_key (str | None): API key; falls back to `OPENAI_API_KEY` env var. kwargs: Additional keyword arguments forwarded to the OpenAI client. Attributes: to_mellea_model_opts_map_chats (dict): Mapping from chat-endpoint option names - to Mellea ``ModelOption`` sentinel keys. + to Mellea `ModelOption` sentinel keys. from_mellea_model_opts_map_chats (dict): Mapping from Mellea sentinel keys to chat-endpoint option names. to_mellea_model_opts_map_completions (dict): Mapping from completions-endpoint - option names to Mellea ``ModelOption`` sentinel keys. + option names to Mellea `ModelOption` sentinel keys. from_mellea_model_opts_map_completions (dict): Mapping from Mellea sentinel keys to completions-endpoint option names. """ @@ -226,7 +226,7 @@ def filter_openai_client_kwargs(**kwargs) -> dict: kwargs: Arbitrary keyword arguments to filter. Returns: - dict: A dict containing only keys accepted by ``openai.OpenAI.__init__``. + dict: A dict containing only keys accepted by `openai.OpenAI.__init__`. """ openai_params = set(inspect.signature(openai.OpenAI.__init__).parameters.keys()) # type: ignore openai_params.discard("self") # Remove 'self' parameter @@ -242,7 +242,7 @@ def filter_chat_completions_kwargs(self, model_options: dict) -> dict: model_options (dict): Model options dict that may contain non-chat keys. Returns: - dict: A dict containing only keys accepted by ``chat.completions.create``. + dict: A dict containing only keys accepted by `chat.completions.create`. """ from openai.resources.chat.completions import Completions @@ -260,7 +260,7 @@ def filter_completions_kwargs(self, model_options: dict) -> dict: model_options (dict): Model options dict that may contain non-completions keys. Returns: - dict: A dict containing only keys accepted by ``completions.create``. + dict: A dict containing only keys accepted by `completions.create`. """ from openai.resources.completions import Completions @@ -338,9 +338,9 @@ async def _generate_from_context( model_options: dict | None = None, tool_calls: bool = False, ) -> tuple[ModelOutputThunk[C], Context]: - """Generate a completion for ``action`` given ``ctx`` via the OpenAI chat API. + """Generate a completion for `action` given `ctx` via the OpenAI chat API. - Delegates to ``generate_from_chat_context``. Only chat contexts are supported. + Delegates to `generate_from_chat_context`. Only chat contexts are supported. Args: action (Component[C] | CBlock): The component or content block to generate @@ -350,12 +350,12 @@ async def _generate_from_context( structured/constrained output decoding. model_options (dict | None): Per-call model options that override the backend's defaults. - tool_calls (bool): If ``True``, expose available tools to the model and + tool_calls (bool): If `True`, expose available tools to the model and parse tool-call responses. Returns: tuple[ModelOutputThunk[C], Context]: A thunk holding the (lazy) model output - and an updated context that includes ``action`` and the new output. + and an updated context that includes `action` and the new output. """ from ..telemetry.backend_instrumentation import start_generate_span @@ -395,7 +395,7 @@ async def generate_from_chat_context( model_options: dict | None = None, tool_calls: bool = False, ) -> tuple[ModelOutputThunk[C], Context]: - """Generate a new completion from the provided Context using this backend's ``Formatter``. + """Generate a new completion from the provided Context using this backend's `Formatter`. Formats the context and action into OpenAI-compatible chat messages, submits the request asynchronously, and returns a thunk that lazily resolves the output. @@ -407,11 +407,11 @@ async def generate_from_chat_context( _format (type[BaseModelSubclass] | None): Optional Pydantic model class for structured output decoding. model_options (dict | None): Per-call model options. - tool_calls (bool): If ``True``, expose available tools and parse responses. + tool_calls (bool): If `True`, expose available tools and parse responses. Returns: tuple[ModelOutputThunk[C], Context]: A thunk holding the (lazy) model output - and an updated context that includes ``action`` and the new output. + and an updated context that includes `action` and the new output. """ await self.do_generate_walk(action) @@ -580,8 +580,8 @@ async def processing( ): """Accumulate content from a single OpenAI response object into the output thunk. - Called for each ``ChatCompletion`` (non-streaming) or ``ChatCompletionChunk`` - (streaming). Tool call parsing is deferred to ``post_processing``. + Called for each `ChatCompletion` (non-streaming) or `ChatCompletionChunk` + (streaming). Tool call parsing is deferred to `post_processing`. Args: mot (ModelOutputThunk): The output thunk being populated. @@ -655,9 +655,9 @@ async def post_processing( tools (dict[str, AbstractMelleaTool]): Available tools, keyed by name. conversation (list[dict]): The chat conversation sent to the model, used for logging. - thinking: The reasoning effort level passed to the model, or ``None`` + thinking: The reasoning effort level passed to the model, or `None` if reasoning mode was not enabled. - seed: The random seed used during generation, or ``None``. + seed: The random seed used during generation, or `None`. _format: The structured output format class used during generation, if any. """ # Reconstruct the chat_response from chunks if streamed. diff --git a/mellea/backends/tools.py b/mellea/backends/tools.py index 125b19938..fb321a9a8 100644 --- a/mellea/backends/tools.py +++ b/mellea/backends/tools.py @@ -1,6 +1,6 @@ """LLM tool definitions, parsing, and validation for mellea backends. -Provides the ``MelleaTool`` class (and the ``@tool`` decorator shorthand) for +Provides the `MelleaTool` class (and the `@tool` decorator shorthand) for wrapping Python callables as OpenAI-compatible tool schemas, with factory methods for LangChain and smolagents interoperability. Also includes helpers for converting tool lists to JSON, extracting tool call requests from raw LLM output strings, and @@ -74,14 +74,14 @@ def from_langchain(cls, tool: Any): """Create a MelleaTool from a LangChain tool object. Args: - tool (Any): A ``langchain_core.tools.BaseTool`` instance to wrap. + tool (Any): A `langchain_core.tools.BaseTool` instance to wrap. Returns: MelleaTool: A Mellea tool wrapping the LangChain tool. Raises: - ImportError: If ``langchain-core`` is not installed. - ValueError: If ``tool`` is not a ``BaseTool`` instance. + ImportError: If `langchain-core` is not installed. + ValueError: If `tool` is not a `BaseTool` instance. """ try: from langchain_core.tools import BaseTool # type: ignore[import-not-found] @@ -180,7 +180,7 @@ def from_callable(cls, func: Callable, name: str | None = None): Args: func (Callable): The Python callable to wrap as a tool. - name (str | None): Optional name override; defaults to ``func.__name__``. + name (str | None): Optional name override; defaults to `func.__name__`. Returns: MelleaTool: A Mellea tool wrapping the callable. @@ -274,8 +274,8 @@ def add_tools_from_model_options( Args: tools_dict: Mutable mapping of tool name to tool instance; modified in-place. - model_options: Model options dict that may contain a ``ModelOption.TOOLS`` - entry (either a list of ``MelleaTool`` or a ``dict[str, MelleaTool]``). + model_options: Model options dict that may contain a `ModelOption.TOOLS` + entry (either a list of `MelleaTool` or a `dict[str, MelleaTool]`). """ model_opts_tools = model_options.get(ModelOption.TOOLS, None) if model_opts_tools is None: @@ -317,8 +317,8 @@ def add_tools_from_context_actions( Args: tools_dict: Mutable mapping of tool name to tool instance; modified in-place. - ctx_actions: List of ``Component`` or ``CBlock`` objects whose template - representations may declare tools, or ``None`` to skip. + ctx_actions: List of `Component` or `CBlock` objects whose template + representations may declare tools, or `None` to skip. """ if ctx_actions is None: return @@ -339,7 +339,7 @@ def convert_tools_to_json(tools: dict[str, AbstractMelleaTool]) -> list[dict]: """Convert tools to json dict representation. Args: - tools: Mapping of tool name to ``AbstractMelleaTool`` instance. + tools: Mapping of tool name to `AbstractMelleaTool` instance. Returns: List of OpenAI-compatible JSON tool schema dicts, one per tool. @@ -359,7 +359,7 @@ def json_extraction(text: str) -> Generator[dict, None, None]: text: Input string potentially containing one or more JSON objects. Returns: - A generator that yields each valid JSON object found in ``text``, + A generator that yields each valid JSON object found in `text`, in order of appearance. """ index = 0 @@ -383,15 +383,15 @@ def json_extraction(text: str) -> Generator[dict, None, None]: def find_func(d) -> tuple[str | None, Mapping | None]: """Find the first function in a json-like dictionary. - Most llms output tool requests in the form ``...{"name": string, "arguments": {}}...`` + Most llms output tool requests in the form `...{"name": string, "arguments": {}}...` Args: - d: A JSON-like Python object (typically a ``dict``) to search for a function + d: A JSON-like Python object (typically a `dict`) to search for a function call record. Returns: - A ``(name, args)`` tuple where ``name`` is the tool name string and ``args`` - is the arguments mapping, or ``(None, None)`` if no function call was found. + A `(name, args)` tuple where `name` is the tool name string and `args` + is the arguments mapping, or `(None, None)` if no function call was found. """ if not isinstance(d, dict): return None, None @@ -423,7 +423,7 @@ def parse_tools(llm_response: str) -> list[tuple[str, Mapping]]: llm_response: Raw string output from a language model. Returns: - List of ``(tool_name, arguments)`` tuples for each tool call found. + List of `(tool_name, arguments)` tuples for each tool call found. """ processed = " ".join(llm_response.split()) @@ -629,10 +629,10 @@ def validate_tool_arguments( # so that all backends don't need it installed. # https://github.com/ollama/ollama-python/blob/60e7b2f9ce710eeb57ef2986c46ea612ae7516af/ollama/_types.py#L19-L101 class SubscriptableBaseModel(BaseModel): - """Pydantic ``BaseModel`` subclass that also supports subscript (``[]``) access. + """Pydantic `BaseModel` subclass that also supports subscript (`[]`) access. Imported from the Ollama Python client. Allows model fields to be accessed - via ``model["field"]`` in addition to ``model.field``, which is required for + via `model["field"]` in addition to `model.field`, which is required for compatibility with Ollama's internal response parsing. """ @@ -711,11 +711,11 @@ def get(self, key: str, default: Any = None) -> Any: Args: key (str): The field name to look up on the model. - default (Any): Value to return when ``key`` is not a field on the model. - Defaults to ``None``. + default (Any): Value to return when `key` is not a field on the model. + Defaults to `None`. Returns: - Any: The field value if the attribute exists, otherwise ``default``. + Any: The field value if the attribute exists, otherwise `default`. >>> msg = Message(role='user') >>> msg.get('role') @@ -738,10 +738,10 @@ class OllamaTool(SubscriptableBaseModel): Represents the JSON structure that Ollama (and OpenAI-compatible endpoints) expect when a tool is passed to the chat API. Mellea builds these objects internally via - ``convert_function_to_ollama_tool`` and never exposes them to end users directly. + `convert_function_to_ollama_tool` and never exposes them to end users directly. Attributes: - type (str | None): Tool type; always ``"function"`` for function-calling tools. + type (str | None): Tool type; always `"function"` for function-calling tools. function (Function | None): Nested object containing the function name, description, and parameters schema. """ @@ -749,7 +749,7 @@ class OllamaTool(SubscriptableBaseModel): type: str | None = "function" class Function(SubscriptableBaseModel): - """Pydantic model for the ``function`` field of an Ollama tool schema, imported from the Ollama Python SDK. + """Pydantic model for the `function` field of an Ollama tool schema, imported from the Ollama Python SDK. Attributes: name (str | None): The name of the function being described. @@ -761,11 +761,11 @@ class Function(SubscriptableBaseModel): description: str | None = None class Parameters(SubscriptableBaseModel): - """Pydantic model for the ``parameters`` field of an Ollama function schema, imported from the Ollama Python SDK. + """Pydantic model for the `parameters` field of an Ollama function schema, imported from the Ollama Python SDK. Attributes: - type (Literal["object"] | None): Always ``"object"`` for function parameters. - defs (Any | None): JSON Schema ``$defs`` for referenced sub-schemas. + type (Literal["object"] | None): Always `"object"` for function parameters. + defs (Any | None): JSON Schema `$defs` for referenced sub-schemas. items (Any | None): Array item schema, if applicable. required (Sequence[str] | None): List of required parameter names. properties (Mapping[str, Property] | None): Parameter property definitions. @@ -856,10 +856,10 @@ def convert_function_to_ollama_tool( Args: func: The Python callable to convert. - name: Optional override for the tool name; defaults to ``func.__name__``. + name: Optional override for the tool name; defaults to `func.__name__`. Returns: - An ``OllamaTool`` instance representing the function as an OpenAI-compatible + An `OllamaTool` instance representing the function as an OpenAI-compatible tool schema. """ doc_string_hash = str(hash(inspect.getdoc(func))) diff --git a/mellea/backends/utils.py b/mellea/backends/utils.py index 5044a0f2a..f31373075 100644 --- a/mellea/backends/utils.py +++ b/mellea/backends/utils.py @@ -1,10 +1,10 @@ """Shared utility functions used across formatter-based backend implementations. -Provides ``to_chat``, which converts a ``Context`` and a ``Component`` action into -the list of role/content dicts expected by ``apply_chat_template``; and -``to_tool_calls``, which parses a raw model output string into validated -``ModelToolCall`` objects. These helpers are consumed internally by all -``FormatterBackend`` subclasses. +Provides `to_chat`, which converts a `Context` and a `Component` action into +the list of role/content dicts expected by `apply_chat_template`; and +`to_tool_calls`, which parses a raw model output string into validated +`ModelToolCall` objects. These helpers are consumed internally by all +`FormatterBackend` subclasses. """ from __future__ import annotations @@ -56,7 +56,7 @@ def to_chat( system_prompt: Optional system prompt to prepend; overrides any system message in the context. Returns: - List of role/content dicts suitable for ``apply_chat_template``. + List of role/content dicts suitable for `apply_chat_template`. """ assert ctx.is_chat_context @@ -94,11 +94,11 @@ def to_tool_calls( """Parse a tool call string. Args: - tools: Mapping of tool name to the corresponding ``AbstractMelleaTool`` object. + tools: Mapping of tool name to the corresponding `AbstractMelleaTool` object. decoded_result: Raw model output string that may contain tool call markup. Returns: - Dict mapping tool name to validated ``ModelToolCall``, or ``None`` if no tool calls were found. + Dict mapping tool name to validated `ModelToolCall`, or `None` if no tool calls were found. """ model_tool_calls: dict[str, ModelToolCall] = dict() for tool_name, tool_args in parse_tools(decoded_result): diff --git a/mellea/backends/vllm.py b/mellea/backends/vllm.py index bc42af75e..ef1d87377 100644 --- a/mellea/backends/vllm.py +++ b/mellea/backends/vllm.py @@ -73,13 +73,13 @@ class LocalVLLMBackend(FormatterBackend): Args: model_id (str | ModelIdentifier): HuggingFace model ID used to load model weights via vLLM. formatter (ChatFormatter | None): Formatter for rendering components into prompts. - Defaults to a ``TemplateFormatter`` for the given ``model_id``. + Defaults to a `TemplateFormatter` for the given `model_id`. model_options (dict | None): Default model options for generation requests. Attributes: to_mellea_model_opts_map (dict): Mapping from backend-specific option names to - Mellea ``ModelOption`` sentinel keys. - from_mellea_model_opts_map (dict): Mapping from Mellea ``ModelOption`` sentinel + Mellea `ModelOption` sentinel keys. + from_mellea_model_opts_map (dict): Mapping from Mellea `ModelOption` sentinel keys to backend-specific option names. engine_args (dict): vLLM engine arguments used at instantiation; retained so the engine can be restarted when the event loop changes. @@ -249,7 +249,7 @@ async def _generate_from_context( generate_logs: list[GenerateLog] | None = None, tool_calls: bool = False, ) -> tuple[ModelOutputThunk[C], Context]: - """Generate a completion for ``action`` given the current ``ctx`` using the vLLM engine. + """Generate a completion for `action` given the current `ctx` using the vLLM engine. Args: action (Component[C] | CBlock): The component or content block to generate a @@ -260,13 +260,13 @@ async def _generate_from_context( model_options (dict | None): Per-call model options that override the backend's defaults. generate_logs (list[GenerateLog] | None): Optional list to which a - ``GenerateLog`` entry will be appended. - tool_calls (bool): If ``True``, expose available tools to the model and + `GenerateLog` entry will be appended. + tool_calls (bool): If `True`, expose available tools to the model and parse tool-call responses. Returns: tuple[ModelOutputThunk[C], Context]: A thunk holding the (lazy) model output - and an updated context that includes ``action`` and the new output. + and an updated context that includes `action` and the new output. """ await self.do_generate_walk(action) @@ -392,7 +392,7 @@ async def processing(self, mot: ModelOutputThunk, chunk: vllm.RequestOutput): """Accumulate text from a single vLLM output chunk into the model output thunk. Called during streaming or final generation to add each incremental result to - ``mot._underlying_value``. + `mot._underlying_value`. Args: mot (ModelOutputThunk): The output thunk being populated. @@ -424,7 +424,7 @@ async def post_processing( class used during generation, if any. tool_calls (bool): Whether tool calling was enabled for this request. tools (dict[str, AbstractMelleaTool]): Available tools, keyed by name. - seed: The random seed used during generation, or ``None``. + seed: The random seed used during generation, or `None`. """ # The ModelOutputThunk must be computed by this point. assert mot.value is not None diff --git a/mellea/backends/watsonx.py b/mellea/backends/watsonx.py index c356825ee..119ed4bf5 100644 --- a/mellea/backends/watsonx.py +++ b/mellea/backends/watsonx.py @@ -59,24 +59,24 @@ class WatsonxAIBackend(FormatterBackend): Args: model_id (str | ModelIdentifier): WatsonX model identifier. Defaults to - ``model_ids.IBM_GRANITE_4_HYBRID_SMALL``. + `model_ids.IBM_GRANITE_4_HYBRID_SMALL`. formatter (ChatFormatter | None): Formatter for rendering components. - Defaults to ``TemplateFormatter``. + Defaults to `TemplateFormatter`. base_url (str | None): URL for the WatsonX ML deployment endpoint; - defaults to the ``WATSONX_URL`` environment variable. + defaults to the `WATSONX_URL` environment variable. model_options (dict | None): Default model options for generation requests. api_key (str | None): WatsonX API key; defaults to the - ``WATSONX_API_KEY`` environment variable. + `WATSONX_API_KEY` environment variable. project_id (str | None): WatsonX project ID; defaults to the - ``WATSONX_PROJECT_ID`` environment variable. + `WATSONX_PROJECT_ID` environment variable. Attributes: to_mellea_model_opts_map_chats (dict): Mapping from chat-endpoint option names - to Mellea ``ModelOption`` sentinel keys. + to Mellea `ModelOption` sentinel keys. from_mellea_model_opts_map_chats (dict): Mapping from Mellea sentinel keys to chat-endpoint option names. to_mellea_model_opts_map_completions (dict): Mapping from completions-endpoint - option names to Mellea ``ModelOption`` sentinel keys. + option names to Mellea `ModelOption` sentinel keys. from_mellea_model_opts_map_completions (dict): Mapping from Mellea sentinel keys to completions-endpoint option names. """ @@ -272,10 +272,10 @@ async def _generate_from_context( model_options: dict | None = None, tool_calls: bool = False, ) -> tuple[ModelOutputThunk[C], Context]: - """Generate a completion for ``action`` given ``ctx`` via the WatsonX chat API. + """Generate a completion for `action` given `ctx` via the WatsonX chat API. - Delegates to ``generate_from_chat_context``. Only chat contexts are - supported; raises ``NotImplementedError`` otherwise. + Delegates to `generate_from_chat_context`. Only chat contexts are + supported; raises `NotImplementedError` otherwise. Args: action (Component[C] | CBlock): The component or content block to generate @@ -285,12 +285,12 @@ async def _generate_from_context( structured/constrained output decoding. model_options (dict | None): Per-call model options that override the backend's defaults. - tool_calls (bool): If ``True``, expose available tools to the model and + tool_calls (bool): If `True`, expose available tools to the model and parse tool-call responses. Returns: tuple[ModelOutputThunk[C], Context]: A thunk holding the (lazy) model output - and an updated context that includes ``action`` and the new output. + and an updated context that includes `action` and the new output. """ assert ctx.is_chat_context, NotImplementedError( "The watsonx.ai backend only supports chat-like contexts." @@ -334,13 +334,13 @@ async def generate_from_chat_context( _format (type[BaseModelSubclass] | None): Optional Pydantic model class for structured output decoding. model_options (dict | None): Per-call model options. - tool_calls (bool): If ``True``, expose available tools and parse responses. + tool_calls (bool): If `True`, expose available tools and parse responses. Returns: ModelOutputThunk[C]: A thunk holding the (lazy) model output. Raises: - Exception: If ``action`` is an ``ALoraRequirement``, which is not + Exception: If `action` is an `ALoraRequirement`, which is not supported by this backend. RuntimeError: If not called from a thread with a running event loop. """ @@ -465,8 +465,8 @@ async def generate_from_chat_context( async def processing(self, mot: ModelOutputThunk, chunk: dict): """Accumulate content from a single WatsonX response dict into the output thunk. - Called for each non-streaming chat dict (with a ``"message"`` key) or - streaming delta dict (with a ``"delta"`` key). Tool call parsing is + Called for each non-streaming chat dict (with a `"message"` key) or + streaming delta dict (with a `"delta"` key). Tool call parsing is handled in the post-processing step. Args: @@ -533,7 +533,7 @@ async def post_processing( conversation (list[dict]): The chat conversation sent to the model, used for logging. tools (dict[str, AbstractMelleaTool]): Available tools, keyed by name. - seed: The random seed used during generation, or ``None``. + seed: The random seed used during generation, or `None`. _format: The structured output format class used during generation, if any. """ # Reconstruct the chat_response from chunks if streamed. @@ -665,7 +665,7 @@ async def generate_from_raw( """Generate completions for multiple actions without chat templating via WatsonX. Passes formatted prompt strings directly to WatsonX's generate endpoint. - The ``format`` parameter is not supported and will be ignored with a warning. + The `format` parameter is not supported and will be ignored with a warning. Args: actions (Sequence[Component[C] | CBlock]): Actions to generate completions for. diff --git a/mellea/core/__init__.py b/mellea/core/__init__.py index b031df6ab..8ba5cf3e5 100644 --- a/mellea/core/__init__.py +++ b/mellea/core/__init__.py @@ -1,10 +1,10 @@ """Core abstractions for the mellea library. This package defines the fundamental interfaces and data structures on which every -other layer of mellea is built: the ``Backend``, ``Formatter``, and -``SamplingStrategy`` protocols; the ``Component``, ``CBlock``, ``Context``, and -``ModelOutputThunk`` data types that flow through the inference pipeline; and -``Requirement`` / ``ValidationResult`` for constrained generation. Start here when +other layer of mellea is built: the `Backend`, `Formatter`, and +`SamplingStrategy` protocols; the `Component`, `CBlock`, `Context`, and +`ModelOutputThunk` data types that flow through the inference pipeline; and +`Requirement` / `ValidationResult` for constrained generation. Start here when building a new backend, formatter, or sampling strategy, or when you need the type definitions shared across the library. """ diff --git a/mellea/core/backend.py b/mellea/core/backend.py index 82f9fae5f..376cbeea4 100644 --- a/mellea/core/backend.py +++ b/mellea/core/backend.py @@ -1,10 +1,10 @@ -"""Abstract ``Backend`` interface and generation-walk utilities. +"""Abstract `Backend` interface and generation-walk utilities. -Defines the ``Backend`` abstract base class whose two key abstract methods — -``generate_from_context`` (context-aware single-action generation) and -``generate_from_raw`` (context-free batch generation) — all concrete backends must -implement. Also provides ``generate_walk``, which traverses a ``Component`` tree to -find un-computed ``ModelOutputThunk`` leaves that need to be resolved before rendering. +Defines the `Backend` abstract base class whose two key abstract methods — +`generate_from_context` (context-aware single-action generation) and +`generate_from_raw` (context-free batch generation) — all concrete backends must +implement. Also provides `generate_walk`, which traverses a `Component` tree to +find un-computed `ModelOutputThunk` leaves that need to be resolved before rendering. """ import abc @@ -42,10 +42,10 @@ class Backend(abc.ABC): """Abstract base class for all inference backends. - All concrete backends must implement ``generate_from_context`` (context-aware - single-action generation) and ``generate_from_raw`` (context-free batch - generation). The ``do_generate_walk`` / ``do_generate_walks`` helpers can be - used to pre-compute any unresolved ``ModelOutputThunk`` leaves before rendering. + All concrete backends must implement `generate_from_context` (context-aware + single-action generation) and `generate_from_raw` (context-free batch + generation). The `do_generate_walk` / `do_generate_walks` helpers can be + used to pre-compute any unresolved `ModelOutputThunk` leaves before rendering. """ @final @@ -162,16 +162,16 @@ async def generate_from_raw( tool_calls: Always set to false unless supported by backend. Returns: - list[ModelOutputThunk]: A list of output thunks, one per action, in the same order as ``actions``. + list[ModelOutputThunk]: A list of output thunks, one per action, in the same order as `actions`. """ async def do_generate_walk( self, action: CBlock | Component | ModelOutputThunk ) -> None: - """Awaits all uncomputed ``ModelOutputThunk`` leaves reachable from ``action``. + """Awaits all uncomputed `ModelOutputThunk` leaves reachable from `action`. - Traverses the component tree rooted at ``action`` via ``generate_walk``, collects - any uncomputed ``ModelOutputThunk`` nodes, and concurrently awaits them all. + Traverses the component tree rooted at `action` via `generate_walk`, collects + any uncomputed `ModelOutputThunk` nodes, and concurrently awaits them all. Args: action (CBlock | Component | ModelOutputThunk): The root node to traverse. @@ -188,10 +188,10 @@ async def do_generate_walk( async def do_generate_walks( self, actions: list[CBlock | Component | ModelOutputThunk] ) -> None: - """Awaits all uncomputed ``ModelOutputThunk`` leaves reachable from each action in ``actions``. + """Awaits all uncomputed `ModelOutputThunk` leaves reachable from each action in `actions`. - Traverses the component tree of every action in the list via ``generate_walk``, collects - all uncomputed ``ModelOutputThunk`` nodes across all actions, and concurrently awaits them. + Traverses the component tree of every action in the list via `generate_walk`, collects + all uncomputed `ModelOutputThunk` nodes across all actions, and concurrently awaits them. Args: actions (list[CBlock | Component | ModelOutputThunk]): The list of root nodes to traverse. @@ -209,18 +209,18 @@ async def do_generate_walks( def generate_walk(c: CBlock | Component | ModelOutputThunk) -> list[ModelOutputThunk]: - """Return all uncomputed ``ModelOutputThunk`` leaves reachable from ``c``. + """Return all uncomputed `ModelOutputThunk` leaves reachable from `c`. Args: - c: A ``CBlock``, ``Component``, or ``ModelOutputThunk`` to traverse. + c: A `CBlock`, `Component`, or `ModelOutputThunk` to traverse. Returns: - A flat list of uncomputed ``ModelOutputThunk`` instances in the order - they need to be resolved (depth-first over ``Component.parts()``). + A flat list of uncomputed `ModelOutputThunk` instances in the order + they need to be resolved (depth-first over `Component.parts()`). Raises: - ValueError: If any element encountered during traversal is not a ``CBlock``, - ``Component``, or ``ModelOutputThunk``. + ValueError: If any element encountered during traversal is not a `CBlock`, + `Component`, or `ModelOutputThunk`. """ match c: case ModelOutputThunk() if not c.is_computed(): diff --git a/mellea/core/base.py b/mellea/core/base.py index bef046b20..dc90ffd39 100644 --- a/mellea/core/base.py +++ b/mellea/core/base.py @@ -1,11 +1,11 @@ """Foundational data structures for mellea's generative programming model. -Defines the building blocks that flow through every layer of the library: ``CBlock`` -(a content block wrapping a string value), ``Component`` (an abstract composable -generative unit), ``ModelOutputThunk`` (a lazily-evaluated model response), -``Context`` and ``ContextTurn`` (stateful conversation history containers), -``TemplateRepresentation`` (the structured rendering of a component for prompt -templates), ``ImageBlock``, and ``ModelToolCall``. Understanding these types is +Defines the building blocks that flow through every layer of the library: `CBlock` +(a content block wrapping a string value), `Component` (an abstract composable +generative unit), `ModelOutputThunk` (a lazily-evaluated model response), +`Context` and `ContextTurn` (stateful conversation history containers), +`TemplateRepresentation` (the structured rendering of a component for prompt +templates), `ImageBlock`, and `ModelToolCall`. Understanding these types is the starting point for building custom components or sampling strategies. """ @@ -38,7 +38,7 @@ class CBlock: value (str | None): The underlying string content of the block. meta (dict[str, Any] | None): Optional metadata about this block (e.g., the inference engine's completion object). Defaults to an empty dict. - cache (bool): If ``True``, the inference engine may store the KV cache for this block. Experimental. + cache (bool): If `True`, the inference engine may store the KV cache for this block. Experimental. """ @@ -90,7 +90,7 @@ def __init__(self, value: str, meta: dict[str, Any] | None = None): """Initialize ImageBlock with a base64-encoded PNG string, validating the encoding. Raises: - AssertionError: If ``value`` is not a valid base64-encoded PNG string. + AssertionError: If `value` is not a valid base64-encoded PNG string. """ assert self.is_valid_base64_png(value), ( "Invalid base64 string representation of image." @@ -108,7 +108,7 @@ def is_valid_base64_png(s: str) -> bool: s (str): The string to validate, optionally prefixed with a data URI header. Returns: - bool: ``True`` if the string decodes to a PNG image, ``False`` otherwise. + bool: `True` if the string decodes to a PNG image, `False` otherwise. """ try: # Check if the string has a data URI prefix and remove it. @@ -154,17 +154,17 @@ def pil_to_base64(image: PILImage.Image) -> str: def from_pil_image( cls, image: PILImage.Image, meta: dict[str, Any] | None = None ) -> ImageBlock: - """Creates an ``ImageBlock`` from a PIL image object. + """Creates an `ImageBlock` from a PIL image object. Converts the image to a base64-encoded PNG string and wraps it in a new - ``ImageBlock`` instance. + `ImageBlock` instance. Args: image (PILImage.Image): The PIL image to encode. meta (dict[str, Any] | None): Optional metadata to associate with the block. Returns: - ImageBlock: A new ``ImageBlock`` containing the base64-encoded PNG. + ImageBlock: A new `ImageBlock` containing the base64-encoded PNG. """ image_base64 = cls.pil_to_base64(image) return cls(image_base64, meta) @@ -190,10 +190,10 @@ class Component(Protocol, Generic[S]): """A `Component` is a composite data structure that is intended to be represented to an LLM.""" def parts(self) -> list[Component | CBlock]: - """Returns the set of all constituent sub-components and content blocks of this ``Component``. + """Returns the set of all constituent sub-components and content blocks of this `Component`. Returns: - list[Component | CBlock]: A list of child ``Component`` or ``CBlock`` objects that make + list[Component | CBlock]: A list of child `Component` or `CBlock` objects that make up this component. The list may be empty for leaf components. Raises: @@ -202,10 +202,10 @@ def parts(self) -> list[Component | CBlock]: raise NotImplementedError("parts isn't implemented by default") def format_for_llm(self) -> TemplateRepresentation | str: - """Formats the ``Component`` into a ``TemplateRepresentation`` or plain string for LLM consumption. + """Formats the `Component` into a `TemplateRepresentation` or plain string for LLM consumption. Returns: - TemplateRepresentation | str: A structured ``TemplateRepresentation`` (for components + TemplateRepresentation | str: A structured `TemplateRepresentation` (for components with tools, fields, or templates) or a plain string for simple components. Raises: @@ -214,19 +214,19 @@ def format_for_llm(self) -> TemplateRepresentation | str: raise NotImplementedError("format_for_llm isn't implemented by default") def parse(self, computed: ModelOutputThunk) -> S: - """Parses the expected type ``S`` from a given ``ModelOutputThunk``. + """Parses the expected type `S` from a given `ModelOutputThunk`. - Delegates to the component's underlying ``_parse`` method and wraps any - exception in a ``ComponentParseError`` for uniform error handling. + Delegates to the component's underlying `_parse` method and wraps any + exception in a `ComponentParseError` for uniform error handling. Args: computed (ModelOutputThunk): The model output thunk whose value should be parsed. Returns: - S: The parsed result produced by ``_parse``, typed according to the component's type parameter. + S: The parsed result produced by `_parse`, typed according to the component's type parameter. Raises: - ComponentParseError: If the underlying ``_parse`` call raises any exception. + ComponentParseError: If the underlying `_parse` call raises any exception. """ try: return self._parse(computed) @@ -243,7 +243,7 @@ class GenerateType(enum.Enum): Attributes: NONE (None): No generation function has been set; the thunk is either already computed or uninitialized. - ASYNC (int): The generation function is async-compatible; ``avalue``/``astream`` may be used. + ASYNC (int): The generation function is async-compatible; `avalue`/`astream` may be used. SYNC (int): The generation function is synchronous only; async extraction methods are unavailable. """ @@ -256,7 +256,7 @@ class ModelOutputThunk(CBlock, Generic[S]): """A `ModelOutputThunk` is a special type of `CBlock` that we know came from a model's output. It is possible to instantiate one without the output being computed yet. Args: - value (str | None): The raw model output string, or ``None`` if not yet computed. + value (str | None): The raw model output string, or `None` if not yet computed. meta (dict[str, Any] | None): Optional metadata from the inference engine (e.g., completion object). parsed_repr (S | None): An already-parsed representation to attach; set when re-wrapping existing output. tool_calls (dict[str, ModelToolCall] | None): Tool calls returned by the model alongside the text output. @@ -566,9 +566,9 @@ class ContextTurn: Args: model_input (CBlock | Component | None): The input component or content block for this turn, - or ``None`` for an output-only partial turn. + or `None` for an output-only partial turn. output (ModelOutputThunk | None): The model's output thunk for this turn, - or ``None`` for an input-only partial turn. + or `None` for an input-only partial turn. """ @@ -585,11 +585,11 @@ class Context(abc.ABC): A context is immutable. Every alteration leads to a new context. Attributes: - is_root_node (bool): ``True`` when this context is the root (empty) node of the linked list. + is_root_node (bool): `True` when this context is the root (empty) node of the linked list. previous_node (Context | None): The context node from which this one was created, - or ``None`` for the root node. + or `None` for the root node. node_data (Component | CBlock | None): The data associated with this context node, - or ``None`` for the root node. + or `None` for the root node. is_chat_context (bool): Whether this context operates in chat (multi-turn) mode. """ @@ -617,7 +617,7 @@ def from_previous( data (Component | CBlock): The component or content block to associate with the new node. Returns: - ContextT: A new context instance whose ``previous_node`` is ``previous``. + ContextT: A new context instance whose `previous_node` is `previous`. """ assert isinstance(previous, Context), ( "Cannot create a new context from a non-Context object." @@ -677,7 +677,7 @@ def as_list(self, last_n_components: int | None = None) -> list[Component | CBlo Args: last_n_components (int | None): Maximum number of most-recent components to include. - Pass ``None`` to return the full history. + Pass `None` to return the full history. Returns: list[Component | CBlock]: Components in chronological order (oldest first). @@ -708,25 +708,25 @@ def as_list(self, last_n_components: int | None = None) -> list[Component | CBlo def actions_for_available_tools(self) -> list[Component | CBlock] | None: """Provides a list of actions to extract tools from for use during generation. - Returns ``None`` if it is not possible to construct such a list. Can be used to make + Returns `None` if it is not possible to construct such a list. Can be used to make the available tools differ from the tools of all the actions in the context. Can be overridden by subclasses. Returns: list[Component | CBlock] | None: The list of actions whose tools should be made - available during generation, or ``None`` if unavailable. + available during generation, or `None` if unavailable. """ return self.view_for_generation() def last_output(self, check_last_n_components: int = 3) -> ModelOutputThunk | None: - """Returns the most recent ``ModelOutputThunk`` found within the last N context components. + """Returns the most recent `ModelOutputThunk` found within the last N context components. Args: check_last_n_components (int): Number of most-recent components to search through. Defaults to 3. Returns: - ModelOutputThunk | None: The most recent output thunk, or ``None`` if none is found + ModelOutputThunk | None: The most recent output thunk, or `None` if none is found within the searched components. """ for c in self.as_list(last_n_components=check_last_n_components)[::-1]: @@ -759,13 +759,13 @@ def last_turn(self) -> ContextTurn | None: @abc.abstractmethod def add(self, c: Component | CBlock) -> Context: - """Returns a new context obtained by appending ``c`` to this context. + """Returns a new context obtained by appending `c` to this context. Args: c (Component | CBlock): The component or content block to add to the context. Returns: - Context: A new context node with ``c`` as its data and this context as its previous node. + Context: A new context node with `c` as its data and this context as its previous node. """ # something along ....from_previous(self, c) ... @@ -774,12 +774,12 @@ def add(self, c: Component | CBlock) -> Context: def view_for_generation(self) -> list[Component | CBlock] | None: """Provides a linear list of context components to use for generation. - Returns ``None`` if it is not possible to construct such a list (e.g., the context + Returns `None` if it is not possible to construct such a list (e.g., the context is in an inconsistent state). Concrete subclasses define the ordering and filtering logic. Returns: list[Component | CBlock] | None: An ordered list of components suitable for passing - to a backend, or ``None`` if generation is not currently possible. + to a backend, or `None` if generation is not currently possible. """ ... @@ -822,7 +822,7 @@ class TemplateRepresentation: obj (Any): The original component object being represented. args (dict): Named arguments extracted from the component for template substitution. tools (dict[str, AbstractMelleaTool] | None): Tools available for this representation, - keyed by the tool's function name. Defaults to ``None``. + keyed by the tool's function name. Defaults to `None`. fields (list[Any] | None): An optional ordered list of field values for positional templates. template (str | None): An optional Jinja2 template string to use when rendering. template_order (list[str] | None): An optional ordering hint for template sections/keys. @@ -858,7 +858,7 @@ class GenerateLog: model_options (dict[str, Any] | None): Model configuration options applied to this call. model_output (Any | None): The raw output returned by the backend API. action (Component | CBlock | None): The component or block that triggered the generation. - result (ModelOutputThunk | None): The ``ModelOutputThunk`` produced by this generation call. + result (ModelOutputThunk | None): The `ModelOutputThunk` produced by this generation call. is_final_result (bool | None): Whether this log entry corresponds to the definitive final result. extra (dict[str, Any] | None): Arbitrary extra metadata to attach to the log entry. @@ -883,7 +883,7 @@ class ModelToolCall: Args: name (str): The name of the tool the model requested to call. - func (AbstractMelleaTool): The ``AbstractMelleaTool`` instance that will be invoked. + func (AbstractMelleaTool): The `AbstractMelleaTool` instance that will be invoked. args (Mapping[str, Any]): The keyword arguments the model supplied for the tool call. """ @@ -896,22 +896,22 @@ def call_func(self) -> Any: """Invokes the tool represented by this object and returns the result. Returns: - Any: The value returned by ``func.run(**args)``; the concrete type depends on the tool. + Any: The value returned by `func.run(**args)`; the concrete type depends on the tool. """ return self.func.run(**self.args) def blockify(s: str | CBlock | Component) -> CBlock | Component: - """Turn a raw string into a ``CBlock``, leaving ``CBlock`` and ``Component`` objects unchanged. + """Turn a raw string into a `CBlock`, leaving `CBlock` and `Component` objects unchanged. Args: - s: A plain string, ``CBlock``, or ``Component`` to normalise. + s: A plain string, `CBlock`, or `Component` to normalise. Returns: - A ``CBlock`` wrapping ``s`` if it was a string; otherwise ``s`` unchanged. + A `CBlock` wrapping `s` if it was a string; otherwise `s` unchanged. Raises: - Exception: If ``s`` is not a ``str``, ``CBlock``, or ``Component``. + Exception: If `s` is not a `str`, `CBlock`, or `Component`. """ # noinspection PyUnreachableCode match s: @@ -926,14 +926,14 @@ def blockify(s: str | CBlock | Component) -> CBlock | Component: def get_images_from_component(c: Component) -> None | list[ImageBlock]: - """Return the images attached to a ``Component``, or ``None`` if absent or empty. + """Return the images attached to a `Component`, or `None` if absent or empty. Args: - c: The ``Component`` whose ``images`` attribute is inspected. + c: The `Component` whose `images` attribute is inspected. Returns: - A non-empty list of ``ImageBlock`` objects if the component has an - ``images`` attribute with at least one element; ``None`` otherwise. + A non-empty list of `ImageBlock` objects if the component has an + `images` attribute with at least one element; `None` otherwise. """ if hasattr(c, "images"): imgs = c.images # type: ignore diff --git a/mellea/core/formatter.py b/mellea/core/formatter.py index faf53d5fa..3db715530 100644 --- a/mellea/core/formatter.py +++ b/mellea/core/formatter.py @@ -1,9 +1,9 @@ -"""Abstract ``Formatter`` interface for rendering components to strings. +"""Abstract `Formatter` interface for rendering components to strings. -A ``Formatter`` converts ``Component`` and ``CBlock`` objects into the text strings -fed to language model prompts. The single abstract method ``print`` encapsulates this -rendering contract; concrete subclasses such as ``ChatFormatter`` and -``TemplateFormatter`` extend it with chat-message and Jinja2-template rendering +A `Formatter` converts `Component` and `CBlock` objects into the text strings +fed to language model prompts. The single abstract method `print` encapsulates this +rendering contract; concrete subclasses such as `ChatFormatter` and +`TemplateFormatter` extend it with chat-message and Jinja2-template rendering respectively. """ @@ -17,12 +17,12 @@ class Formatter(abc.ABC): @abc.abstractmethod def print(self, c: Component | CBlock) -> str: - """Renders a ``Component`` or ``CBlock`` into a string suitable for use as model input. + """Renders a `Component` or `CBlock` into a string suitable for use as model input. Args: c (Component | CBlock): The component or content block to render. Returns: - str: The rendered string representation of ``c``. + str: The rendered string representation of `c`. """ ... diff --git a/mellea/core/requirement.py b/mellea/core/requirement.py index ca780037d..b53b596d8 100644 --- a/mellea/core/requirement.py +++ b/mellea/core/requirement.py @@ -1,10 +1,10 @@ -"""``Requirement`` interface for constrained and validated generation. +"""`Requirement` interface for constrained and validated generation. -A ``Requirement`` pairs a human-readable description with a validation function that -inspects a ``Context`` (and optionally a backend) to determine whether a model output -meets a constraint. ``ValidationResult`` carries the pass/fail verdict along with an -optional reason, score, and the ``ModelOutputThunk`` produced during validation. -Helper factories such as ``default_output_to_bool`` make it easy to build requirements +A `Requirement` pairs a human-readable description with a validation function that +inspects a `Context` (and optionally a backend) to determine whether a model output +meets a constraint. `ValidationResult` carries the pass/fail verdict along with an +optional reason, score, and the `ModelOutputThunk` produced during validation. +Helper factories such as `default_output_to_bool` make it easy to build requirements without boilerplate. """ @@ -23,7 +23,7 @@ class ValidationResult: result (bool): Boolean indicating whether the requirement passed. reason (str | None): Optional human-readable explanation for the verdict. score (float | None): Optional numeric score returned by the validator. - thunk (ModelOutputThunk | None): The ``ModelOutputThunk`` produced during LLM-as-a-Judge validation, if applicable. + thunk (ModelOutputThunk | None): The `ModelOutputThunk` produced during LLM-as-a-Judge validation, if applicable. context (Context | None): The context associated with the validation backend call, if applicable. """ @@ -68,7 +68,7 @@ def as_bool(self) -> bool: """Return a boolean value based on the validation result. Returns: - bool: ``True`` if the requirement passed, ``False`` otherwise. + bool: `True` if the requirement passed, `False` otherwise. """ return self._result @@ -84,10 +84,10 @@ def default_output_to_bool(x: CBlock | str) -> bool: also checks if any of the words in the output are "yes" (case-insensitive). Args: - x: The model output to evaluate, as a ``CBlock`` or plain string. + x: The model output to evaluate, as a `CBlock` or plain string. Returns: - ``True`` if the output indicates a "yes" answer, ``False`` otherwise. + `True` if the output indicates a "yes" answer, `False` otherwise. """ output = str(x) @@ -106,12 +106,12 @@ class Requirement(Component[str]): Args: description (str | None): A natural-language description of the requirement. Sometimes included in - ``Instruction`` prompts; use ``check_only=True`` to suppress this. + `Instruction` prompts; use `check_only=True` to suppress this. validation_fn (Callable[[Context], ValidationResult] | None): If provided, this function is executed - instead of LLM-as-a-Judge. The ``bool()`` of its return value defines pass/fail. + instead of LLM-as-a-Judge. The `bool()` of its return value defines pass/fail. output_to_bool (Callable[[CBlock | str], bool] | None): Translates LLM-as-a-Judge output to a boolean. Defaults to a "yes"-detection heuristic. - check_only (bool): When ``True``, the requirement description is excluded from ``Instruction`` prompts. + check_only (bool): When `True`, the requirement description is excluded from `Instruction` prompts. Attributes: description (str | None): A natural-language description of the requirement. @@ -119,7 +119,7 @@ class Requirement(Component[str]): output into a boolean pass/fail result. validation_fn (Callable[[Context], ValidationResult] | None): Optional custom validation function that bypasses the LLM-as-a-Judge strategy entirely. - check_only (bool): When ``True``, the requirement description is excluded from ``Instruction`` + check_only (bool): When `True`, the requirement description is excluded from `Instruction` prompts to avoid influencing model output. """ @@ -150,12 +150,12 @@ async def validate( ) -> ValidationResult: """Chooses the appropriate validation strategy and applies it to the given context. - Uses ``validation_fn`` if one was provided, otherwise falls back to LLM-as-a-Judge + Uses `validation_fn` if one was provided, otherwise falls back to LLM-as-a-Judge by generating a judgement response with the backend. Args: backend (Backend): The inference backend used when the LLM-as-a-Judge strategy is selected. - ctx (Context): The context to validate, which must contain a ``ModelOutputThunk`` as its last output. + ctx (Context): The context to validate, which must contain a `ModelOutputThunk` as its last output. format (type[BaseModelSubclass] | None): Optional structured output format for the judgement call. model_options (dict | None): Optional model options to pass to the backend during the judgement call. @@ -194,14 +194,14 @@ def parts(self) -> list[Component | CBlock]: return [] def format_for_llm(self) -> TemplateRepresentation | str: - """Returns a ``TemplateRepresentation`` for LLM-as-a-Judge evaluation of this requirement. + """Returns a `TemplateRepresentation` for LLM-as-a-Judge evaluation of this requirement. - Populates the template with the requirement's ``description`` and the stored model - ``_output``. Must only be called from within a ``validate`` call for this same requirement, - after ``_output`` has been set. + Populates the template with the requirement's `description` and the stored model + `_output`. Must only be called from within a `validate` call for this same requirement, + after `_output` has been set. Returns: - TemplateRepresentation | str: A ``TemplateRepresentation`` containing the description + TemplateRepresentation | str: A `TemplateRepresentation` containing the description and the model output to be judged. """ assert self._output is not None, ( diff --git a/mellea/core/sampling.py b/mellea/core/sampling.py index 0bf75badf..f6a245b25 100644 --- a/mellea/core/sampling.py +++ b/mellea/core/sampling.py @@ -1,8 +1,8 @@ """Abstract interfaces for sampling strategies and their results. -``SamplingStrategy`` defines the contract for all sampling algorithms: an async -``sample`` method that takes an action, context, backend, and requirements, and -returns a ``SamplingResult``. ``SamplingResult`` records the chosen generation +`SamplingStrategy` defines the contract for all sampling algorithms: an async +`sample` method that takes an action, context, backend, and requirements, and +returns a `SamplingResult`. `SamplingResult` records the chosen generation alongside the full history of intermediate samples, their validation outcomes, and associated contexts — enabling detailed post-hoc inspection of the sampling process. @@ -20,7 +20,7 @@ class SamplingResult(CBlock, Generic[S]): """Stores the results from a sampling operation. This includes successful and failed samplings. Args: - result_index (int): Index into ``sample_generations`` identifying the chosen final output. + result_index (int): Index into `sample_generations` identifying the chosen final output. success (bool): Whether the sampling operation produced a passing result. sample_generations (list[ModelOutputThunk[S]] | None): All output thunks generated during sampling. sample_validations (list[list[tuple[Requirement, ValidationResult]]] | None): Per-generation validation @@ -29,16 +29,16 @@ class SamplingResult(CBlock, Generic[S]): sample_contexts (list[Context] | None): The contexts associated with each generation. Attributes: - result_index (int): Index into ``sample_generations`` identifying the chosen final output. + result_index (int): Index into `sample_generations` identifying the chosen final output. success (bool): Whether the sampling operation produced a passing result. sample_generations (list[ModelOutputThunk[S]]): All output thunks generated during - sampling; always a list (``None`` input is normalised to ``[]``). + sampling; always a list (`None` input is normalised to `[]`). sample_validations (list[list[tuple[Requirement, ValidationResult]]]): Per-generation - validation results; always a list (``None`` input is normalised to ``[]``). + validation results; always a list (`None` input is normalised to `[]`). sample_actions (list[Component]): The actions used to produce each generation; - always a list (``None`` input is normalised to ``[]``). + always a list (`None` input is normalised to `[]`). sample_contexts (list[Context]): The contexts associated with each generation; - always a list (``None`` input is normalised to ``[]``). + always a list (`None` input is normalised to `[]`). """ def __init__( diff --git a/mellea/core/utils.py b/mellea/core/utils.py index 91d87ee39..dbcce31c2 100644 --- a/mellea/core/utils.py +++ b/mellea/core/utils.py @@ -1,9 +1,9 @@ """Logging utilities for the mellea core library. -Provides ``FancyLogger``, a singleton logger with colour-coded console output and -an optional REST handler (``RESTHandler``) that forwards log records to a local -``/api/receive`` endpoint when the ``FLOG`` environment variable is set. All -internal mellea modules obtain their logger via ``FancyLogger.get_logger()``. +Provides `FancyLogger`, a singleton logger with colour-coded console output and +an optional REST handler (`RESTHandler`) that forwards log records to a local +`/api/receive` endpoint when the `FLOG` environment variable is set. All +internal mellea modules obtain their logger via `FancyLogger.get_logger()`. """ import json @@ -17,15 +17,15 @@ class RESTHandler(logging.Handler): """Logging handler that forwards records to a local REST endpoint. - Sends log records as JSON to ``/api/receive`` when the ``FLOG`` environment + Sends log records as JSON to `/api/receive` when the `FLOG` environment variable is set. Failures are silently suppressed to avoid disrupting the application. Args: api_url (str): The URL of the REST endpoint that receives log records. - method (str): HTTP method to use when sending records (default ``"POST"``). + method (str): HTTP method to use when sending records (default `"POST"`). headers (dict | None): HTTP headers to send; defaults to - ``{"Content-Type": "application/json"}`` when ``None``. + `{"Content-Type": "application/json"}` when `None`. """ def __init__( @@ -38,7 +38,7 @@ def __init__( self.headers = headers or {"Content-Type": "application/json"} def emit(self, record: logging.LogRecord) -> None: - """Forwards a log record to the REST endpoint when the ``FLOG`` environment variable is set. + """Forwards a log record to the REST endpoint when the `FLOG` environment variable is set. Silently suppresses any network or HTTP errors to avoid disrupting the application. @@ -137,18 +137,18 @@ def format(self, record: logging.LogRecord) -> str: class FancyLogger: """Singleton logger with colour-coded console output and optional REST forwarding. - Obtain the shared logger instance via ``FancyLogger.get_logger()``. Log level - defaults to ``INFO`` but can be raised to ``DEBUG`` by setting the ``DEBUG`` - environment variable. When the ``FLOG`` environment variable is set, records are - also forwarded to a local ``/api/receive`` REST endpoint via ``RESTHandler``. + Obtain the shared logger instance via `FancyLogger.get_logger()`. Log level + defaults to `INFO` but can be raised to `DEBUG` by setting the `DEBUG` + environment variable. When the `FLOG` environment variable is set, records are + also forwarded to a local `/api/receive` REST endpoint via `RESTHandler`. Attributes: - logger (logging.Logger | None): The shared ``logging.Logger`` instance; ``None`` until first call to ``get_logger()``. + logger (logging.Logger | None): The shared `logging.Logger` instance; `None` until first call to `get_logger()`. CRITICAL (int): Numeric level for critical log messages (50). - FATAL (int): Alias for ``CRITICAL`` (50). + FATAL (int): Alias for `CRITICAL` (50). ERROR (int): Numeric level for error log messages (40). WARNING (int): Numeric level for warning log messages (30). - WARN (int): Alias for ``WARNING`` (30). + WARN (int): Alias for `WARNING` (30). INFO (int): Numeric level for informational log messages (20). DEBUG (int): Numeric level for debug log messages (10). NOTSET (int): Numeric level meaning no level is set (0). diff --git a/mellea/formatters/__init__.py b/mellea/formatters/__init__.py index c7489dc31..bcbdf96bf 100644 --- a/mellea/formatters/__init__.py +++ b/mellea/formatters/__init__.py @@ -1,11 +1,11 @@ """Formatters for converting components into model-ready prompts. -Formatters translate ``Component`` objects into the prompt strings or chat message -lists that inference backends consume. This package exports the abstract ``Formatter`` -interface and two concrete implementations: ``ChatFormatter``, which converts -components into role-labelled chat messages, and ``TemplateFormatter``, which renders +Formatters translate `Component` objects into the prompt strings or chat message +lists that inference backends consume. This package exports the abstract `Formatter` +interface and two concrete implementations: `ChatFormatter`, which converts +components into role-labelled chat messages, and `TemplateFormatter`, which renders them through Jinja2 templates. Pass a formatter when constructing a -``FormatterBackend`` for your chosen model. +`FormatterBackend` for your chosen model. """ # Import from core for ergonomics. diff --git a/mellea/formatters/chat_formatter.py b/mellea/formatters/chat_formatter.py index 6ccfced05..88a291477 100644 --- a/mellea/formatters/chat_formatter.py +++ b/mellea/formatters/chat_formatter.py @@ -1,9 +1,9 @@ -"""``ChatFormatter`` for converting context histories to chat-message lists. +"""`ChatFormatter` for converting context histories to chat-message lists. -``ChatFormatter`` is the standard formatter used by mellea's legacy backends. Its -``to_chat_messages`` method linearises a sequence of ``Component`` and ``CBlock`` -objects into ``Message`` objects with ``user``, ``assistant``, or ``tool`` roles, -handling ``ModelOutputThunk`` responses, image attachments, and parsed structured +`ChatFormatter` is the standard formatter used by mellea's legacy backends. Its +`to_chat_messages` method linearises a sequence of `Component` and `CBlock` +objects into `Message` objects with `user`, `assistant`, or `tool` roles, +handling `ModelOutputThunk` responses, image attachments, and parsed structured outputs. Concrete backends call this formatter when preparing input for a chat completion endpoint. """ @@ -25,9 +25,9 @@ def to_chat_messages(self, cs: list[Component | CBlock]) -> list[Message]: """Convert a linearized chat history into a list of chat messages. Iterates over each element in the context history and converts it to a - ``Message`` with an appropriate role. ``ModelOutputThunk`` instances are - treated as assistant responses, while all other ``Component`` and - ``CBlock`` objects default to the ``user`` role. Image attachments and + `Message` with an appropriate role. `ModelOutputThunk` instances are + treated as assistant responses, while all other `Component` and + `CBlock` objects default to the `user` role. Image attachments and parsed structured outputs are handled transparently. Args: @@ -35,7 +35,7 @@ def to_chat_messages(self, cs: list[Component | CBlock]) -> list[Message]: components and code blocks to convert. Returns: - list[Message]: A list of ``Message`` objects ready for submission to + list[Message]: A list of `Message` objects ready for submission to a chat completion endpoint. """ diff --git a/mellea/formatters/granite/base/io.py b/mellea/formatters/granite/base/io.py index ece5ecce0..912e1f3a3 100644 --- a/mellea/formatters/granite/base/io.py +++ b/mellea/formatters/granite/base/io.py @@ -35,10 +35,10 @@ def transform( Args: chat_completion (ChatCompletion): Structured representation of the inputs to the chat completion request. - add_generation_prompt (bool): If ``True``, the returned prompt string will + add_generation_prompt (bool): If `True`, the returned prompt string will contain a prefix of the next assistant response for use as a prompt to a generation request. Otherwise, the prompt will only contain the messages - and documents in ``chat_completion``. Defaults to ``True``. + and documents in `chat_completion`. Defaults to `True`. Returns: str: String that can be passed to the model's tokenizer to create a prompt @@ -68,8 +68,8 @@ def transform( model_output (str): String output of the generation request, potentially incomplete if it was a streaming request. chat_completion (ChatCompletion | None): The chat completion request that - produced ``model_output``. Parameters of the request can determine how - the output should be decoded. Defaults to ``None``. + produced `model_output`. Parameters of the request can determine how + the output should be decoded. Defaults to `None`. Returns: AssistantMessage: The parsed output so far, as an instance of @@ -103,7 +103,7 @@ def transform( ChatCompletion: Rewritten copy of the original chat completion request. Raises: - TypeError: If ``chat_completion`` is not a :class:`ChatCompletion` object, + TypeError: If `chat_completion` is not a :class:`ChatCompletion` object, a JSON string, or a dictionary. """ if isinstance(chat_completion, str): @@ -153,16 +153,16 @@ def transform( :class:`ChatCompletionResponse` dataclass, a raw dictionary, or another Pydantic model. chat_completion (ChatCompletion | None): The original chat completion - request that produced ``chat_completion_response``. Required by + request that produced `chat_completion_response`. Required by some implementations to decode references back to the original - request. Defaults to ``None``. + request. Defaults to `None`. Returns: ChatCompletionResponse: Post-processed copy of the chat completion response with model-specific transformations applied. Raises: - TypeError: If ``chat_completion_response`` is not a supported type. + TypeError: If `chat_completion_response` is not a supported type. """ # Convert from over-the-wire format if necessary if isinstance(chat_completion_response, dict): @@ -206,9 +206,9 @@ def retrieve(self, query: str, top_k: int = 10) -> list[Document]: Args: query (str): Query string to use for lookup. - top_k (int): Maximum number of results to return. Defaults to ``10``. + top_k (int): Maximum number of results to return. Defaults to `10`. Returns: list[Document]: List of the top-k matching :class:`Document` objects, - each with fields such as ``text``, ``title``, and ``doc_id``. + each with fields such as `text`, `title`, and `doc_id`. """ diff --git a/mellea/formatters/granite/base/optional.py b/mellea/formatters/granite/base/optional.py index 705f1bd2f..5ab91167c 100644 --- a/mellea/formatters/granite/base/optional.py +++ b/mellea/formatters/granite/base/optional.py @@ -2,9 +2,9 @@ """Context-manager helpers for gracefully handling optional import dependencies. -Provides ``import_optional``, a context manager that catches ``ImportError`` and -re-raises it with a human-readable install hint (e.g. ``pip install [extra]``), -and ``nltk_check``, a variant tailored to NLTK data-download errors. Used by Granite +Provides `import_optional`, a context manager that catches `ImportError` and +re-raises it with a human-readable install hint (e.g. `pip install [extra]`), +and `nltk_check`, a variant tailored to NLTK data-download errors. Used by Granite formatter modules that have optional third-party dependencies. """ @@ -27,7 +27,7 @@ def import_optional(extra_name: str): Args: extra_name: Package extra to suggest in the install hint - (e.g. ``pip install granite_io[extra_name]``). + (e.g. `pip install granite_io[extra_name]`). """ try: yield @@ -49,7 +49,7 @@ def nltk_check(feature_name: str): feature_name: Name of the feature that requires NLTK, used in the error message. Raises: - ImportError: If the ``nltk`` package is not installed, re-raised with + ImportError: If the `nltk` package is not installed, re-raised with a descriptive message and installation instructions. """ try: diff --git a/mellea/formatters/granite/base/types.py b/mellea/formatters/granite/base/types.py index 8c747c012..928792a88 100644 --- a/mellea/formatters/granite/base/types.py +++ b/mellea/formatters/granite/base/types.py @@ -2,10 +2,10 @@ """Common Pydantic types shared across the Granite formatter package. -Defines reusable Pydantic models and mixins, including ``NoDefaultsMixin`` (which +Defines reusable Pydantic models and mixins, including `NoDefaultsMixin` (which suppresses unset default fields from serialized JSON output) and message/request -types for Granite model chat completions (``ChatMessage``, ``ChatCompletion``, -``VLLMExtraBody``, ``ChatCompletionLogProbs``, and related classes). These types are +types for Granite model chat completions (`ChatMessage`, `ChatCompletion`, +`VLLMExtraBody`, `ChatCompletionLogProbs`, and related classes). These types are consumed internally by the Granite intrinsic formatters. """ @@ -37,7 +37,7 @@ def _workaround_for_design_flaw_in_pydantic(self, nxt): See https://github.com/pydantic/pydantic/issues/4554 for the relevant dismissive comment from the devs. This comment suggests overriding :func:`dict()`, but that method was disabled a year later. Now you need to add a custom serializer method - with a ``@model_serializer`` decorator. + with a `@model_serializer` decorator. See the docs at https://docs.pydantic.dev/latest/api/functional_serializers/ @@ -130,7 +130,7 @@ class UserMessage(_ChatMessageBase): """User message for an IBM Granite model chat completion request. Attributes: - role (str): Always ``"user"``, identifying the message sender. + role (str): Always `"user"`, identifying the message sender. """ role: Literal["user"] = "user" @@ -143,7 +143,7 @@ class DocumentMessage(_ChatMessageBase): completion request. Attributes: - role (str): A string matching the pattern ``"document "``, + role (str): A string matching the pattern `"document "`, identifying this message as a document fragment. """ @@ -180,7 +180,7 @@ class AssistantMessage(_ChatMessageBase): completion request. Attributes: - role (str): Always ``"assistant"``, identifying the message sender. + role (str): Always `"assistant"`, identifying the message sender. tool_calls (list[ToolCall] | None): Optional list of tool calls requested by the assistant during this turn. reasoning_content (str | None): Optional chain-of-thought or reasoning @@ -199,7 +199,7 @@ class ToolResultMessage(_ChatMessageBase): request. Attributes: - role (str): Always ``"tool"``, identifying this as a tool-result message. + role (str): Always `"tool"`, identifying this as a tool-result message. tool_call_id (str): The identifier of the tool call this message responds to. """ @@ -211,7 +211,7 @@ class SystemMessage(_ChatMessageBase): """System message for an IBM Granite model chat completion request. Attributes: - role (str): Always ``"system"``, identifying this as a system-level instruction. + role (str): Always `"system"`, identifying this as a system-level instruction. """ role: Literal["system"] = "system" @@ -221,7 +221,7 @@ class DeveloperMessage(_ChatMessageBase): """Developer system message for a chat completion request. Attributes: - role (str): Always ``"developer"``, identifying this as a developer-role message. + role (str): Always `"developer"`, identifying this as a developer-role message. """ role: Literal["developer"] = "developer" @@ -240,7 +240,7 @@ class DeveloperMessage(_ChatMessageBase): class ToolDefinition(pydantic.BaseModel, NoDefaultsMixin): - """An entry in the ``tools`` list in an IBM Granite model chat completion request. + """An entry in the `tools` list in an IBM Granite model chat completion request. Attributes: name (str): The name used to identify and invoke the tool. @@ -281,7 +281,7 @@ class Document(pydantic.BaseModel, NoDefaultsMixin): class ChatTemplateKwargs(pydantic.BaseModel): """Keyword arguments for chat template. - Values that can appear in the ``chat_template_kwargs`` portion of a valid chat + Values that can appear in the `chat_template_kwargs` portion of a valid chat completion request for a Granite model. Attributes: @@ -308,8 +308,8 @@ class VLLMExtraBody(pydantic.BaseModel, NoDefaultsMixin): Attributes: documents (list[Document] | None): RAG documents made accessible to the model during generation, if the template supports RAG. - add_generation_prompt (bool): When ``True``, the generation prompt is - appended to the rendered chat template. Defaults to ``True``. + add_generation_prompt (bool): When `True`, the generation prompt is + appended to the rendered chat template. Defaults to `True`. chat_template_kwargs (ChatTemplateKwargs | None): Additional keyword arguments forwarded to the chat template renderer. structured_outputs (dict | None): Optional JSON schema that constrains @@ -396,7 +396,7 @@ def _documents(self) -> list[Document] | None: """Fetch documents attached to chat completion. Convenience method for internal code to fetch documents attached to the - chat completion without having to dig into ``extra_body``. + chat completion without having to dig into `extra_body`. """ if self.extra_body: return self.extra_body.documents @@ -406,7 +406,7 @@ def _chat_template_kwargs(self) -> ChatTemplateKwargs | None: """Fetch chat template arguments. Convenience method for internal code to fetch chat template arguments - without having to dig into ``extra_body``. + without having to dig into `extra_body`. """ if self.extra_body: return self.extra_body.chat_template_kwargs @@ -431,7 +431,7 @@ class GraniteChatCompletion(ChatCompletion): def _validate_vllm_stuff_in_extra_body(self): """Validate non-standard VLLM fields. - Non-standard VLLM fields should be passed via the ``extra_body`` parameter. + Non-standard VLLM fields should be passed via the `extra_body` parameter. Make sure the user didn't stuff them into the root, which is currently set up to allow arbitrary additional fields. """ @@ -452,10 +452,10 @@ def _validate_documents_at_top_level(self): """Validate documents at top level. Documents for a Granite model chat completion request should be passed in the - ``documents`` argument at the top level of the ``extra_body`` portion of the + `documents` argument at the top level of the `extra_body` portion of the request. - Detect cases where the documents are hanging off of ``chat_template_kwargs`` + Detect cases where the documents are hanging off of `chat_template_kwargs` and sanitize appropriately. """ if self is None: @@ -528,7 +528,7 @@ class ChatCompletionLogProb(pydantic.BaseModel, NoDefaultsMixin): Attributes: token (str): The decoded token string. logprob (float): The log-probability of the token. Defaults to - ``-9999.0`` when not returned by the server. + `-9999.0` when not returned by the server. bytes (list[int] | None): The UTF-8 byte values of the token, if provided by the server. """ @@ -566,7 +566,7 @@ class ChatCompletionLogProbs(pydantic.BaseModel, NoDefaultsMixin): Attributes: content (list[ChatCompletionLogProbsContent] | None): Per-token - log-probability entries for each generated token, or ``None`` if + log-probability entries for each generated token, or `None` if logprobs were not requested. """ @@ -588,7 +588,7 @@ class ChatCompletionResponseChoice(pydantic.BaseModel, NoDefaultsMixin): logprobs (ChatCompletionLogProbs | None): Token log-probabilities for this choice, if they were requested. finish_reason (str | None): The reason the model stopped generating. - Defaults to ``"stop"`` per the OpenAI specification. + Defaults to `"stop"` per the OpenAI specification. """ index: int diff --git a/mellea/formatters/granite/base/util.py b/mellea/formatters/granite/base/util.py index 4d02ef9c6..9dfcf9de0 100644 --- a/mellea/formatters/granite/base/util.py +++ b/mellea/formatters/granite/base/util.py @@ -32,7 +32,7 @@ def import_optional(extra_name: str): Args: extra_name: Package extra to suggest in the install hint - (e.g. ``pip install granite_io[extra_name]``). + (e.g. `pip install granite_io[extra_name]`). """ try: yield @@ -54,7 +54,7 @@ def nltk_check(feature_name: str): feature_name: Name of the feature that requires NLTK, used in the error message. Raises: - ImportError: If the ``nltk`` package is not installed, re-raised with + ImportError: If the `nltk` package is not installed, re-raised with a descriptive message and installation instructions. """ try: @@ -78,7 +78,7 @@ def find_substring_in_text(substring: str, text: str) -> list[dict]: text: The string to search within. Returns: - List of dicts with ``begin_idx`` and ``end_idx`` for each match found. + List of dicts with `begin_idx` and `end_idx` for each match found. """ span_matches = [] @@ -105,18 +105,18 @@ def load_transformers_lora(local_or_remote_path): pass it a LoRA adapter's config, but that auto-loading is very broken as of 8/2025. Workaround powers activate! - Only works if ``transformers`` and ``peft`` are installed. + Only works if `transformers` and `peft` are installed. Args: local_or_remote_path: Local directory path of the LoRA adapter. Returns: - Tuple of ``(model, tokenizer)`` where ``model`` is the loaded LoRA model and - ``tokenizer`` is the corresponding HuggingFace tokenizer. + Tuple of `(model, tokenizer)` where `model` is the loaded LoRA model and + `tokenizer` is the corresponding HuggingFace tokenizer. Raises: - ImportError: If ``peft`` or ``transformers`` packages are not installed. - NotImplementedError: If ``local_or_remote_path`` does not exist locally + ImportError: If `peft` or `transformers` packages are not installed. + NotImplementedError: If `local_or_remote_path` does not exist locally (remote loading from the Hugging Face Hub is not yet implemented). """ with import_optional("peft"): @@ -141,7 +141,7 @@ def chat_completion_request_to_transformers_inputs( """Translate an OpenAI-style chat completion request. Translate an OpenAI-style chat completion request into an input for a Transformers - ``generate()`` call. + `generate()` call. Args: request: Request as parsed JSON or equivalent dataclass. @@ -152,17 +152,17 @@ def chat_completion_request_to_transformers_inputs( constrained_decoding_prefix: Optional generation prefix to append to the prompt. Returns: - Tuple of ``(generate_input, other_input)`` where ``generate_input`` contains - kwargs to pass directly to ``generate()`` and ``other_input`` contains - additional parameters for ``generate_with_transformers``. + Tuple of `(generate_input, other_input)` where `generate_input` contains + kwargs to pass directly to `generate()` and `other_input` contains + additional parameters for `generate_with_transformers`. Raises: - ImportError: If ``torch``, ``transformers``, or ``xgrammar`` packages + ImportError: If `torch`, `transformers`, or `xgrammar` packages are not installed (the latter only when constrained decoding is used). - TypeError: If ``tokenizer.apply_chat_template()`` returns an unexpected type. + TypeError: If `tokenizer.apply_chat_template()` returns an unexpected type. ValueError: If padding or end-of-sequence token IDs cannot be determined from the tokenizer, or if a constrained-decoding request is made - without passing a ``tokenizer`` or ``model`` argument. + without passing a `tokenizer` or `model` argument. """ with import_optional("torch"): # Third Party @@ -329,10 +329,10 @@ def generate_with_transformers( tokenizer: HuggingFace tokenizer for the model, required at several stages of generation. model: Initialized HuggingFace model object. - generate_input: Parameters to pass to the ``generate()`` method, usually - produced by ``chat_completion_request_to_transformers_inputs()``. + generate_input: Parameters to pass to the `generate()` method, usually + produced by `chat_completion_request_to_transformers_inputs()`. other_input: Additional kwargs produced by - ``chat_completion_request_to_transformers_inputs()`` for aspects of the + `chat_completion_request_to_transformers_inputs()` for aspects of the original request that Transformers APIs don't handle natively. Returns: diff --git a/mellea/formatters/granite/granite3/granite32/input.py b/mellea/formatters/granite/granite3/granite32/input.py index a5333cf9f..a23a9a6c3 100644 --- a/mellea/formatters/granite/granite3/granite32/input.py +++ b/mellea/formatters/granite/granite3/granite32/input.py @@ -32,7 +32,7 @@ class Granite32InputProcessor(Granite3InputProcessor): This input processor is based on the Jinja template that was used during supervised fine tuning of these models. This template is as follows: - ``` + `` {%- if messages[0]['role'] == 'system' %} {%- set system_message = messages[0]['content'] %} {%- set loop_messages = messages[1:] %} @@ -122,7 +122,7 @@ class Granite32InputProcessor(Granite3InputProcessor): {{- '<|end_of_role|>' }} {%- endif %} {%- endfor %} - ``` + `` """ def _build_default_system_message( @@ -227,7 +227,7 @@ def sanitize(cls, chat_completion, parts="all"): Args: chat_completion: The chat completion request to sanitize. parts (str): Which parts of the chat completion to sanitize; - defaults to ``"all"``. + defaults to `"all"`. Returns: The sanitized chat completion with all Granite 3.2 special tokens @@ -243,17 +243,17 @@ def transform( Args: chat_completion (ChatCompletion): The structured chat completion request to convert into a tokenizer-ready prompt string. - add_generation_prompt (bool): When ``True``, appends the assistant role + add_generation_prompt (bool): When `True`, appends the assistant role header to the end of the prompt to trigger generation. Defaults to - ``True``. + `True`. Returns: str: The prompt string formatted for the Granite 3.2 model tokenizer. Raises: ValueError: If conflicting options are specified, such as enabling - ``thinking`` mode together with documents, tools, or a custom - system message; or enabling ``citations`` or ``hallucinations`` + `thinking` mode together with documents, tools, or a custom + system message; or enabling `citations` or `hallucinations` with a custom system message. """ chat_completion = Granite32ChatCompletion.model_validate( diff --git a/mellea/formatters/granite/granite3/granite32/output.py b/mellea/formatters/granite/granite3/granite32/output.py index e3c4b44f0..4d24cc1ab 100644 --- a/mellea/formatters/granite/granite3/granite32/output.py +++ b/mellea/formatters/granite/granite3/granite32/output.py @@ -6,7 +6,7 @@ output. The input to the parser is assumed to be as follows: - ``` + `` response_text # Citations: @@ -14,7 +14,7 @@ # Hallucinations: hallucinations_text - ``` + `` The output from the lowest level of the parser is a dictionary as follows: @@ -24,7 +24,7 @@ * "response": Model response text without the above constituents This dict is further refined into dataclasses before being returned as an extended -``AssistantMessage``. +`AssistantMessage`. """ # Standard @@ -75,7 +75,7 @@ def _parse_citations_text(citations_text: str) -> list[dict]: Given the citations text output by model under the "# Citations:" section, extract the citation info as an array of the form: - ``` + `` [ { "citation_id": "Citation ID output by model", @@ -84,7 +84,7 @@ def _parse_citations_text(citations_text: str) -> list[dict]: }, ... ] - ``` + `` """ citations: list[dict] = [] @@ -619,9 +619,9 @@ def transform( Args: model_output (str): Raw text output from the Granite 3.2 model. chat_completion (ChatCompletion | None): The original chat completion - request that produced ``model_output``. Used to determine which + request that produced `model_output`. Used to determine which output features (thinking, tools, citations, hallucinations) to - parse. Defaults to ``None``. + parse. Defaults to `None`. Returns: AssistantMessage: A :class:`Granite3AssistantMessage` containing the diff --git a/mellea/formatters/granite/granite3/granite33/input.py b/mellea/formatters/granite/granite3/granite33/input.py index 2bb8f899d..86ff30e54 100644 --- a/mellea/formatters/granite/granite3/granite33/input.py +++ b/mellea/formatters/granite/granite3/granite33/input.py @@ -32,9 +32,9 @@ class Granite33InputProcessor(Granite3InputProcessor): This input processor is based on the Jinja template from tokenizer_config.json. - ``` + `` "chat_template": "{# Alias tools -> available_tools #}\n{%- if tools and not available_tools -%}\n {%- set available_tools = tools -%}\n{%- endif -%}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n {%- else %}\n {%- set system_message = \" Knowledge Cutoff Date: April 2024.\n Today's Date: \" + strftime_now('%B %d, %Y') + \". You are Granite, developed by IBM.\" %}\n {%- if available_tools and documents %}\n {%- set system_message = system_message + \" You are a helpful assistant with access to the following tools. When a tool is required to answer the user's query, respond only with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request. \nWrite the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data.\" %}\n {%- elif available_tools %}\n {%- set system_message = system_message + \" You are a helpful assistant with access to the following tools. When a tool is required to answer the user's query, respond only with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request.\" %}\n {%- elif documents %}\n {%- set system_message = system_message + \" Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data.\" %}\n {%- elif thinking %}\n {%- set system_message = system_message + \" You are a helpful AI assistant.\nRespond to every user query in a comprehensive and detailed way. You can write down your thoughts and reasoning process before responding. In the thought process, engage in a comprehensive cycle of analysis, summarization, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. In the response section, based on various attempts, explorations, and reflections from the thoughts section, systematically present the final solution that you deem correct. The response should summarize the thought process. Write your thoughts between and write your response between for each user query.\" %}\n {%- else %}\n {%- set system_message = system_message + \" You are a helpful AI assistant.\" %}\n {%- endif %}\n {%- if 'citations' in controls and documents %}\n {%- set system_message = system_message + ' \nUse the symbols <|start_of_cite|> and <|end_of_cite|> to indicate when a fact comes from a document in the search result, e.g <|start_of_cite|> {document_id: 1}my fact <|end_of_cite|> for a fact from document 1. Afterwards, list all the citations with their corresponding documents in an ordered list.' %}\n {%- endif %}\n {%- if 'hallucinations' in controls and documents %}\n {%- set system_message = system_message + ' \nFinally, after the response is written, include a numbered list of sentences from the response with a corresponding risk value that are hallucinated and not based in the documents.' %}\n {%- endif %}\n {%- set loop_messages = messages %}\n {%- endif %}\n {{- '<|start_of_role|>system<|end_of_role|>' + system_message + '<|end_of_text|>\n' }}\n {%- if available_tools %}\n {{- '<|start_of_role|>available_tools<|end_of_role|>' }}\n {{- available_tools | tojson(indent=4) }}\n {{- '<|end_of_text|>\n' }}\n {%- endif %}\n {%- if documents %}\n {%- for document in documents %}\n {{- '<|start_of_role|>document {\"document_id\": \"' + document['doc_id'] | string + '\"}<|end_of_role|>\n' }}\n {{- document['text'] }}\n {{- '<|end_of_text|>\n' }}\n {%- endfor %}\n {%- endif %}\n {%- for message in loop_messages %}\n {{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- if loop.last and add_generation_prompt %}\n {{- '<|start_of_role|>assistant' }}\n {%- if controls %}\n {{- ' ' + controls | tojson()}}\n {%- endif %}\n {{- '<|end_of_role|>' }}\n {%- endif %}\n {%- endfor %}", - ``` + `` """ def _build_default_system_message( @@ -147,7 +147,7 @@ def sanitize(cls, chat_completion, parts="all"): Args: chat_completion: The chat completion request to sanitize. parts (str): Which parts of the chat completion to sanitize; - defaults to ``"all"``. + defaults to `"all"`. Returns: The sanitized chat completion with all Granite 3.3 special tokens @@ -165,17 +165,17 @@ def transform( Args: chat_completion (ChatCompletion): The structured chat completion request to convert into a tokenizer-ready prompt string. - add_generation_prompt (bool): When ``True``, appends the assistant role + add_generation_prompt (bool): When `True`, appends the assistant role header to the end of the prompt to trigger generation. Defaults to - ``True``. + `True`. Returns: str: The prompt string formatted for the Granite 3.3 model tokenizer. Raises: ValueError: If conflicting options are specified, such as enabling - ``thinking`` mode together with documents, tools, or a custom - system message; or enabling ``citations`` or ``hallucinations`` + `thinking` mode together with documents, tools, or a custom + system message; or enabling `citations` or `hallucinations` with a custom system message. """ # Downcast to a Granite-specific request type with possible additional fields. diff --git a/mellea/formatters/granite/granite3/granite33/output.py b/mellea/formatters/granite/granite3/granite33/output.py index 34893de22..9aa084621 100644 --- a/mellea/formatters/granite/granite3/granite33/output.py +++ b/mellea/formatters/granite/granite3/granite33/output.py @@ -10,7 +10,7 @@ * "response": Model response text without the above constituents This dict is further refined into dataclasses before being returned as an extended -``AssistantMessage``. +`AssistantMessage`. """ # Standard @@ -63,7 +63,7 @@ def _parse_citations_text(citations_text: str) -> list[dict]: Given the citations text output by model under the "# Citations:" section, extract the citation info as an array of the form: - ``` + `` [ { "citation_id": "Citation ID output by model", @@ -72,7 +72,7 @@ def _parse_citations_text(citations_text: str) -> list[dict]: }, ... ] - ``` + `` """ citations: list[dict] = [] @@ -528,9 +528,9 @@ def transform( Args: model_output (str): Raw text output from the Granite 3.3 model. chat_completion (ChatCompletion | None): The original chat completion - request that produced ``model_output``. Used to determine which + request that produced `model_output`. Used to determine which output features (thinking, tools, citations, hallucinations) to - parse. Defaults to ``None``. + parse. Defaults to `None`. Returns: AssistantMessage: A :class:`Granite3AssistantMessage` containing the diff --git a/mellea/formatters/granite/granite3/input.py b/mellea/formatters/granite/granite3/input.py index 558238f05..eb9b4da7a 100644 --- a/mellea/formatters/granite/granite3/input.py +++ b/mellea/formatters/granite/granite3/input.py @@ -81,7 +81,7 @@ def _message_to_prompt_string(message: UserMessage | AssistantMessage) -> str: def _build_controls_record(chat_completion: Granite3ChatCompletion) -> dict | None: """Build a Granite 3 controls record. - Use the output control flags in ``inputs`` to build a version of the + Use the output control flags in `inputs` to build a version of the undocumented arbitrary JSON data regarding output controls that the Jinja template expected to see in the input for each chat completion request. diff --git a/mellea/formatters/granite/granite3/output.py b/mellea/formatters/granite/granite3/output.py index d3c31d96a..c6d2af0d6 100644 --- a/mellea/formatters/granite/granite3/output.py +++ b/mellea/formatters/granite/granite3/output.py @@ -81,7 +81,7 @@ def parse_hallucinations_text(hallucinations_text: str) -> list[dict]: hallucinations_text: Raw text from the model's "# Hallucinations:" section. Returns: - List of dicts, each with ``hallucination_id``, ``risk``, and ``response_text`` keys. + List of dicts, each with `hallucination_id`, `risk`, and `response_text` keys. """ hallucinations = [] @@ -178,14 +178,14 @@ def add_hallucination_response_spans( Args: hallucination_info: Parsed hallucination list as returned by - ``parse_hallucinations_text``. + `parse_hallucinations_text`. response_text_without_citations: Full response text with citation tags removed. remove_citations_from_response_text: Callable that strips citation tags from a substring of the response. Returns: - Deep copy of ``hallucination_info`` with ``response_text``, ``response_begin``, - and ``response_end`` populated for each entry. + Deep copy of `hallucination_info` with `response_text`, `response_begin`, + and `response_end` populated for each entry. """ augmented_hallucination_info = copy.deepcopy(hallucination_info) @@ -243,11 +243,11 @@ def add_citation_context_spans( Args: citation_info: List of citation dicts as produced by the model output parser. - docs: List of source document dicts, each with ``citation_id``, ``doc_id``, - and ``text`` keys. + docs: List of source document dicts, each with `citation_id`, `doc_id`, + and `text` keys. Returns: - Deep copy of ``citation_info`` with ``context_begin`` and ``context_end`` + Deep copy of `citation_info` with `context_begin` and `context_end` populated for each entry. """ augmented_citation_info = copy.deepcopy(citation_info) diff --git a/mellea/formatters/granite/granite3/types.py b/mellea/formatters/granite/granite3/types.py index 4182537c9..a41c6db95 100644 --- a/mellea/formatters/granite/granite3/types.py +++ b/mellea/formatters/granite/granite3/types.py @@ -24,13 +24,13 @@ class Hallucination(pydantic.BaseModel): Attributes: hallucination_id (str): Unique identifier for the hallucination entry. - risk (str): Risk level of the hallucination, e.g. ``"low"`` or ``"high"``. + risk (str): Risk level of the hallucination, e.g. `"low"` or `"high"`. reasoning (str | None): Optional model-provided reasoning for why this sentence was flagged. response_text (str): The portion of the response text that is flagged. - response_begin (int): Start character offset of ``response_text`` within + response_begin (int): Start character offset of `response_text` within the full response string. - response_end (int): End character offset (exclusive) of ``response_text`` + response_end (int): End character offset (exclusive) of `response_text` within the full response string. """ @@ -49,15 +49,15 @@ class Citation(pydantic.BaseModel): citation_id (str): Unique identifier assigned to this citation. doc_id (str): Identifier of the source document being cited. context_text (str): Verbatim text from the source document that is cited. - context_begin (int): Start character offset of ``context_text`` within + context_begin (int): Start character offset of `context_text` within the source document. - context_end (int): End character offset (exclusive) of ``context_text`` + context_end (int): End character offset (exclusive) of `context_text` within the source document. response_text (str): The portion of the response text that makes this citation. - response_begin (int): Start character offset of ``response_text`` within + response_begin (int): Start character offset of `response_text` within the response string. - response_end (int): End character offset (exclusive) of ``response_text`` + response_end (int): End character offset (exclusive) of `response_text` within the response string. """ @@ -79,14 +79,14 @@ class Granite3Controls(pydantic.BaseModel): and originality style. Attributes: - citations (bool | None): When ``True``, instructs the model to annotate + citations (bool | None): When `True`, instructs the model to annotate factual claims with inline citation markers. - hallucinations (bool | None): When ``True``, instructs the model to + hallucinations (bool | None): When `True`, instructs the model to append a list of sentences that may be hallucinated. - length (str | None): Requested response length; must be ``"short"``, - ``"long"``, or ``None`` for no constraint. + length (str | None): Requested response length; must be `"short"`, + `"long"`, or `None` for no constraint. originality (str | None): Requested response originality style; must be - ``"extractive"``, ``"abstractive"``, or ``None``. + `"extractive"`, `"abstractive"`, or `None`. """ citations: bool | None = None @@ -127,8 +127,8 @@ class Granite3Kwargs(ChatTemplateKwargs, NoDefaultsMixin): controls (Granite3Controls | None): Optional output control flags that enable or configure citations, hallucination detection, response length, and originality style. - thinking (bool): When ``True``, enables chain-of-thought reasoning mode. - Defaults to ``False``. + thinking (bool): When `True`, enables chain-of-thought reasoning mode. + Defaults to `False`. """ controls: Granite3Controls | None = None @@ -158,8 +158,8 @@ def thinking(self) -> bool: """Return whether chain-of-thought thinking mode is enabled. Returns: - bool: ``True`` if the ``thinking`` flag is set in the chat template - kwargs; ``False`` otherwise. + bool: `True` if the `thinking` flag is set in the chat template + kwargs; `False` otherwise. """ kwargs = self.extra_body.chat_template_kwargs if self.extra_body else None return bool(kwargs and isinstance(kwargs, Granite3Kwargs) and kwargs.thinking) @@ -170,7 +170,7 @@ def _validate_chat_template_kwargs(cls, extra_body: VLLMExtraBody) -> VLLMExtraB """Validate Granite 3 chat template kwargs and convert to dataclass. Validates kwargs that are specific to Granite 3 chat templates and converts - the ``chat_template_kwargs`` field to a Granite 3-specific dataclass. + the `chat_template_kwargs` field to a Granite 3-specific dataclass. Other arguments are currently passed through without checking. """ diff --git a/mellea/formatters/granite/intrinsics/constants.py b/mellea/formatters/granite/intrinsics/constants.py index f418578cc..f880437fa 100644 --- a/mellea/formatters/granite/intrinsics/constants.py +++ b/mellea/formatters/granite/intrinsics/constants.py @@ -24,7 +24,7 @@ "generative-computing/core-intrinsics-lib", ] """Repositories (aka "models") on Hugging Face Hub that use the old layout of -``//``. +`//`. """ BASE_MODEL_TO_CANONICAL_NAME = { diff --git a/mellea/formatters/granite/intrinsics/input.py b/mellea/formatters/granite/intrinsics/input.py index c578c62d2..a44cf5870 100644 --- a/mellea/formatters/granite/intrinsics/input.py +++ b/mellea/formatters/granite/intrinsics/input.py @@ -24,7 +24,7 @@ def _needs_logprobs(transformations: list | None) -> bool: :param transformations: Contents of the field by the same name in the YAML file :type transformations: list - :return: ``True`` if this intrinsic produces a field for which logprobs need to be + :return: `True` if this intrinsic produces a field for which logprobs need to be enabled for downstream result decoding to succeed. :rtype: bool """ @@ -37,7 +37,7 @@ def sentence_delimiter(tag, sentence_num) -> str: """Return a tag string that identifies the beginning of the indicated sentence. Args: - tag: Tag string prefix, e.g. ``"i"`` or ``"c"``. + tag: Tag string prefix, e.g. `"i"` or `"c"`. sentence_num: Zero-based index of the sentence. Returns: @@ -53,7 +53,7 @@ def mark_sentence_boundaries( """Modify input strings by inserting sentence boundary markers. Modify one or more input strings by inserting a tag in the form - ``<[prefix][number]>`` + `<[prefix][number]>` at the location of each sentence boundary. Args: @@ -87,19 +87,19 @@ def move_documents_to_message( Args: chat_completion: A chat completion request as dataclass or parsed JSON. - how: How to serialize the documents; supported values are ``"string"``, - ``"json"``, and ``"roles"``. + how: How to serialize the documents; supported values are `"string"`, + `"json"`, and `"roles"`. Returns: - A copy of ``chat_completion`` with any documents under ``extra_body`` + A copy of `chat_completion` with any documents under `extra_body` moved to the first message. Returned type will be the same as the input type. May return original object if no edits are necessary. Raises: - TypeError: If ``chat_completion`` is not a :class:`ChatCompletion` or - ``dict``. - ValueError: If ``how`` is not one of ``"string"``, ``"json"``, or - ``"roles"``. + TypeError: If `chat_completion` is not a :class:`ChatCompletion` or + `dict`. + ValueError: If `how` is not one of `"string"`, `"json"`, or + `"roles"`. """ if isinstance(chat_completion, ChatCompletion): should_return_dataclass = True @@ -172,9 +172,9 @@ class IntrinsicsRewriter(ChatCompletionRewriter): Args: config_file (str | pathlib.Path | None): Path to the YAML configuration file for the - target intrinsic. Mutually exclusive with ``config_dict``. + target intrinsic. Mutually exclusive with `config_dict`. config_dict (dict | None): Inline configuration dictionary. Mutually exclusive with - ``config_file``. + `config_file`. model_name (str | None): Optional model name used to locate model-specific overrides within the configuration. @@ -184,18 +184,18 @@ class IntrinsicsRewriter(ChatCompletionRewriter): parameters (dict): Additional parameters (key-value pairs) that this rewriter adds to all chat completion requests. extra_body_parameters (dict): Extended vLLM-specific parameters that go - under the ``extra_body`` element of each request. These are merged - with any existing ``extra_body`` content in incoming requests. + under the `extra_body` element of each request. These are merged + with any existing `extra_body` content in incoming requests. instruction (str | None): Optional instruction template. When present, a new user message is appended with the formatted instruction. sentence_boundaries (dict[str, str] | None): Optional sentence-boundary - marking specification, mapping location strings (``"last_message"`` - or ``"documents"``) to marker prefixes (e.g. ``"c"`` produces - ````, ````, …). + marking specification, mapping location strings (`"last_message"` + or `"documents"`) to marker prefixes (e.g. `"c"` produces + ``, ``, …). docs_as_message (str | None): Optional specification for moving - documents from ``extra_body/documents`` to a user message at the - start of the messages list. Value must be ``"string"``, ``"json"``, - or ``"roles"``. + documents from `extra_body/documents` to a user message at the + start of the messages list. Value must be `"string"`, `"json"`, + or `"roles"`. """ config: dict @@ -209,8 +209,8 @@ class IntrinsicsRewriter(ChatCompletionRewriter): completion requests.""" extra_body_parameters: dict - """Extended vLLM-specific parameters that go under the ``extra_body`` element of - the parameters field. These parameters need to be merged with any ``extra_body`` + """Extended vLLM-specific parameters that go under the `extra_body` element of + the parameters field. These parameters need to be merged with any `extra_body` content that is present in incoming requests.""" instruction: str | None @@ -226,7 +226,7 @@ class IntrinsicsRewriter(ChatCompletionRewriter): docs_as_message: str | None """ - Optional specification for moving documents from ``extra_body/documents`` to a + Optional specification for moving documents from `extra_body/documents` to a user message at the beginning of the messages list. Value specifies how to serialize the documents into the message: "string" or "json". """ diff --git a/mellea/formatters/granite/intrinsics/json_util.py b/mellea/formatters/granite/intrinsics/json_util.py index 573802680..cec4c7c49 100644 --- a/mellea/formatters/granite/intrinsics/json_util.py +++ b/mellea/formatters/granite/intrinsics/json_util.py @@ -2,7 +2,7 @@ """JSON parsing utilities for Granite intrinsic formatters. -Provides a fast, position-aware JSON literal parser (``JsonLiteralWithPosition``) used +Provides a fast, position-aware JSON literal parser (`JsonLiteralWithPosition`) used to extract and re-score tokens inside structured model outputs. The module also defines compiled regular expressions for JSON structural characters, numbers, booleans, and null values that are used throughout the Granite intrinsic formatting pipeline. @@ -68,7 +68,7 @@ def find_string_offsets(json_data: str) -> list[tuple[int, int, str]]: json_data: String containing valid JSON. Returns: - Begin and end offsets of all strings in ``json_data``, including + Begin and end offsets of all strings in `json_data`, including the double quotes. """ result = [] @@ -89,11 +89,11 @@ def non_string_offsets(json_str, compiled_regex, string_begins, string_ends): Args: json_str: Original string of valid JSON data. compiled_regex: Compiled regex for the target token type. - string_begins: Table of string begin offsets within ``json_str``. - string_ends: Table of string end offsets within ``json_str``. + string_begins: Table of string begin offsets within `json_str`. + string_ends: Table of string end offsets within `json_str`. Returns: - List of ``(begin, end, matched_string)`` tuples. + List of `(begin, end, matched_string)` tuples. """ offsets = [] for match in compiled_regex.finditer(json_str): @@ -117,7 +117,7 @@ def tokenize_json(json_str: str): json_str: String representation of valid JSON data. Returns: - List of tuples of ``(begin, end, value, type)``. + List of tuples of `(begin, end, value, type)`. """ string_offsets = find_string_offsets(json_str) string_begins = [s[0] for s in string_offsets] @@ -150,11 +150,11 @@ def reparse_value(tokens, offset) -> tuple[Any, int]: Assumes valid JSON. Args: - tokens: Token stream as produced by ``tokenize_json()``. + tokens: Token stream as produced by `tokenize_json()`. offset: Token offset at which to start parsing. Returns: - Tuple of ``(parsed_value, next_offset)``. + Tuple of `(parsed_value, next_offset)`. Raises: ValueError: If an unexpected delimiter token or unknown token type is @@ -177,19 +177,19 @@ def reparse_value(tokens, offset) -> tuple[Any, int]: def reparse_object(tokens, offset) -> tuple[dict, int]: - """Parse a JSON object from the token stream, starting after the opening ``{``. + """Parse a JSON object from the token stream, starting after the opening `{`. Subroutine called by :func:`reparse_value` when an opening curly brace is - encountered. Consumes tokens until the matching closing ``}`` is found. + encountered. Consumes tokens until the matching closing `}` is found. Args: - tokens: Token stream as produced by ``tokenize_json()``. - offset (int): Token offset immediately after the opening ``{`` delimiter. + tokens: Token stream as produced by `tokenize_json()`. + offset (int): Token offset immediately after the opening `{` delimiter. Returns: - tuple[dict, int]: A tuple of ``(parsed_dict, next_offset)`` where - ``parsed_dict`` maps string keys to parsed values (possibly - :class:`JsonLiteralWithPosition` instances) and ``next_offset`` + tuple[dict, int]: A tuple of `(parsed_dict, next_offset)` where + `parsed_dict` maps string keys to parsed values (possibly + :class:`JsonLiteralWithPosition` instances) and `next_offset` is the position of the next unconsumed token. Raises: @@ -231,19 +231,19 @@ def reparse_object(tokens, offset) -> tuple[dict, int]: def reparse_list(tokens, offset) -> tuple[list, int]: - """Parse a JSON array from the token stream, starting after the opening ``[``. + """Parse a JSON array from the token stream, starting after the opening `[`. Subroutine called by :func:`reparse_value` when an opening square bracket is - encountered. Consumes tokens until the matching closing ``]`` is found. + encountered. Consumes tokens until the matching closing `]` is found. Args: - tokens: Token stream as produced by ``tokenize_json()``. - offset (int): Token offset immediately after the opening ``[`` delimiter. + tokens: Token stream as produced by `tokenize_json()`. + offset (int): Token offset immediately after the opening `[` delimiter. Returns: - tuple[list, int]: A tuple of ``(parsed_list, next_offset)`` where - ``parsed_list`` contains the parsed elements (possibly - :class:`JsonLiteralWithPosition` instances) and ``next_offset`` + tuple[list, int]: A tuple of `(parsed_list, next_offset)` where + `parsed_list` contains the parsed elements (possibly + :class:`JsonLiteralWithPosition` instances) and `next_offset` is the position of the next unconsumed token. Raises: @@ -276,8 +276,8 @@ def reparse_json_with_offsets(json_str: str) -> Any: json_str: String known to contain valid JSON data. Returns: - Parsed representation of ``json_str``, with literals at the leaf nodes of - the parse tree replaced with ``JsonLiteralWithPosition`` instances containing + Parsed representation of `json_str`, with literals at the leaf nodes of + the parse tree replaced with `JsonLiteralWithPosition` instances containing position information. """ tokens = tokenize_json(json_str) @@ -291,7 +291,7 @@ def scalar_paths(parsed_json) -> list[tuple]: parsed_json: JSON data parsed into native Python objects. Returns: - A list of paths to scalar values within ``parsed_json``, where each + A list of paths to scalar values within `parsed_json`, where each path is expressed as a tuple. The root element of a bare scalar is an empty tuple. """ @@ -315,7 +315,7 @@ def all_paths(parsed_json) -> list[tuple]: parsed_json: JSON data parsed into native Python objects. Returns: - A list of paths to all elements of the parse tree of ``parsed_json``, + A list of paths to all elements of the parse tree of `parsed_json`, where each path is expressed as a tuple. The root element of a bare scalar is an empty tuple. """ @@ -335,13 +335,13 @@ def fetch_path(json_value: Any, path: tuple): Args: json_value: Parsed JSON value. path: A tuple of names/numbers that indicates a path from root to a leaf - or internal node of ``json_value``. + or internal node of `json_value`. Returns: The node at the indicated path. Raises: - TypeError: If ``path`` is not a tuple, if a path element is not a string + TypeError: If `path` is not a tuple, if a path element is not a string or integer, or if an intermediate node is not a dict or list. """ if not isinstance(path, tuple): @@ -375,10 +375,10 @@ def replace_path(json_value: Any, path: tuple, new_value: Any) -> Any: new_value: New value to place at the indicated location. Returns: - The modified input, or ``new_value`` itself if the root was replaced. + The modified input, or `new_value` itself if the root was replaced. Raises: - TypeError: If ``path`` is not a tuple, or if any error propagated from + TypeError: If `path` is not a tuple, or if any error propagated from :func:`fetch_path` during path traversal. """ if not isinstance(path, tuple): @@ -395,7 +395,7 @@ def parse_inline_json(json_response: dict) -> dict: """Replace the JSON strings in message contents with parsed JSON. Args: - json_response: Parsed JSON representation of a ``ChatCompletionResponse`` object. + json_response: Parsed JSON representation of a `ChatCompletionResponse` object. Returns: Deep copy of the input with JSON message content strings replaced by parsed @@ -416,12 +416,12 @@ def make_begin_to_token_table(logprobs: ChatCompletionLogProbs | None): """Create a table mapping token begin positions to token indices. Args: - logprobs: The token log probabilities from the chat completion, or ``None`` + logprobs: The token log probabilities from the chat completion, or `None` if the chat completion request did not ask for logprobs. Returns: A dictionary mapping token begin positions to token indices, - or ``None`` if ``logprobs`` is ``None``. + or `None` if `logprobs` is `None`. """ if logprobs is None: return None diff --git a/mellea/formatters/granite/intrinsics/output.py b/mellea/formatters/granite/intrinsics/output.py index aaa3efdae..f01651c7b 100644 --- a/mellea/formatters/granite/intrinsics/output.py +++ b/mellea/formatters/granite/intrinsics/output.py @@ -45,12 +45,12 @@ class TransformationRule(abc.ABC): config (dict): Configuration of the parent output processor, as parsed YAML. input_path_expr (list[str | int | None]): Path expression matching all instances of the field that this rule transforms. Elements can be - strings for object fields, integers for list indices, or ``None`` + strings for object fields, integers for list indices, or `None` for wildcard matches. Attributes: YAML_NAME (str | None): The name used to identify this rule in YAML - configuration files. Subclasses must set this to a non-``None`` string. + configuration files. Subclasses must set this to a non-`None` string. """ YAML_NAME: str | None = None @@ -92,7 +92,7 @@ def _matching_paths(self, parsed_json: Any) -> list[tuple]: :param parsed_json: Output of running model results through :func:`json.loads()`, plus applying zero or more transformation rules. - :returns: List of paths within ``parsed_json`` that match this rule's input + :returns: List of paths within `parsed_json` that match this rule's input path spec. """ return [p for p in json_util.all_paths(parsed_json) if self._is_input_path(p)] @@ -101,10 +101,10 @@ def rule_name(self) -> str: """Return the YAML name that identifies this transformation rule. Returns: - str: The value of ``YAML_NAME`` for this rule subclass. + str: The value of `YAML_NAME` for this rule subclass. Raises: - ValueError: If ``YAML_NAME`` has not been set by the subclass. + ValueError: If `YAML_NAME` has not been set by the subclass. """ if self.YAML_NAME is None: raise ValueError(f"Attempted to fetch missing rule name for {type(self)}") @@ -144,13 +144,13 @@ def apply( through :func:`json_util.reparse_json_with_offsets()`, preserving position information on literal values. logprobs (ChatCompletionLogProbs | None): Optional logprobs result - associated with the original model output string, or ``None`` + associated with the original model output string, or `None` if no logprobs were present. chat_completion (ChatCompletion | None): The chat completion request that produced this output. Required by some rules. Returns: - Any: Transformed copy of ``parsed_json`` after applying this rule. + Any: Transformed copy of `parsed_json` after applying this rule. """ paths = self._matching_paths(parsed_json) prepare_output = self._prepare( @@ -179,7 +179,7 @@ def _apply_at_path(self, result: Any, path: tuple, prepare_output: dict) -> Any: :param prepare_output: Dictionary of global data that this object's :func:`self._prepare()` method has set aside - :returns: A modified version of ``result``, which may be modified in place or + :returns: A modified version of `result`, which may be modified in place or a fresh copy. """ raise NotImplementedError() @@ -188,7 +188,7 @@ def _apply_at_path(self, result: Any, path: tuple, prepare_output: dict) -> Any: class InPlaceTransformation(TransformationRule): """Base class for TransformationRules that replace values in place in JSON. - Base class for ``TransformationRule``s that replace values in place in the source + Base class for `TransformationRule`s that replace values in place in the source JSON. The values replaced can be a scalar, object, or list. """ @@ -218,7 +218,7 @@ def _transform(self, value: Any, path: tuple, prepare_output: dict) -> Any: class AddFieldsTransformation(TransformationRule): """Base class for TransformationRules that add values to JSON. - Base class for ``TransformationRule``s that add one or more values adjacent to + Base class for `TransformationRule`s that add one or more values adjacent to an existing value in the source JSON. """ @@ -231,7 +231,7 @@ def _apply_at_path(self, result: Any, path: tuple, prepare_output: dict) -> Any: :param prepare_output: Dictionary of global data that this object's :func:`self._prepare()` method has set aside - :returns: A modified version of ``result``, which may be modified in place or + :returns: A modified version of `result`, which may be modified in place or a fresh copy. """ if len(path) == 0: @@ -287,11 +287,11 @@ class TokenToFloat(InPlaceTransformation): instances of the field that this rule transforms. categories_to_values (dict[str | int | bool, float] | None): Mapping from categorical labels to floating-point values. Defaults to - ``None``. + `None`. Attributes: YAML_NAME (str): YAML configuration key for this rule; always - ``"likelihood"``. + `"likelihood"`. """ YAML_NAME = "likelihood" @@ -430,7 +430,7 @@ def _desplit_sentences( :param tag: String such as that appears in every sentence boundary marker, e.g. "i" => "" :param first_sentence_num: Number we expect to see in the first sentence boundary - marker in ``target_text``. + marker in `target_text`. :returns: Self-describing dictionary of lists. """ @@ -486,26 +486,26 @@ class DecodeSentences(AddFieldsTransformation): input_path_expr (list[str | int | None]): Path expression matching all instances of the field that this rule transforms. source (str): Name of the location to look for sentences; must be - ``"last_message"`` or ``"documents"``. - output_names (dict): Mapping from output role name (``"begin"``, - ``"end"``, ``"text"``, ``"document_id"``) to the name of the new + `"last_message"` or `"documents"`. + output_names (dict): Mapping from output role name (`"begin"`, + `"end"`, `"text"`, `"document_id"`) to the name of the new field to add in the result JSON. Attributes: YAML_NAME (str): YAML configuration key for this rule; always - ``"decode_sentences"``. + `"decode_sentences"`. begin_name (str | None): Name of the output field that receives the - sentence begin offset; extracted from ``output_names``, or ``None`` + sentence begin offset; extracted from `output_names`, or `None` if not configured. end_name (str | None): Name of the output field that receives the - sentence end offset; extracted from ``output_names``, or ``None`` + sentence end offset; extracted from `output_names`, or `None` if not configured. text_name (str | None): Name of the output field that receives the - sentence text; extracted from ``output_names``, or ``None`` if not + sentence text; extracted from `output_names`, or `None` if not configured. document_id_name (str | None): Name of the output field that receives - the document ID (only used when ``source="documents"``); extracted - from ``output_names``, or ``None`` if not configured. + the document ID (only used when `source="documents"`); extracted + from `output_names`, or `None` if not configured. """ YAML_NAME = "decode_sentences" @@ -521,10 +521,10 @@ def __init__( """Initialize DecodeSentences with a source location and output field name mapping. Raises: - ValueError: If ``source`` is not ``"last_message"`` or - ``"documents"``, or if an unexpected key is found in - ``output_names``. - TypeError: If ``output_names`` is not a dict. + ValueError: If `source` is not `"last_message"` or + `"documents"`, or if an unexpected key is found in + `output_names`. + TypeError: If `output_names` is not a dict. """ super().__init__(config, input_path_expr) @@ -681,7 +681,7 @@ class Explode(InPlaceTransformation): Attributes: YAML_NAME (str): YAML configuration key for this rule; always - ``"explode"``. + `"explode"`. target_field (str): Name of the list-valued field within each record to expand. """ @@ -749,7 +749,7 @@ class DropDuplicates(InPlaceTransformation): Attributes: YAML_NAME (str): YAML configuration key for this rule; always - ``"drop_duplicates"``. + `"drop_duplicates"`. target_fields (list): Names of fields used to determine whether two records are considered duplicates. """ @@ -806,7 +806,7 @@ class Project(InPlaceTransformation): Attributes: YAML_NAME (str): YAML configuration key for this rule; always - ``"project"``. + `"project"`. retained_fields (dict): Mapping from original field name to the (possibly renamed) output field name. Initialized from either a list of field names (identity mapping) or an explicit mapping. @@ -857,7 +857,7 @@ class Nest(InPlaceTransformation): Attributes: YAML_NAME (str): YAML configuration key for this rule; always - ``"nest"``. + `"nest"`. field_name (str): Name of the single field in the output JSON object that wraps each matching value. """ @@ -896,11 +896,11 @@ class MergeSpans(InPlaceTransformation): end_field (str): Name of the field that holds the end offset of spans. text_field (str | None): Optional field containing covered text strings that should be concatenated when spans are merged. Defaults to - ``None``. + `None`. Attributes: YAML_NAME (str): YAML configuration key for this rule; always - ``"merge_spans"``. + `"merge_spans"`. """ YAML_NAME = "merge_spans" @@ -1063,16 +1063,16 @@ def _transform(self, value, path, prepare_output): def _find_final_channel_header(token_strings: list[str]) -> int | None: - """Find the token index of the final ``<|message|>`` token in a token sequence. + """Find the token index of the final `<|message|>` token in a token sequence. - Find the token index of ``<|message|>`` that ends the last - ``<|channel|> final <|message|>`` header in the token sequence. + Find the token index of `<|message|>` that ends the last + `<|channel|> final <|message|>` header in the token sequence. Matches are done on exact token values so that the single special token - ``<|channel|>`` is never confused with regular tokens that happen to - concatenate to the same string (e.g. ``['<|', 'channel', '|>']``). + `<|channel|>` is never confused with regular tokens that happen to + concatenate to the same string (e.g. `['<|', 'channel', '|>']`). - :returns: Index of the ``<|message|>`` token, or ``None``. + :returns: Index of the `<|message|>` token, or `None`. """ last_match = None i = 0 @@ -1111,11 +1111,11 @@ def _logprobs_workaround( This function walks the logprob token sequence, matching individual tokens (not concatenated strings) to locate the final channel header. - This ensures that the single special token ``<|channel|>`` is never + This ensures that the single special token `<|channel|>` is never confused with regular tokens that concatenate to the same string. :param logprobs: Logprobs from a chat completion choice. - :returns: ``(content, trimmed_logprobs)`` or ``None`` if no final channel + :returns: `(content, trimmed_logprobs)` or `None` if no final channel is found. """ if logprobs.content is None: @@ -1164,10 +1164,10 @@ class IntrinsicsResultProcessor(ChatCompletionResultProcessor): Args: config_file (str | pathlib.Path | None): Optional path to a YAML - configuration file. Exactly one of ``config_file`` and - ``config_dict`` must be provided. + configuration file. Exactly one of `config_file` and + `config_dict` must be provided. config_dict (dict | None): Optional pre-parsed YAML configuration dict. - Exactly one of ``config_file`` and ``config_dict`` must be + Exactly one of `config_file` and `config_dict` must be provided. """ diff --git a/mellea/formatters/granite/intrinsics/util.py b/mellea/formatters/granite/intrinsics/util.py index 2e4539d1f..2aea15ad3 100644 --- a/mellea/formatters/granite/intrinsics/util.py +++ b/mellea/formatters/granite/intrinsics/util.py @@ -32,17 +32,17 @@ def make_config_dict( Also parses JSON fields. Args: - config_file: Path to a YAML configuration file. Exactly one of ``config_file`` - and ``config_dict`` must be provided. - config_dict: Pre-parsed configuration dict (from ``yaml.safe_load()``). Exactly - one of ``config_file`` and ``config_dict`` must be provided. + config_file: Path to a YAML configuration file. Exactly one of `config_file` + and `config_dict` must be provided. + config_dict: Pre-parsed configuration dict (from `yaml.safe_load()`). Exactly + one of `config_file` and `config_dict` must be provided. Returns: - Validated configuration dict with optional fields set to ``None`` and JSON + Validated configuration dict with optional fields set to `None` and JSON string fields parsed to Python objects. Raises: - ValueError: If both or neither of ``config_file`` and ``config_dict`` are + ValueError: If both or neither of `config_file` and `config_dict` are provided, if a required field is missing, if an unexpected top-level field is encountered, or if a JSON field cannot be parsed. """ @@ -115,14 +115,14 @@ def obtain_lora( adapter files on local disk. Args: - intrinsic_name: Short name of the intrinsic model, such as ``"certainty"``. + intrinsic_name: Short name of the intrinsic model, such as `"certainty"`. target_model_name: Name of the base model for the LoRA or aLoRA adapter. repo_id: Hugging Face Hub repository containing a collection of LoRA and/or aLoRA adapters for intrinsics. revision: Git revision of the repository to download from. - alora: If ``True``, load the aLoRA version of the intrinsic; otherwise use LoRA. + alora: If `True`, load the aLoRA version of the intrinsic; otherwise use LoRA. cache_dir: Local directory to use as a cache (Hugging Face Hub format), or - ``None`` to use the default location. + `None` to use the default location. file_glob: Only files matching this glob will be downloaded to the cache. Returns: @@ -181,27 +181,27 @@ def obtain_io_yaml( alora: bool = False, cache_dir: str | None = None, ) -> pathlib.Path: - """Download cached ``io.yaml`` configuration file for an intrinsic. + """Download cached `io.yaml` configuration file for an intrinsic. - Downloads an ``io.yaml`` configuration file for an intrinsic + Downloads an `io.yaml` configuration file for an intrinsic with a model repository that follows the format of the [Granite Intrinsics Library]( https://huggingface.co/ibm-granite/granite-lib-rag-r1.0) if one is not already in the local cache. Args: - intrinsic_name: Short name of the intrinsic model, such as ``"certainty"``. + intrinsic_name: Short name of the intrinsic model, such as `"certainty"`. target_model_name: Name of the base model for the LoRA or aLoRA adapter. repo_id: Hugging Face Hub repository containing a collection of LoRA and/or aLoRA adapters for intrinsics. revision: Git revision of the repository to download from. - alora: If ``True``, load the aLoRA version of the intrinsic; otherwise use LoRA. + alora: If `True`, load the aLoRA version of the intrinsic; otherwise use LoRA. cache_dir: Local directory to use as a cache (Hugging Face Hub format), or - ``None`` to use the default location. + `None` to use the default location. Returns: - Full path to the local copy of the ``io.yaml`` file, suitable for passing to - ``IntrinsicsRewriter``. + Full path to the local copy of the `io.yaml` file, suitable for passing to + `IntrinsicsRewriter`. """ lora_dir = obtain_lora( intrinsic_name, diff --git a/mellea/formatters/granite/retrievers/elasticsearch.py b/mellea/formatters/granite/retrievers/elasticsearch.py index 092b52490..a29d6caf1 100644 --- a/mellea/formatters/granite/retrievers/elasticsearch.py +++ b/mellea/formatters/granite/retrievers/elasticsearch.py @@ -12,15 +12,15 @@ class ElasticsearchRetriever: retrieve the top-k matching documents for a given natural language query. Attributes: - hosts (str): Full ``url:port`` connection string to the Elasticsearch - server; stored from the ``host`` constructor argument. + hosts (str): Full `url:port` connection string to the Elasticsearch + server; stored from the `host` constructor argument. Args: corpus_name (str): Name of the Elasticsearch index to query. - host (str): Full ``url:port`` connection string to the Elasticsearch + host (str): Full `url:port` connection string to the Elasticsearch server. **kwargs (Any): Additional keyword arguments forwarded to the - ``Elasticsearch`` client constructor. + `Elasticsearch` client constructor. """ def __init__(self, corpus_name: str, host: str, **kwargs: Any): @@ -67,11 +67,11 @@ def retrieve(self, query: str, top_k: int = 5) -> list[dict]: Args: query (str): Natural language query string to search for. - top_k (int): Maximum number of documents to return. Defaults to ``5``. + top_k (int): Maximum number of documents to return. Defaults to `5`. Returns: - list[dict]: List of matching documents, each with keys ``doc_id``, - ``text``, and ``score``. + list[dict]: List of matching documents, each with keys `doc_id`, + `text`, and `score`. """ body = self.create_es_body(top_k, query) diff --git a/mellea/formatters/granite/retrievers/embeddings.py b/mellea/formatters/granite/retrievers/embeddings.py index e1d4c611a..7db1f7952 100644 --- a/mellea/formatters/granite/retrievers/embeddings.py +++ b/mellea/formatters/granite/retrievers/embeddings.py @@ -94,8 +94,8 @@ def compute_embeddings( """Split documents into windows and compute embeddings for each of the the windows. Args: - corpus: PyArrow Table of documents as returned by ``read_corpus()``. - Should have the columns ``["id", "url", "title", "text"]``. + corpus: PyArrow Table of documents as returned by `read_corpus()`. + Should have the columns `["id", "url", "title", "text"]`. embedding_model_name: Hugging Face model name for the model that computes embeddings. Also used for tokenizing. chunk_size: Maximum size of chunks to split documents into, in embedding @@ -106,7 +106,7 @@ def compute_embeddings( Returns: PyArrow Table of chunks of the corpus, with schema - ``["id", "url", "title", "begin", "end", "text", "embedding"]``. + `["id", "url", "title", "begin", "end", "text", "embedding"]`. """ # Third Party import pyarrow as pa @@ -192,7 +192,7 @@ def write_embeddings( Args: target_dir: Location where the files should be written (in a subdirectory). corpus_name: Corpus name used to generate the output directory name. - embeddings: PyArrow Table produced by ``compute_embeddings()``. + embeddings: PyArrow Table produced by `compute_embeddings()`. chunks_per_partition: Number of document chunks to write to each Parquet partition file. @@ -221,8 +221,8 @@ class InMemoryRetriever: Args: data_file_or_table: Parquet file of document snippets and embeddings, or an equivalent - in-memory PyArrow Table. Should have columns ``id``, ``begin``, ``end``, ``text``, - and ``embedding``. + in-memory PyArrow Table. Should have columns `id`, `begin`, `end`, `text`, + and `embedding`. embedding_model_name (str): Name of the Sentence Transformers model to use for embeddings. Must match the model used to compute embeddings in the data file. """ @@ -263,7 +263,7 @@ def retrieve(self, query: str, top_k: int = 5) -> list[dict]: top_k: Number of top results to return. Returns: - List of dicts with keys ``doc_id``, ``text``, and ``score``. + List of dicts with keys `doc_id`, `text`, and `score`. """ # Third Party import pyarrow as pa diff --git a/mellea/formatters/granite/retrievers/util.py b/mellea/formatters/granite/retrievers/util.py index 3a1cf3784..bf134c395 100644 --- a/mellea/formatters/granite/retrievers/util.py +++ b/mellea/formatters/granite/retrievers/util.py @@ -19,14 +19,14 @@ def download_mtrag_corpus(target_dir: str, corpus_name: str) -> pathlib.Path: Args: target_dir: Location where the file should be written if not already present. - corpus_name: Should be one of ``"cloud"``, ``"clapnq"``, ``"fiqa"``, - or ``"govt"``. + corpus_name: Should be one of `"cloud"`, `"clapnq"`, `"fiqa"`, + or `"govt"`. Returns: Path to the downloaded (or cached) file. Raises: - ValueError: If ``corpus_name`` is not one of the supported corpus names. + ValueError: If `corpus_name` is not one of the supported corpus names. """ corpus_names = ("cloud", "clapnq", "fiqa", "govt") if corpus_name not in corpus_names: @@ -53,10 +53,10 @@ def read_mtrag_corpus(corpus_file: str | pathlib.Path) -> pa.Table: Returns: Documents from the corpus as a PyArrow table, with schema - ``["id", "url", "title", "text"]``. + `["id", "url", "title", "text"]`. Raises: - TypeError: If the ID column cannot be identified or if no ``text`` column + TypeError: If the ID column cannot be identified or if no `text` column is present in the corpus file. """ if not isinstance(corpus_file, pathlib.Path): @@ -99,13 +99,13 @@ def download_mtrag_embeddings(embedding_name: str, corpus_name: str, target_dir: Args: embedding_name: Name of the SentenceTransformers embedding model used to create the embeddings. - corpus_name: Should be one of ``"cloud"``, ``"clapnq"``, ``"fiqa"``, - or ``"govt"``. - target_dir: Location where Parquet files named ``"part_001.parquet"``, - ``"part_002.parquet"``, etc. will be written. + corpus_name: Should be one of `"cloud"`, `"clapnq"`, `"fiqa"`, + or `"govt"`. + target_dir: Location where Parquet files named `"part_001.parquet"`, + `"part_002.parquet"`, etc. will be written. Raises: - ValueError: If ``corpus_name`` is not one of the supported corpus names, or + ValueError: If `corpus_name` is not one of the supported corpus names, or if no precomputed embeddings are found for the given corpus and embedding model combination. """ diff --git a/mellea/formatters/template_formatter.py b/mellea/formatters/template_formatter.py index 90c806297..982eef75e 100644 --- a/mellea/formatters/template_formatter.py +++ b/mellea/formatters/template_formatter.py @@ -1,9 +1,9 @@ -"""``TemplateFormatter``: Jinja2-template-based formatter for legacy backends. +"""`TemplateFormatter`: Jinja2-template-based formatter for legacy backends. -``TemplateFormatter`` extends ``ChatFormatter`` to look up a per-component Jinja2 -template at rendering time, allowing each ``Component`` type to control its own -prompt representation. Template discovery walks a configurable ``template_path`` and -the built-in templates directory; results are cached in a ``SimpleLRUCache`` for +`TemplateFormatter` extends `ChatFormatter` to look up a per-component Jinja2 +template at rendering time, allowing each `Component` type to control its own +prompt representation. Template discovery walks a configurable `template_path` and +the built-in templates directory; results are cached in a `SimpleLRUCache` for performance. Use this formatter when your backend requires hand-crafted prompts rather than a generic chat-message rendering. """ @@ -26,8 +26,8 @@ class TemplateFormatter(ChatFormatter): """Formatter that uses Jinja2 templates to render components into prompt strings. - Template discovery walks a configurable ``template_path`` and the built-in - templates directory. Results are optionally cached in a ``SimpleLRUCache`` + Template discovery walks a configurable `template_path` and the built-in + templates directory. Results are optionally cached in a `SimpleLRUCache` for performance. Use this formatter when your backend requires hand-crafted prompts rather than generic chat-message rendering. @@ -36,10 +36,10 @@ class TemplateFormatter(ChatFormatter): Should match the template directory structure. template_path (str): An alternate location where templates can be found. Will be preferred over all other template directories even if a less exact match is found. - Defaults to ``""``. - use_template_cache (bool): When ``True``, caches template lookup results. Set to ``False`` - if you plan to change ``model_id`` or ``template_path`` after construction. - Defaults to ``True``. + Defaults to `""`. + use_template_cache (bool): When `True`, caches template lookup results. Set to `False` + if you plan to change `model_id` or `template_path` after construction. + Defaults to `True`. Example:: diff --git a/mellea/helpers/__init__.py b/mellea/helpers/__init__.py index 62b22eb6a..da1f06724 100644 --- a/mellea/helpers/__init__.py +++ b/mellea/helpers/__init__.py @@ -1,10 +1,10 @@ """Low-level helpers and utilities supporting mellea backends. This package provides the internal plumbing used by the built-in backend -implementations: async utilities (``send_to_queue``, ``wait_for_all_mots``, -``ClientCache``) for managing concurrent model-output thunks; OpenAI-compatible -message conversion helpers (``message_to_openai_message``, ``messages_to_docs``, -``chat_completion_delta_merge``); and ``_ServerType`` detection for adapting +implementations: async utilities (`send_to_queue`, `wait_for_all_mots`, +`ClientCache`) for managing concurrent model-output thunks; OpenAI-compatible +message conversion helpers (`message_to_openai_message`, `messages_to_docs`, +`chat_completion_delta_merge`); and `_ServerType` detection for adapting structured-output support to the target server. Most user code will not import from this package directly — it is consumed internally by the backend layer. """ diff --git a/mellea/helpers/async_helpers.py b/mellea/helpers/async_helpers.py index 0618d54c1..65d419e19 100644 --- a/mellea/helpers/async_helpers.py +++ b/mellea/helpers/async_helpers.py @@ -1,9 +1,9 @@ """Async helper functions for managing concurrent model output thunks. -Provides ``send_to_queue``, which feeds a backend response coroutine or async iterator -into an ``asyncio.Queue`` (including sentinel and error forwarding); ``wait_for_all_mots``, -which gathers multiple ``ModelOutputThunk`` computations in a single ``asyncio.gather`` -call; and ``get_current_event_loop``, a safe wrapper that returns ``None`` instead of +Provides `send_to_queue`, which feeds a backend response coroutine or async iterator +into an `asyncio.Queue` (including sentinel and error forwarding); `wait_for_all_mots`, +which gathers multiple `ModelOutputThunk` computations in a single `asyncio.gather` +call; and `get_current_event_loop`, a safe wrapper that returns `None` instead of raising when no event loop is running. These utilities are used internally by backends that operate in async contexts. """ @@ -23,7 +23,7 @@ async def send_to_queue( Args: co: A coroutine or async iterator producing the backend response. - aqueue: The async queue to send results to. A sentinel ``None`` is appended on + aqueue: The async queue to send results to. A sentinel `None` is appended on completion; an exception instance is appended on error. """ try: @@ -57,7 +57,7 @@ async def wait_for_all_mots(mots: list[ModelOutputThunk]) -> None: functions, session functions, and top-level mellea functions. Args: - mots: List of ``ModelOutputThunk`` objects to await concurrently. + mots: List of `ModelOutputThunk` objects to await concurrently. """ coroutines: list[Coroutine[Any, Any, str]] = [] for mot in mots: @@ -70,7 +70,7 @@ def get_current_event_loop() -> None | asyncio.AbstractEventLoop: """Get the current event loop without having to catch exceptions. Returns: - The running event loop, or ``None`` if no loop is running. + The running event loop, or `None` if no loop is running. """ loop = None try: @@ -113,7 +113,7 @@ def get(self, key: int) -> Any | None: key: Integer cache key. Returns: - The cached value, or ``None`` if the key is not present. + The cached value, or `None` if the key is not present. """ if key not in self.cache: return None diff --git a/mellea/helpers/openai_compatible_helpers.py b/mellea/helpers/openai_compatible_helpers.py index 739677368..77a57401c 100644 --- a/mellea/helpers/openai_compatible_helpers.py +++ b/mellea/helpers/openai_compatible_helpers.py @@ -16,13 +16,13 @@ def extract_model_tool_requests( """Extract tool calls from the dict representation of an OpenAI-like chat response object. Args: - tools: Mapping of tool name to ``AbstractMelleaTool`` for lookup. + tools: Mapping of tool name to `AbstractMelleaTool` for lookup. response: Dict representation of an OpenAI-compatible chat completion message - (must contain a ``"message"`` key). + (must contain a `"message"` key). Returns: - Mapping of tool name to ``ModelToolCall`` for each requested tool call, or - ``None`` if no tool calls were found. + Mapping of tool name to `ModelToolCall` for each requested tool call, or + `None` if no tool calls were found. """ model_tool_calls: dict[str, ModelToolCall] = {} calls = response["message"].get("tool_calls", None) @@ -55,19 +55,19 @@ def extract_model_tool_requests( def chat_completion_delta_merge( chunks: list[dict], force_all_tool_calls_separate: bool = False ) -> dict: - """Merge a list of deltas from ``ChatCompletionChunk``s into a single dict representing the ``ChatCompletion`` choice. + """Merge a list of deltas from `ChatCompletionChunk`s into a single dict representing the `ChatCompletion` choice. Args: chunks: The list of dicts that represent the message deltas. - force_all_tool_calls_separate: If ``True``, tool calls in separate message + force_all_tool_calls_separate: If `True`, tool calls in separate message deltas will not be merged even if their index values are the same. Use when providers do not return the correct index value for tool calls; all tool calls must then be fully populated in a single delta. Returns: - A single merged dict representing the assembled ``ChatCompletion`` choice, - with ``finish_reason``, ``index``, and a ``message`` sub-dict containing - ``content``, ``role``, and ``tool_calls``. + A single merged dict representing the assembled `ChatCompletion` choice, + with `finish_reason`, `index`, and a `message` sub-dict containing + `content`, `role`, and `tool_calls`. """ merged: dict[str, Any] = dict() @@ -141,14 +141,14 @@ def chat_completion_delta_merge( def message_to_openai_message(msg: Message): - """Serialise a Mellea ``Message`` to the format required by OpenAI-compatible API providers. + """Serialise a Mellea `Message` to the format required by OpenAI-compatible API providers. Args: - msg: The ``Message`` object to serialise. + msg: The `Message` object to serialise. Returns: - A dict with ``"role"`` and ``"content"`` fields. When the message carries - images, ``"content"`` is a list of text and image-URL dicts; otherwise it + A dict with `"role"` and `"content"` fields. When the message carries + images, `"content"` is a list of text and image-URL dicts; otherwise it is a plain string. """ if msg.images is not None: @@ -182,14 +182,14 @@ def message_to_openai_message(msg: Message): def messages_to_docs(msgs: list[Message]) -> list[dict[str, str]]: - """Extract all ``Document`` objects from a list of ``Message`` objects. + """Extract all `Document` objects from a list of `Message` objects. Args: - msgs: List of ``Message`` objects whose ``_docs`` attributes are inspected. + msgs: List of `Message` objects whose `_docs` attributes are inspected. Returns: - A list of dicts, each with a ``"text"`` key and optional ``"title"`` and - ``"doc_id"`` keys, suitable for passing to an OpenAI-compatible RAG API. + A list of dicts, each with a `"text"` key and optional `"title"` and + `"doc_id"` keys, suitable for passing to an OpenAI-compatible RAG API. """ docs: list[Document] = [] for message in msgs: diff --git a/mellea/helpers/server_type.py b/mellea/helpers/server_type.py index b3a770935..6f0375d03 100644 --- a/mellea/helpers/server_type.py +++ b/mellea/helpers/server_type.py @@ -1,11 +1,11 @@ """Utilities for detecting and classifying the target inference server. -Defines the ``_ServerType`` enum (``LOCALHOST``, ``OPENAI``, ``REMOTE_VLLM``, -``UNKNOWN``) and ``_server_type``, which classifies a URL by hostname. Also provides -``is_vllm_server_with_structured_output``, which probes a server's ``/version`` -endpoint to determine whether it supports the ``structured_outputs`` parameter +Defines the `_ServerType` enum (`LOCALHOST`, `OPENAI`, `REMOTE_VLLM`, +`UNKNOWN`) and `_server_type`, which classifies a URL by hostname. Also provides +`is_vllm_server_with_structured_output`, which probes a server's `/version` +endpoint to determine whether it supports the `structured_outputs` parameter introduced in vLLM ≥ 0.12.0. Used by the OpenAI-compatible backend to choose between -``guided_json`` and ``structured_outputs`` request formats. +`guided_json` and `structured_outputs` request formats. """ import json diff --git a/mellea/plugins/base.py b/mellea/plugins/base.py index 357e2a515..2e81fcf3f 100644 --- a/mellea/plugins/base.py +++ b/mellea/plugins/base.py @@ -66,13 +66,13 @@ async def redact_input(self, payload, ctx): def __init_subclass__( cls, *, name: str = "", priority: int = 50, **kwargs: Any ) -> None: - """Set plugin metadata on subclasses that provide a ``name``.""" + """Set plugin metadata on subclasses that provide a `name`.""" super().__init_subclass__(**kwargs) if name: cls._mellea_plugin_meta = PluginMeta(name=name, priority=priority) # type: ignore[attr-defined] def __enter__(self) -> Any: - """Register this plugin for the duration of a ``with`` block.""" + """Register this plugin for the duration of a `with` block.""" return _plugin_cm_enter(self) def __exit__( @@ -85,7 +85,7 @@ def __exit__( _plugin_cm_exit(self, exc_type, exc_val, exc_tb) async def __aenter__(self) -> Any: - """Async variant — delegates to ``__enter__``.""" + """Async variant — delegates to `__enter__`.""" return self.__enter__() async def __aexit__( @@ -94,7 +94,7 @@ async def __aexit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: - """Async variant — delegates to ``__exit__``.""" + """Async variant — delegates to `__exit__`.""" self.__exit__(exc_type, exc_val, exc_tb) @@ -139,9 +139,9 @@ def __init__( # noqa: D107 class MelleaBasePayload(PluginPayload): """Frozen base — all payloads are immutable by design. - Plugins must use ``model_copy(update={...})`` to propose modifications - and return the copy via ``PluginResult.modified_payload``. The plugin - manager applies the hook's ``HookPayloadPolicy`` to filter changes to + Plugins must use `model_copy(update={...})` to propose modifications + and return the copy via `PluginResult.modified_payload`. The plugin + manager applies the hook's `HookPayloadPolicy` to filter changes to writable fields only. """ @@ -154,9 +154,9 @@ class MelleaBasePayload(PluginPayload): class MelleaPlugin(_CpexPlugin): """Base class for Mellea plugins with lifecycle hooks and typed accessors. - Use this when you need lifecycle hooks (``initialize``/``shutdown``) - or typed context accessors. For simpler plugins, prefer ``@hook`` - on standalone functions or ``@plugin`` on plain classes. + Use this when you need lifecycle hooks (`initialize`/`shutdown`) + or typed context accessors. For simpler plugins, prefer `@hook` + on standalone functions or `@plugin` on plain classes. Instances support the context manager protocol for temporary activation:: @@ -193,7 +193,7 @@ def plugin_config(self) -> dict[str, Any]: return self._config.config or {} def __enter__(self) -> MelleaPlugin: - """Register this plugin for the duration of a ``with`` block.""" + """Register this plugin for the duration of a `with` block.""" if getattr(self, "_scope_id", None) is not None: raise RuntimeError( f"MelleaPlugin {self.name!r} is already active as a context manager. " @@ -218,11 +218,11 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self._scope_id = None # type: ignore[assignment] async def __aenter__(self) -> MelleaPlugin: - """Async variant — delegates to the synchronous ``__enter__``.""" + """Async variant — delegates to the synchronous `__enter__`.""" return self.__enter__() async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - """Async variant — delegates to the synchronous ``__exit__``.""" + """Async variant — delegates to the synchronous `__exit__`.""" self.__exit__(exc_type, exc_val, exc_tb) PluginResult: TypeAlias = _CFPluginResult # type: ignore[misc] @@ -230,7 +230,7 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: else: # Provide a stub when the plugin framework is not installed. class MelleaBasePayload: # type: ignore[no-redef] - """Stub — install ``"mellea[hooks]"`` for full plugin support.""" + """Stub — install `"mellea[hooks]"` for full plugin support.""" def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: D107 raise ImportError( @@ -240,7 +240,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: D107 # Provide a stub when the plugin framework is not installed. class MelleaPlugin: # type: ignore[no-redef] - """Stub — install ``"mellea[hooks]"`` for full plugin support.""" + """Stub — install `"mellea[hooks]"` for full plugin support.""" def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: D107 raise ImportError( diff --git a/mellea/plugins/context.py b/mellea/plugins/context.py index bed622794..3b6b8d0bb 100644 --- a/mellea/plugins/context.py +++ b/mellea/plugins/context.py @@ -16,14 +16,14 @@ def build_global_context(*, backend: Backend | None = None, **extra_fields: Any) -> Any: - """Build a ContextForge ``GlobalContext`` from Mellea domain objects. + """Build a ContextForge `GlobalContext` from Mellea domain objects. The global context carries lightweight, cross-cutting ambient metadata - (e.g. ``backend_name``) that is useful to every hook regardless of type. + (e.g. `backend_name`) that is useful to every hook regardless of type. Hook-specific data (context, session, action, etc.) belongs on the typed payload, not here. - Returns ``None`` if ContextForge is not installed. + Returns `None` if ContextForge is not installed. """ if not _HAS_PLUGIN_FRAMEWORK: return None diff --git a/mellea/plugins/decorators.py b/mellea/plugins/decorators.py index 27789c647..d394bcef6 100644 --- a/mellea/plugins/decorators.py +++ b/mellea/plugins/decorators.py @@ -26,10 +26,10 @@ def hook( """Register an async function or method as a hook handler. Args: - hook_type: The hook point name (e.g., ``"generation_pre_call"``). - mode: Execution mode — ``PluginMode.SEQUENTIAL`` (default), ``PluginMode.CONCURRENT``, - ``PluginMode.AUDIT``, or ``PluginMode.FIRE_AND_FORGET``. - priority: Lower numbers execute first. For methods on a ``Plugin`` subclass, falls back + hook_type: The hook point name (e.g., `"generation_pre_call"`). + mode: Execution mode — `PluginMode.SEQUENTIAL` (default), `PluginMode.CONCURRENT`, + `PluginMode.AUDIT`, or `PluginMode.FIRE_AND_FORGET`. + priority: Lower numbers execute first. For methods on a `Plugin` subclass, falls back to the class-level priority, then 50. For standalone functions, defaults to 50. """ diff --git a/mellea/plugins/hooks/component.py b/mellea/plugins/hooks/component.py index b42615c3a..d23936bf1 100644 --- a/mellea/plugins/hooks/component.py +++ b/mellea/plugins/hooks/component.py @@ -8,16 +8,16 @@ class ComponentPreExecutePayload(MelleaBasePayload): - """Payload for ``component_pre_execute`` — before component execution via ``aact()``. + """Payload for `component_pre_execute` — before component execution via `aact()`. Attributes: component_type: Class name of the component being executed. - action: The ``Component`` or ``CBlock`` about to be executed. + action: The `Component` or `CBlock` about to be executed. context_view: Optional snapshot of the context as a list. - requirements: List of ``Requirement`` instances for validation (writable). + requirements: List of `Requirement` instances for validation (writable). model_options: Dict of model options passed to the backend (writable). - format: Optional ``BaseModel`` subclass for structured output / constrained decoding (writable). - strategy: Optional ``SamplingStrategy`` instance controlling retry logic (writable). + format: Optional `BaseModel` subclass for structured output / constrained decoding (writable). + strategy: Optional `SamplingStrategy` instance controlling retry logic (writable). tool_calls_enabled: Whether tool calling is enabled for this execution (writable). """ @@ -32,19 +32,19 @@ class ComponentPreExecutePayload(MelleaBasePayload): class ComponentPostSuccessPayload(MelleaBasePayload): - """Payload for ``component_post_success`` — after successful component execution. + """Payload for `component_post_success` — after successful component execution. Attributes: component_type: Class name of the executed component. - action: The ``Component`` or ``CBlock`` that was executed. + action: The `Component` or `CBlock` that was executed. - result: The ``ModelOutputThunk`` containing the generation result. - context_before: The ``Context`` before execution. + result: The `ModelOutputThunk` containing the generation result. + context_before: The `Context` before execution. - context_after: The ``Context`` after execution (with action + result appended). + context_after: The `Context` after execution (with action + result appended). - generate_log: The ``GenerateLog`` from the final generation pass. - sampling_results: Optional list of ``ModelOutputThunk`` from all sampling attempts. + generate_log: The `GenerateLog` from the final generation pass. + sampling_results: Optional list of `ModelOutputThunk` from all sampling attempts. latency_ms: Wall-clock time for the full execution in milliseconds. """ @@ -59,16 +59,16 @@ class ComponentPostSuccessPayload(MelleaBasePayload): class ComponentPostErrorPayload(MelleaBasePayload): - """Payload for ``component_post_error`` — after component execution fails. + """Payload for `component_post_error` — after component execution fails. Attributes: component_type: Class name of the component that failed. - action: The ``Component`` or ``CBlock`` that was being executed. + action: The `Component` or `CBlock` that was being executed. - error: The ``Exception`` that was raised. - error_type: Class name of the exception (e.g. ``"ValueError"``). + error: The `Exception` that was raised. + error_type: Class name of the exception (e.g. `"ValueError"`). stack_trace: Formatted traceback string. - context: The ``Context`` at the time of the error. + context: The `Context` at the time of the error. model_options: Dict of model options that were in effect. """ diff --git a/mellea/plugins/hooks/generation.py b/mellea/plugins/hooks/generation.py index 2b83c8c1f..438d58720 100644 --- a/mellea/plugins/hooks/generation.py +++ b/mellea/plugins/hooks/generation.py @@ -8,15 +8,15 @@ class GenerationPreCallPayload(MelleaBasePayload): - """Payload for ``generation_pre_call`` — before LLM backend call. + """Payload for `generation_pre_call` — before LLM backend call. Attributes: - action: The ``Component`` or ``CBlock`` about to be sent to the backend. + action: The `Component` or `CBlock` about to be sent to the backend. - context: The ``Context`` being used for this generation call. + context: The `Context` being used for this generation call. model_options: Dict of model options (writable — plugins may adjust temperature, etc.). - format: Optional ``BaseModel`` subclass for constrained decoding (writable). + format: Optional `BaseModel` subclass for constrained decoding (writable). tool_calls: Whether tool calls are enabled for this generation (writable). """ @@ -28,18 +28,18 @@ class GenerationPreCallPayload(MelleaBasePayload): class GenerationPostCallPayload(MelleaBasePayload): - """Payload for ``generation_post_call`` — fires once the model output is fully computed. + """Payload for `generation_post_call` — fires once the model output is fully computed. - For lazy ``ModelOutputThunk`` objects this hook fires inside - ``ModelOutputThunk.astream`` after ``post_process`` completes, so - ``model_output.value`` is guaranteed to be available. For already-computed - thunks (e.g. cached responses) it fires before ``generate_from_context`` + For lazy `ModelOutputThunk` objects this hook fires inside + `ModelOutputThunk.astream` after `post_process` completes, so + `model_output.value` is guaranteed to be available. For already-computed + thunks (e.g. cached responses) it fires before `generate_from_context` returns. Attributes: prompt: The formatted prompt sent to the backend (str or list of message dicts). - model_output: The fully-computed ``ModelOutputThunk``. - latency_ms: Elapsed milliseconds from the ``generate_from_context`` call + model_output: The fully-computed `ModelOutputThunk`. + latency_ms: Elapsed milliseconds from the `generate_from_context` call to when the value was fully materialized. """ diff --git a/mellea/plugins/hooks/sampling.py b/mellea/plugins/hooks/sampling.py index 4e0bceefe..a9dabb680 100644 --- a/mellea/plugins/hooks/sampling.py +++ b/mellea/plugins/hooks/sampling.py @@ -8,15 +8,15 @@ class SamplingLoopStartPayload(MelleaBasePayload): - """Payload for ``sampling_loop_start`` — when sampling strategy begins. + """Payload for `sampling_loop_start` — when sampling strategy begins. Attributes: - strategy_name: Class name of the sampling strategy (e.g. ``"RejectionSamplingStrategy"``). - action: The ``Component`` being sampled. + strategy_name: Class name of the sampling strategy (e.g. `"RejectionSamplingStrategy"`). + action: The `Component` being sampled. - context: The ``Context`` at the start of sampling. + context: The `Context` at the start of sampling. - requirements: List of ``Requirement`` instances to validate against. + requirements: List of `Requirement` instances to validate against. loop_budget: Maximum number of sampling iterations allowed (writable). """ @@ -28,15 +28,15 @@ class SamplingLoopStartPayload(MelleaBasePayload): class SamplingIterationPayload(MelleaBasePayload): - """Payload for ``sampling_iteration`` — after each sampling attempt. + """Payload for `sampling_iteration` — after each sampling attempt. Attributes: iteration: 1-based iteration number within the sampling loop. - action: The ``Component`` used for this attempt. + action: The `Component` used for this attempt. - result: The ``ModelOutputThunk`` produced by this attempt. - validation_results: List of ``(Requirement, ValidationResult)`` tuples. - all_validations_passed: ``True`` when **every** requirement in ``validation_results`` + result: The `ModelOutputThunk` produced by this attempt. + validation_results: List of `(Requirement, ValidationResult)` tuples. + all_validations_passed: `True` when **every** requirement in `validation_results` passed for this iteration (i.e., the sampling attempt succeeded). valid_count: Number of requirements that passed. total_count: Total number of requirements evaluated. @@ -52,16 +52,16 @@ class SamplingIterationPayload(MelleaBasePayload): class SamplingRepairPayload(MelleaBasePayload): - """Payload for ``sampling_repair`` — when repair is invoked after validation failure. + """Payload for `sampling_repair` — when repair is invoked after validation failure. Attributes: - repair_type: Kind of repair (strategy-dependent, e.g. ``"rejection"``, ``"template"``). - failed_action: The ``Component`` that failed validation. + repair_type: Kind of repair (strategy-dependent, e.g. `"rejection"`, `"template"`). + failed_action: The `Component` that failed validation. - failed_result: The ``ModelOutputThunk`` that failed validation. - failed_validations: List of ``(Requirement, ValidationResult)`` tuples that failed. - repair_action: The repaired ``Component`` to use for the next attempt. - repair_context: The ``Context`` to use for the next attempt. + failed_result: The `ModelOutputThunk` that failed validation. + failed_validations: List of `(Requirement, ValidationResult)` tuples that failed. + repair_action: The repaired `Component` to use for the next attempt. + repair_context: The `Context` to use for the next attempt. repair_iteration: 1-based iteration at which the repair was triggered. """ @@ -75,20 +75,20 @@ class SamplingRepairPayload(MelleaBasePayload): class SamplingLoopEndPayload(MelleaBasePayload): - """Payload for ``sampling_loop_end`` — when sampling completes. + """Payload for `sampling_loop_end` — when sampling completes. Attributes: - success: ``True`` if at least one attempt passed all requirements. + success: `True` if at least one attempt passed all requirements. iterations_used: Total number of iterations the loop executed. - final_result: The selected ``ModelOutputThunk`` (best success or best failure). - final_action: The ``Component`` that produced ``final_result``. + final_result: The selected `ModelOutputThunk` (best success or best failure). + final_action: The `Component` that produced `final_result`. - final_context: The ``Context`` associated with ``final_result``. + final_context: The `Context` associated with `final_result`. - failure_reason: Human-readable reason when ``success`` is ``False``. - all_results: List of ``ModelOutputThunk`` from every iteration. - all_validations: Nested list — ``all_validations[i]`` is the list of - ``(Requirement, ValidationResult)`` tuples for iteration *i*. + failure_reason: Human-readable reason when `success` is `False`. + all_results: List of `ModelOutputThunk` from every iteration. + all_validations: Nested list — `all_validations[i]` is the list of + `(Requirement, ValidationResult)` tuples for iteration *i*. """ success: bool = False diff --git a/mellea/plugins/hooks/session.py b/mellea/plugins/hooks/session.py index 321f154ee..a8018f5ec 100644 --- a/mellea/plugins/hooks/session.py +++ b/mellea/plugins/hooks/session.py @@ -11,13 +11,13 @@ class SessionPreInitPayload(MelleaBasePayload): - """Payload for ``session_pre_init`` — before backend initialization. + """Payload for `session_pre_init` — before backend initialization. Attributes: - backend_name: Name of the backend (e.g. ``"ollama"``, ``"openai"``). + backend_name: Name of the backend (e.g. `"ollama"`, `"openai"`). model_id: Model identifier string (writable). model_options: Optional dict of model options like temperature, max_tokens (writable). - context_type: Class name of the context being used (e.g. ``"SimpleContext"``). + context_type: Class name of the context being used (e.g. `"SimpleContext"`). """ backend_name: str @@ -27,12 +27,12 @@ class SessionPreInitPayload(MelleaBasePayload): class SessionPostInitPayload(MelleaBasePayload): - """Payload for ``session_post_init`` — after session is fully initialized. + """Payload for `session_post_init` — after session is fully initialized. Attributes: session_id: UUID string identifying this session. - model_id: Model identifier used by the backend (e.g. ``"granite4:micro"``). - context: The initial ``Context`` instance for this session. + model_id: Model identifier used by the backend (e.g. `"granite4:micro"`). + context: The initial `Context` instance for this session. """ session_id: str = "" @@ -41,10 +41,10 @@ class SessionPostInitPayload(MelleaBasePayload): class SessionResetPayload(MelleaBasePayload): - """Payload for ``session_reset`` — when session context is reset. + """Payload for `session_reset` — when session context is reset. Attributes: - previous_context: The ``Context`` that is about to be discarded (observe-only). + previous_context: The `Context` that is about to be discarded (observe-only). """ @@ -52,10 +52,10 @@ class SessionResetPayload(MelleaBasePayload): class SessionCleanupPayload(MelleaBasePayload): - """Payload for ``session_cleanup`` — before session cleanup/teardown. + """Payload for `session_cleanup` — before session cleanup/teardown. Attributes: - context: The ``Context`` at the time of cleanup (observe-only). + context: The `Context` at the time of cleanup (observe-only). interaction_count: Number of items in the context at cleanup time. """ diff --git a/mellea/plugins/hooks/tool.py b/mellea/plugins/hooks/tool.py index bc92d430e..d8a8b2f6f 100644 --- a/mellea/plugins/hooks/tool.py +++ b/mellea/plugins/hooks/tool.py @@ -8,10 +8,10 @@ class ToolPreInvokePayload(MelleaBasePayload): - """Payload for ``tool_pre_invoke`` — before tool/function invocation. + """Payload for `tool_pre_invoke` — before tool/function invocation. Attributes: - model_tool_call: The ``ModelToolCall`` about to be executed (writable — + model_tool_call: The `ModelToolCall` about to be executed (writable — plugins may modify arguments or swap the tool entirely). """ @@ -19,16 +19,16 @@ class ToolPreInvokePayload(MelleaBasePayload): class ToolPostInvokePayload(MelleaBasePayload): - """Payload for ``tool_post_invoke`` — after tool execution. + """Payload for `tool_post_invoke` — after tool execution. Attributes: - model_tool_call: The ``ModelToolCall`` that was executed. + model_tool_call: The `ModelToolCall` that was executed. tool_output: The return value of the tool function (writable — plugins may transform the output before it is formatted). - tool_message: The ``ToolMessage`` constructed from the output. + tool_message: The `ToolMessage` constructed from the output. execution_time_ms: Wall-clock time of the tool execution in milliseconds. - success: ``True`` if the tool executed without raising an exception. - error: The ``Exception`` raised during execution, or ``None`` on success. + success: `True` if the tool executed without raising an exception. + error: The `Exception` raised during execution, or `None` on success. """ model_tool_call: Any = None diff --git a/mellea/plugins/hooks/validation.py b/mellea/plugins/hooks/validation.py index 4c7f22845..78ef539f7 100644 --- a/mellea/plugins/hooks/validation.py +++ b/mellea/plugins/hooks/validation.py @@ -8,13 +8,13 @@ class ValidationPreCheckPayload(MelleaBasePayload): - """Payload for ``validation_pre_check`` — before requirement validation. + """Payload for `validation_pre_check` — before requirement validation. Attributes: - requirements: List of ``Requirement`` instances to validate (writable). - target: The ``CBlock`` being validated, or ``None`` when validating the full context. + requirements: List of `Requirement` instances to validate (writable). + target: The `CBlock` being validated, or `None` when validating the full context. - context: The ``Context`` used for validation. + context: The `Context` used for validation. model_options: Dict of model options for backend-based validators (writable). """ @@ -26,12 +26,12 @@ class ValidationPreCheckPayload(MelleaBasePayload): class ValidationPostCheckPayload(MelleaBasePayload): - """Payload for ``validation_post_check`` — after validation completes. + """Payload for `validation_post_check` — after validation completes. Attributes: - requirements: List of ``Requirement`` instances that were evaluated. - results: List of ``ValidationResult`` instances (writable). - all_validations_passed: ``True`` when every requirement passed (writable). + requirements: List of `Requirement` instances that were evaluated. + results: List of `ValidationResult` instances (writable). + all_validations_passed: `True` when every requirement passed (writable). passed_count: Number of requirements that passed. failed_count: Number of requirements that failed. """ diff --git a/mellea/plugins/manager.py b/mellea/plugins/manager.py index 25598f520..17a008397 100644 --- a/mellea/plugins/manager.py +++ b/mellea/plugins/manager.py @@ -34,7 +34,7 @@ def has_plugins(hook_type: HookType | None = None) -> bool: """Fast check: are plugins configured and available for the given hook type. - When ``hook_type`` is provided, also checks whether any plugin has + When `hook_type` is provided, also checks whether any plugin has registered a handler for that specific hook, enabling callers to skip payload construction entirely when no plugin subscribes. """ @@ -46,7 +46,7 @@ def has_plugins(hook_type: HookType | None = None) -> bool: def get_plugin_manager() -> Any | None: - """Returns the initialized PluginManager, or ``None`` if plugins are not configured.""" + """Returns the initialized PluginManager, or `None` if plugins are not configured.""" return _plugin_manager @@ -151,12 +151,12 @@ async def invoke_hook( ) -> tuple[Any | None, MelleaBasePayload]: """Invoke a hook if plugins are configured. - Returns ``(result, possibly-modified-payload)``. - If plugins are not configured, returns ``(None, original_payload)`` immediately. + Returns `(result, possibly-modified-payload)`. + If plugins are not configured, returns `(None, original_payload)` immediately. Three layers of no-op guards ensure zero overhead when plugins are not configured: - 1. ``_plugins_enabled`` boolean — single pointer dereference - 2. ``has_hooks_for(hook_type)`` — skips when no plugin subscribes + 1. `_plugins_enabled` boolean — single pointer dereference + 2. `has_hooks_for(hook_type)` — skips when no plugin subscribes 3. Returns immediately when either guard fails """ if not _plugins_enabled or _plugin_manager is None: diff --git a/mellea/plugins/pluginset.py b/mellea/plugins/pluginset.py index d485ee65a..86c7d8749 100644 --- a/mellea/plugins/pluginset.py +++ b/mellea/plugins/pluginset.py @@ -10,8 +10,8 @@ class PluginSet: """A named, composable group of hook functions and plugin instances. PluginSets are inert containers — they do not register anything themselves. - Registration happens when they are passed to ``register()`` or - ``start_session(plugins=[...])``. + Registration happens when they are passed to `register()` or + `start_session(plugins=[...])`. PluginSets can be nested: a PluginSet can contain other PluginSets. @@ -38,10 +38,10 @@ def __init__( # noqa: D107 self._scope_id: str | None = None def flatten(self) -> list[tuple[Callable | Any, int | None]]: - """Recursively flatten nested PluginSets into ``(item, priority_override)`` pairs. + """Recursively flatten nested PluginSets into `(item, priority_override)` pairs. When this set has a priority, it overrides the priorities of all nested - items — including items inside nested ``PluginSet`` instances. + items — including items inside nested `PluginSet` instances. """ result: list[tuple[Callable | Any, int | None]] = [] for item in self.items: @@ -58,7 +58,7 @@ def flatten(self) -> list[tuple[Callable | Any, int | None]]: return result def __enter__(self) -> PluginSet: - """Register all plugins in this set for the duration of the ``with`` block.""" + """Register all plugins in this set for the duration of the `with` block.""" if self._scope_id is not None: raise RuntimeError( f"PluginSet {self.name!r} is already active as a context manager. " @@ -82,11 +82,11 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self._scope_id = None async def __aenter__(self) -> PluginSet: - """Async variant — delegates to the synchronous ``__enter__``.""" + """Async variant — delegates to the synchronous `__enter__`.""" return self.__enter__() async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - """Async variant — delegates to the synchronous ``__exit__``.""" + """Async variant — delegates to the synchronous `__exit__`.""" self.__exit__(exc_type, exc_val, exc_tb) def __repr__(self) -> str: # noqa: D105 diff --git a/mellea/plugins/policies.py b/mellea/plugins/policies.py index 262114e29..99ec5ee68 100644 --- a/mellea/plugins/policies.py +++ b/mellea/plugins/policies.py @@ -19,18 +19,18 @@ def _build_policies() -> dict[str, Any]: Mutability is enforced at **two layers**: - 1. **Execution mode** (cpex) — only ``SEQUENTIAL`` and ``TRANSFORM`` plugins - can modify payloads. ``AUDIT``, ``CONCURRENT``, and ``FIRE_AND_FORGET`` + 1. **Execution mode** (cpex) — only `SEQUENTIAL` and `TRANSFORM` plugins + can modify payloads. `AUDIT`, `CONCURRENT`, and `FIRE_AND_FORGET` plugins have their modifications silently discarded by cpex regardless of what this table says. 2. **Field-level policy** (this table) — for modes that *can* modify, this table restricts *which* fields are writable. cpex applies - ``HookPayloadPolicy`` after each plugin returns, accepting only changes + `HookPayloadPolicy` after each plugin returns, accepting only changes to listed fields and discarding the rest. Hooks absent from this table are observe-only; with - ``DefaultHookPolicy.DENY`` (the Mellea default), any modification attempt + `DefaultHookPolicy.DENY` (the Mellea default), any modification attempt on an unlisted hook is rejected by cpex at runtime. """ if not _HAS_PLUGIN_FRAMEWORK: diff --git a/mellea/plugins/registry.py b/mellea/plugins/registry.py index 181044ec6..078d4ff4c 100644 --- a/mellea/plugins/registry.py +++ b/mellea/plugins/registry.py @@ -45,11 +45,11 @@ def _map_mode(mode: PluginMode) -> Any: def modify(payload: Any, **field_updates: Any) -> Any: - """Convenience helper for returning a modifying ``PluginResult``. + """Convenience helper for returning a modifying `PluginResult`. - Creates an immutable copy of ``payload`` with ``field_updates`` applied and - wraps it in a ``PluginResult(continue_processing=True)``. Only fields - listed in the hook's ``HookPayloadPolicy.writable_fields`` will be accepted + Creates an immutable copy of `payload` with `field_updates` applied and + wraps it in a `PluginResult(continue_processing=True)`. Only fields + listed in the hook's `HookPayloadPolicy.writable_fields` will be accepted by the framework; changes to read-only fields are silently discarded. Mirrors :func:`block` for the modification case:: @@ -83,12 +83,12 @@ def block( description: str = "", details: dict[str, Any] | None = None, ) -> Any: - """Convenience helper for returning a blocking ``PluginResult``. + """Convenience helper for returning a blocking `PluginResult`. Args: reason: Short reason for the violation. code: Machine-readable violation code. - description: Longer description (defaults to ``reason``). + description: Longer description (defaults to `reason`). details: Additional structured details. """ if not _HAS_PLUGIN_FRAMEWORK: @@ -114,11 +114,11 @@ def register( ) -> None: """Register plugins globally or for a specific session. - When ``session_id`` is ``None``, plugins are global (fire for all invocations). - When ``session_id`` is provided, plugins fire only within that session. + When `session_id` is `None`, plugins are global (fire for all invocations). + When `session_id` is provided, plugins fire only within that session. - Accepts standalone ``@hook`` functions, ``@plugin``-decorated class instances, - ``MelleaPlugin`` instances, ``PluginSet`` instances, or lists thereof. + Accepts standalone `@hook` functions, `@plugin`-decorated class instances, + `MelleaPlugin` instances, `PluginSet` instances, or lists thereof. """ if not _HAS_PLUGIN_FRAMEWORK: raise ImportError( @@ -146,9 +146,9 @@ def _register_single( ) -> None: """Register a single hook function or plugin instance. - - Standalone functions with ``_mellea_hook_meta``: wrapped in ``_FunctionHookAdapter`` - - ``@plugin``-decorated class instances: methods with ``_mellea_hook_meta`` discovered - - ``MelleaPlugin`` instances: registered directly + - Standalone functions with `_mellea_hook_meta`: wrapped in `_FunctionHookAdapter` + - `@plugin`-decorated class instances: methods with `_mellea_hook_meta` discovered + - `MelleaPlugin` instances: registered directly """ meta: HookMeta | None = getattr(item, "_mellea_hook_meta", None) plugin_meta: PluginMeta | None = getattr(type(item), "_mellea_plugin_meta", None) @@ -235,7 +235,7 @@ def _register_single( if _HAS_PLUGIN_FRAMEWORK: class _FunctionHookAdapter(Plugin): - """Adapts a standalone ``@hook``-decorated function into a ContextForge Plugin.""" + """Adapts a standalone `@hook`-decorated function into a ContextForge Plugin.""" def __init__( self, @@ -284,15 +284,15 @@ async def _invoke(self, payload: Any, context: Any) -> Any: return result class _MethodHookAdapter(Plugin): - """Adapts a single ``@hook``-decorated bound method from a ``Plugin`` class. + """Adapts a single `@hook`-decorated bound method from a `Plugin` class. - Each ``@hook`` method on a ``@plugin``-decorated class gets its own adapter - so that per-method execution modes (``SEQUENTIAL``, ``FIRE_AND_FORGET``, etc.) - are respected. The adapter name is ``"."``. + Each `@hook` method on a `@plugin`-decorated class gets its own adapter + so that per-method execution modes (`SEQUENTIAL`, `FIRE_AND_FORGET`, etc.) + are respected. The adapter name is `"."`. - Note: ``initialize()`` and ``shutdown()`` delegate to the underlying class + Note: `initialize()` and `shutdown()` delegate to the underlying class instance and may be called once per registered hook method. Make them - idempotent when using the ``Plugin`` base class with multiple hook methods. + idempotent when using the `Plugin` base class with multiple hook methods. """ def __init__( @@ -346,7 +346,7 @@ async def _invoke(self, payload: Any, context: Any) -> Any: else: # Provide a stub when the plugin framework is not installed. class _FunctionHookAdapter: # type: ignore[no-redef] - """Stub — install ``"mellea[hooks]"`` for full plugin support.""" + """Stub — install `"mellea[hooks]"` for full plugin support.""" def __init__(self, *args: Any, **kwargs: Any) -> None: raise ImportError( @@ -356,7 +356,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # Provide a stub when the plugin framework is not installed. class _MethodHookAdapter: # type: ignore[no-redef] - """Stub — install ``"mellea[hooks]"`` for full plugin support.""" + """Stub — install `"mellea[hooks]"` for full plugin support.""" def __init__(self, *args: Any, **kwargs: Any) -> None: raise ImportError( @@ -368,7 +368,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: class _PluginScope: """Context manager returned by :func:`plugin_scope`. - Supports both synchronous and asynchronous ``with`` statements. + Supports both synchronous and asynchronous `with` statements. """ def __init__(self, items: list[Callable | Any | PluginSet]) -> None: @@ -453,9 +453,9 @@ def unregister( ) -> None: """Unregister globally-registered plugins. - Accepts the same items as :func:`register`: standalone ``@hook``-decorated - functions, ``Plugin`` subclass instances, ``MelleaPlugin`` instances, - ``PluginSet`` instances, or lists thereof. + Accepts the same items as :func:`register`: standalone `@hook`-decorated + functions, `Plugin` subclass instances, `MelleaPlugin` instances, + `PluginSet` instances, or lists thereof. Silently ignores items that are not currently registered. """ @@ -485,11 +485,11 @@ def unregister( def plugin_scope(*items: Callable | Any | PluginSet) -> _PluginScope: """Return a context manager that temporarily registers plugins for a block of code. - Accepts the same items as :func:`register`: standalone ``@hook``-decorated - functions, ``@plugin``-decorated class instances, ``MelleaPlugin`` instances, + Accepts the same items as :func:`register`: standalone `@hook`-decorated + functions, `@plugin`-decorated class instances, `MelleaPlugin` instances, and :class:`~mellea.plugins.PluginSet` instances — or any mix thereof. - Supports both synchronous and asynchronous ``with`` statements:: + Supports both synchronous and asynchronous `with` statements:: # Sync functional API with plugin_scope(log_hook, audit_plugin): diff --git a/mellea/stdlib/__init__.py b/mellea/stdlib/__init__.py index a402d5733..b1d0882a3 100644 --- a/mellea/stdlib/__init__.py +++ b/mellea/stdlib/__init__.py @@ -1,11 +1,11 @@ """The mellea standard library of components, sessions, and sampling strategies. This package provides the high-level building blocks for writing generative programs -with mellea. It contains ready-to-use ``Component`` types (``Instruction``, -``Message``, ``Document``, ``Intrinsic``, ``SimpleComponent``, and more), context -implementations (``ChatContext``, ``SimpleContext``), sampling strategies (rejection -sampling, budget forcing), session management via ``MelleaSession``, and the -``@mify`` decorator for turning ordinary Python objects into components. Import from -the sub-packages — ``mellea.stdlib.components``, ``mellea.stdlib.sampling``, and -``mellea.stdlib.session`` — for day-to-day use. +with mellea. It contains ready-to-use `Component` types (`Instruction`, +`Message`, `Document`, `Intrinsic`, `SimpleComponent`, and more), context +implementations (`ChatContext`, `SimpleContext`), sampling strategies (rejection +sampling, budget forcing), session management via `MelleaSession`, and the +`@mify` decorator for turning ordinary Python objects into components. Import from +the sub-packages — `mellea.stdlib.components`, `mellea.stdlib.sampling`, and +`mellea.stdlib.session` — for day-to-day use. """ diff --git a/mellea/stdlib/components/chat.py b/mellea/stdlib/components/chat.py index bea42015d..3a19ad723 100644 --- a/mellea/stdlib/components/chat.py +++ b/mellea/stdlib/components/chat.py @@ -1,11 +1,11 @@ -"""Chat primitives: the ``Message`` and ``ToolMessage`` components. - -Defines ``Message``, the ``Component`` subtype used to represent a single turn in a -chat history with a ``role`` (``user``, ``assistant``, ``system``, or ``tool``), -text ``content``, and optional ``images`` and ``documents`` attachments. Also provides -``ToolMessage`` (a ``Message`` subclass that carries the tool name and arguments) and -the ``as_chat_history`` utility for converting a ``Context`` into a flat list of -``Message`` objects. +"""Chat primitives: the `Message` and `ToolMessage` components. + +Defines `Message`, the `Component` subtype used to represent a single turn in a +chat history with a `role` (`user`, `assistant`, `system`, or `tool`), +text `content`, and optional `images` and `documents` attachments. Also provides +`ToolMessage` (a `Message` subclass that carries the tool name and arguments) and +the `as_chat_history` utility for converting a `Context` into a flat list of +`Message` objects. """ from collections.abc import Mapping @@ -30,8 +30,8 @@ class Message(Component["Message"]): The fact that some Component gets rendered as a chat message is `Formatter` miscellania. Args: - role (str): The role that this message came from (e.g., ``"user"``, - ``"assistant"``). + role (str): The role that this message came from (e.g., `"user"`, + `"assistant"`). content (str): The content of the message. images (list[ImageBlock] | None): Optional images associated with the message. @@ -39,8 +39,8 @@ class Message(Component["Message"]): the message. Attributes: - Role (type): Type alias for the allowed role literals: ``"system"``, - ``"user"``, ``"assistant"``, or ``"tool"``. + Role (type): Type alias for the allowed role literals: `"system"`, + `"user"`, `"assistant"`, or `"tool"`. """ Role = Literal["system", "user", "assistant", "tool"] @@ -185,17 +185,17 @@ class ToolMessage(Message): """Adds the name field for function name. Args: - role (str): The role of this message; most backends use ``"tool"``. + role (str): The role of this message; most backends use `"tool"`. content (str): The content of the message; should be a stringified - version of ``tool_output``. + version of `tool_output`. tool_output (Any): The output of the tool or function call. name (str): The name of the tool or function that was called. args (Mapping[str, Any]): The arguments passed to the tool. - tool (ModelToolCall): The ``ModelToolCall`` representation. + tool (ModelToolCall): The `ModelToolCall` representation. Attributes: arguments (Mapping[str, Any]): The arguments that were passed to the - tool; stored from the ``args`` constructor parameter. + tool; stored from the `args` constructor parameter. """ def __init__( @@ -215,7 +215,7 @@ def __init__( self._tool = tool def format_for_llm(self) -> TemplateRepresentation: - """Return the same representation as ``Message`` with a ``name`` field added to the args dict. + """Return the same representation as `Message` with a `name` field added to the args dict. Returns: TemplateRepresentation: Template representation including the tool @@ -238,17 +238,17 @@ def as_chat_history(ctx: Context) -> list[Message]: """Returns a list of Messages corresponding to a Context. Args: - ctx: A linear ``Context`` whose entries are ``Message`` or ``ModelOutputThunk`` - objects with ``Message`` parsed representations. + ctx: A linear `Context` whose entries are `Message` or `ModelOutputThunk` + objects with `Message` parsed representations. Returns: - List of ``Message`` objects in conversation order. + List of `Message` objects in conversation order. Raises: Exception: If the context history is non-linear and cannot be cast to a flat list. AssertionError: If any entry in the context cannot be converted to a - ``Message``. + `Message`. """ def _to_msg(c: CBlock | Component | ModelOutputThunk) -> Message | None: diff --git a/mellea/stdlib/components/docs/document.py b/mellea/stdlib/components/docs/document.py index 6b151d773..42576a071 100644 --- a/mellea/stdlib/components/docs/document.py +++ b/mellea/stdlib/components/docs/document.py @@ -1,8 +1,8 @@ -"""``Document`` component for grounding model inputs with text passages. +"""`Document` component for grounding model inputs with text passages. -``Document`` wraps a text passage with an optional ``title`` and ``doc_id``, and +`Document` wraps a text passage with an optional `title` and `doc_id`, and renders them inline as a formatted citation string for the model. Documents are -typically attached to a ``Message`` via its ``documents`` parameter, enabling +typically attached to a `Message` via its `documents` parameter, enabling retrieval-augmented generation (RAG) workflows. """ @@ -13,7 +13,7 @@ class Document(Component[str]): """A text passage with optional metadata for grounding model inputs. - Documents are typically attached to a ``Message`` via its ``documents`` + Documents are typically attached to a `Message` via its `documents` parameter to enable retrieval-augmented generation (RAG) workflows. Args: @@ -34,7 +34,7 @@ def parts(self) -> list[Component | CBlock]: Returns: list[Component | CBlock]: An empty list by default since the base - ``Document`` class has no constituent parts. Subclasses may override + `Document` class has no constituent parts. Subclasses may override this method to return meaningful parts. """ return [] diff --git a/mellea/stdlib/components/docs/richdocument.py b/mellea/stdlib/components/docs/richdocument.py index 7812112ed..69fd3068f 100644 --- a/mellea/stdlib/components/docs/richdocument.py +++ b/mellea/stdlib/components/docs/richdocument.py @@ -1,10 +1,10 @@ -"""``RichDocument``, ``Table``, and related helpers backed by Docling. +"""`RichDocument`, `Table`, and related helpers backed by Docling. -``RichDocument`` wraps a ``DoclingDocument`` (e.g. produced by converting a PDF or -Markdown file) and renders it as Markdown for a language model. ``Table`` represents a -single table within a Docling document and provides ``transpose``, ``to_markdown``, and -query/transform helpers. Use ``RichDocument.from_document_file`` to convert a PDF or -other supported format, and ``get_tables()`` to extract structured table data for +`RichDocument` wraps a `DoclingDocument` (e.g. produced by converting a PDF or +Markdown file) and renders it as Markdown for a language model. `Table` represents a +single table within a Docling document and provides `transpose`, `to_markdown`, and +query/transform helpers. Use `RichDocument.from_document_file` to convert a PDF or +other supported format, and `get_tables()` to extract structured table data for downstream LLM-driven Q&A or transformation tasks. """ @@ -25,11 +25,11 @@ class RichDocument(Component[str]): - """A ``RichDocument`` is a block of content backed by a ``DoclingDocument``. + """A `RichDocument` is a block of content backed by a `DoclingDocument`. Provides helper functions for working with the document and extracting parts - such as tables. Use ``from_document_file`` to convert PDFs or other formats, - and ``save``/``load`` for persistence. + such as tables. Use `from_document_file` to convert PDFs or other formats, + and `save`/`load` for persistence. Args: doc (DoclingDocument): The underlying Docling document to wrap. @@ -67,7 +67,7 @@ def _parse(self, computed: ModelOutputThunk) -> str: return computed.value if computed.value is not None else "" def docling(self) -> DoclingDocument: - """Return the underlying ``DoclingDocument``. + """Return the underlying `DoclingDocument`. Returns: DoclingDocument: The wrapped Docling document instance. @@ -82,12 +82,12 @@ def get_tables(self) -> list[Table]: """Return all tables found in this document. Returns: - list[Table]: A list of ``Table`` objects extracted from the document. + list[Table]: A list of `Table` objects extracted from the document. """ return [Table(x, self.docling()) for x in self.docling().tables] def save(self, filename: str | Path) -> None: - """Save the underlying ``DoclingDocument`` to a JSON file for later reuse. + """Save the underlying `DoclingDocument` to a JSON file for later reuse. Args: filename (str | Path): Destination file path for the serialized @@ -99,14 +99,14 @@ def save(self, filename: str | Path) -> None: @classmethod def load(cls, filename: str | Path) -> RichDocument: - """Load a ``RichDocument`` from a previously saved ``DoclingDocument`` JSON file. + """Load a `RichDocument` from a previously saved `DoclingDocument` JSON file. Args: filename (str | Path): Path to a JSON file previously created by - ``RichDocument.save``. + `RichDocument.save`. Returns: - RichDocument: A new ``RichDocument`` wrapping the loaded document. + RichDocument: A new `RichDocument` wrapping the loaded document. """ if type(filename) is str: filename = Path(filename) @@ -115,14 +115,14 @@ def load(cls, filename: str | Path) -> RichDocument: @classmethod def from_document_file(cls, source: str | Path | DocumentStream) -> RichDocument: - """Convert a document file to a ``RichDocument`` using Docling. + """Convert a document file to a `RichDocument` using Docling. Args: source (str | Path | DocumentStream): Path or stream for the source document (e.g. a PDF or Markdown file). Returns: - RichDocument: A new ``RichDocument`` wrapping the converted document. + RichDocument: A new `RichDocument` wrapping the converted document. """ pipeline_options = PdfPipelineOptions( images_scale=2.0, generate_picture_images=True @@ -139,7 +139,7 @@ def from_document_file(cls, source: str | Path | DocumentStream) -> RichDocument class TableQuery(Query): - """A ``Query`` component specialised for ``Table`` objects. + """A `Query` component specialised for `Table` objects. Formats the table as Markdown alongside the query string so the LLM receives both the structured table content and the natural-language question. @@ -157,7 +157,7 @@ def parts(self) -> list[Component | CBlock]: """Return the constituent parts of this table query. Returns: - list[Component | CBlock]: A list containing the wrapped ``Table`` + list[Component | CBlock]: A list containing the wrapped `Table` object. """ cs: list[Component | CBlock] = [self._obj] @@ -186,7 +186,7 @@ def format_for_llm(self) -> TemplateRepresentation: class TableTransform(Transform): - """A ``Transform`` component specialised for ``Table`` objects. + """A `Transform` component specialised for `Table` objects. Formats the table as Markdown alongside the transformation instruction so the LLM receives both the structured table content and the mutation description. @@ -204,7 +204,7 @@ def parts(self) -> list[Component | CBlock]: """Return the constituent parts of this table transform. Returns: - list[Component | CBlock]: A list containing the wrapped ``Table`` + list[Component | CBlock]: A list containing the wrapped `Table` object. """ cs: list[Component | CBlock] = [self._obj] @@ -236,11 +236,11 @@ def format_for_llm(self) -> TemplateRepresentation: class Table(MObject): - """A ``Table`` represents a single table within a larger Docling Document. + """A `Table` represents a single table within a larger Docling Document. Args: - ti (TableItem): The Docling ``TableItem`` extracted from the document. - doc (DoclingDocument): The parent ``DoclingDocument``. Passing ``None`` + ti (TableItem): The Docling `TableItem` extracted from the document. + doc (DoclingDocument): The parent `DoclingDocument`. Passing `None` may cause downstream Docling functions to fail. """ @@ -252,7 +252,7 @@ def __init__(self, ti: TableItem, doc: DoclingDocument): @classmethod def from_markdown(cls, md: str) -> Table | None: - """Create a ``Table`` from a Markdown string by round-tripping through Docling. + """Create a `Table` from a Markdown string by round-tripping through Docling. Wraps the Markdown in a minimal document, converts it with Docling, and returns the first table found. @@ -261,8 +261,8 @@ def from_markdown(cls, md: str) -> Table | None: md (str): A Markdown string containing at least one table. Returns: - Table | None: The first ``Table`` extracted from the Markdown, or - ``None`` if no table could be found. + Table | None: The first `Table` extracted from the Markdown, or + `None` if no table could be found. """ fake_doc = f"# X\n\n{md}\n" bs = io.BytesIO(fake_doc.encode("utf-8")) @@ -276,7 +276,7 @@ def parts(self): """Return the constituent parts of this table component. The current implementation always returns an empty list because the - table is rendered entirely through ``format_for_llm``. + table is rendered entirely through `format_for_llm`. Returns: list[Component | CBlock]: Always an empty list. @@ -292,11 +292,11 @@ def to_markdown(self) -> str: return self._ti.export_to_markdown(self._doc) def transpose(self) -> Table | None: - """Transpose this table and return the result as a new ``Table``. + """Transpose this table and return the result as a new `Table`. Returns: - Table | None: A new transposed ``Table``, or ``None`` if the - transposed Markdown cannot be parsed back into a ``Table``. + Table | None: A new transposed `Table`, or `None` if the + transposed Markdown cannot be parsed back into a `Table`. """ t = self._ti.export_to_dataframe().transpose() return Table.from_markdown(t.to_markdown()) @@ -305,8 +305,8 @@ def format_for_llm(self) -> TemplateRepresentation | str: """Return the table representation for the Formatter. Returns: - TemplateRepresentation | str: A ``TemplateRepresentation`` that - renders the table as its Markdown string using a ``{{table}}`` + TemplateRepresentation | str: A `TemplateRepresentation` that + renders the table as its Markdown string using a `{{table}}` template. """ return TemplateRepresentation( diff --git a/mellea/stdlib/components/genslot.py b/mellea/stdlib/components/genslot.py index 4ac82486d..47fac007b 100644 --- a/mellea/stdlib/components/genslot.py +++ b/mellea/stdlib/components/genslot.py @@ -68,9 +68,9 @@ class FunctionDict(TypedDict): """Return Type for a Function Component. Attributes: - name (str): The function's ``__name__``. + name (str): The function's `__name__`. signature (str): The function's parameter signature as a string. - docstring (str | None): The function's docstring, or ``None`` if absent. + docstring (str | None): The function's docstring, or `None` if absent. """ name: str @@ -116,9 +116,9 @@ def __init__( class Arguments(CBlock): - """A ``CBlock`` that renders a list of ``Argument`` objects as human-readable text. + """A `CBlock` that renders a list of `Argument` objects as human-readable text. - Each argument is formatted as ``"- name: value (type: annotation)"`` and the + Each argument is formatted as `"- name: value (type: annotation)"` and the items are newline-joined into a single string suitable for inclusion in a prompt. Args: @@ -185,10 +185,10 @@ def __init__( class Function(Generic[P, R]): - """Wraps a callable with its introspected ``FunctionDict`` metadata. + """Wraps a callable with its introspected `FunctionDict` metadata. Stores the original callable alongside its name, signature, and docstring - as produced by ``describe_function``, so generative slots can render them + as produced by `describe_function`, so generative slots can render them into prompts without re-inspecting the function each time. Args: @@ -318,11 +318,11 @@ def __init__(self): class GenerativeSlot(Component[R], Generic[P, R]): - """Abstract base class for AI-powered function wrappers produced by ``@generative``. + """Abstract base class for AI-powered function wrappers produced by `@generative`. - A ``GenerativeSlot`` wraps a callable and uses an LLM to generate its output. - Subclasses (``SyncGenerativeSlot``, ``AsyncGenerativeSlot``) implement - ``__call__`` for synchronous and asynchronous invocation respectively. + A `GenerativeSlot` wraps a callable and uses an LLM to generate its output. + Subclasses (`SyncGenerativeSlot`, `AsyncGenerativeSlot`) implement + `__call__` for synchronous and asynchronous invocation respectively. The function's signature, docstring, and type hints are rendered into a prompt so the LLM can imitate the function's intended behaviour. @@ -374,9 +374,9 @@ def extract_args_and_kwargs(*args, **kwargs) -> ExtractedArgs: Args: args: Positional arguments; the first must be either a - ``MelleaSession`` or a ``Context`` instance. + `MelleaSession` or a `Context` instance. kwargs: Keyword arguments for both the generative slot machinery - (e.g. ``m``, ``context``, ``backend``, ``requirements``) and the + (e.g. `m`, `context`, `backend`, `requirements`) and the wrapped function's own parameters. Returns: @@ -485,13 +485,13 @@ def parts(self) -> list[Component | CBlock]: def format_for_llm(self) -> TemplateRepresentation: """Format this generative slot for the language model. - Builds a ``TemplateRepresentation`` containing the function metadata + Builds a `TemplateRepresentation` containing the function metadata (name, signature, docstring), the bound arguments, and any requirement descriptions. Returns: TemplateRepresentation: The formatted representation ready for the - ``Formatter`` to render into a prompt. + `Formatter` to render into a prompt. """ return TemplateRepresentation( obj=self, @@ -524,9 +524,9 @@ def _parse(self, computed: ModelOutputThunk) -> R: class SyncGenerativeSlot(GenerativeSlot, Generic[P, R]): """A synchronous generative slot that blocks until the LLM response is ready. - Returned by ``@generative`` when the decorated function is not a coroutine. - ``__call__`` returns the parsed result directly (when a session is passed) or a - ``(result, context)`` tuple (when a context and backend are passed). + Returned by `@generative` when the decorated function is not a coroutine. + `__call__` returns the parsed result directly (when a session is passed) or a + `(result, context)` tuple (when a context and backend are passed). """ @overload @@ -853,7 +853,7 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: PreconditionException: (raised when calling the generative slot) if the precondition validation of the args fails; catch the exception to get the validation results Examples: - ```python + ``python >>> from mellea import generative, start_session >>> session = start_session() >>> @generative @@ -928,7 +928,7 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: ... ... >>> >>> reasoning = generate_chain_of_thought(session, problem="How to optimize a slow database query?") - ``` + `` """ if inspect.iscoroutinefunction(func): return AsyncGenerativeSlot(func) diff --git a/mellea/stdlib/components/instruction.py b/mellea/stdlib/components/instruction.py index 30faaea20..d7cd63f4d 100644 --- a/mellea/stdlib/components/instruction.py +++ b/mellea/stdlib/components/instruction.py @@ -1,7 +1,7 @@ -"""``Instruction`` component for instruct/validate/repair loops. +"""`Instruction` component for instruct/validate/repair loops. -``Instruction`` is the primary component type used with ``MelleaSession.instruct``. It -packages a task ``description``, a list of ``Requirement`` constraints, optional +`Instruction` is the primary component type used with `MelleaSession.instruct`. It +packages a task `description`, a list of `Requirement` constraints, optional in-context-learning examples, a grounding context dict, user variables for Jinja2 template interpolation, and output/input prefix overrides into a single renderable unit. The session's sampling strategy evaluates each requirement against the model's @@ -39,7 +39,7 @@ class Instruction(Component[str]): to all string parameters. prefix (str | CBlock | None): A prefix prepended before the model's generation. output_prefix (str | CBlock | None): A prefix prepended to the model's output token - stream (currently unsupported; must be ``None``). + stream (currently unsupported; must be `None`). images (list[ImageBlock] | None): Images to include in the prompt. Attributes: @@ -217,8 +217,8 @@ def copy_and_repair(self, repair_string: str) -> Instruction: describing which requirements failed and why. Returns: - Instruction: A new ``Instruction`` identical to this one but with - ``_repair_string`` set to ``repair_string``. + Instruction: A new `Instruction` identical to this one but with + `_repair_string` set to `repair_string`. """ res = deepcopy(self) res._repair_string = repair_string diff --git a/mellea/stdlib/components/intrinsic/intrinsic.py b/mellea/stdlib/components/intrinsic/intrinsic.py index 497912a91..8ce15ad1b 100644 --- a/mellea/stdlib/components/intrinsic/intrinsic.py +++ b/mellea/stdlib/components/intrinsic/intrinsic.py @@ -1,10 +1,10 @@ -"""``Intrinsic`` component for invoking fine-tuned adapter capabilities. +"""`Intrinsic` component for invoking fine-tuned adapter capabilities. -An ``Intrinsic`` component references a named adapter from Mellea's intrinsic catalog +An `Intrinsic` component references a named adapter from Mellea's intrinsic catalog and transforms a chat completion request — typically by injecting new messages, modifying model parameters, or applying structured output constraints. It must be -paired with a backend that supports adapter loading (e.g. ``LocalHFBackend`` with an -attached ``IntrinsicAdapter``). +paired with a backend that supports adapter loading (e.g. `LocalHFBackend` with an +attached `IntrinsicAdapter`). """ from ....backends.adapters import AdapterType, fetch_intrinsic_metadata @@ -63,9 +63,9 @@ def parts(self) -> list[Component | CBlock]: return [] # TODO revisit this. def format_for_llm(self) -> TemplateRepresentation | str: - """Not implemented for the base ``Intrinsic`` class. + """Not implemented for the base `Intrinsic` class. - ``Intrinsic`` components are intended to be used as the *action* passed + `Intrinsic` components are intended to be used as the *action* passed directly to the backend, not as a part of the context rendered by the formatter. @@ -73,8 +73,8 @@ def format_for_llm(self) -> TemplateRepresentation | str: TemplateRepresentation | str: Never returns; always raises. Raises: - NotImplementedError: Always, because ``Intrinsic`` does not - implement ``format_for_llm`` by default. + NotImplementedError: Always, because `Intrinsic` does not + implement `format_for_llm` by default. """ raise NotImplementedError( "`Intrinsic` doesn't implement format_for_llm by default. You should only " diff --git a/mellea/stdlib/components/intrinsic/rag.py b/mellea/stdlib/components/intrinsic/rag.py index cf14d24b1..135dcb4dc 100644 --- a/mellea/stdlib/components/intrinsic/rag.py +++ b/mellea/stdlib/components/intrinsic/rag.py @@ -91,7 +91,7 @@ def check_answerability( Args: question: Question that the user has posed in response to the last turn in - ``context``. + `context`. documents: Document snippets retrieved that may or may not answer the indicated question. context: Chat context containing the conversation thus far. @@ -119,12 +119,12 @@ def rewrite_question( Args: question: Question that the user has posed in response to the last turn in - ``context``. + `context`. context: Chat context containing the conversation thus far. backend: Backend instance that supports adding the LoRA or aLoRA adapters. Returns: - Rewritten version of ``question``. + Rewritten version of `question`. """ result_json = _call_intrinsic( "query_rewrite", context.add(Message("user", question)), backend @@ -176,8 +176,8 @@ def find_citations( Args: response: Potential assistant response. - documents: Documents that were used to generate ``response``. These documents - should set the ``doc_id`` field; otherwise the intrinsic will be unable to + documents: Documents that were used to generate `response`. These documents + should set the `doc_id` field; otherwise the intrinsic will be unable to specify which document was the source of a given citation. context: Context of the dialog between user and assistant at the point where the user has just asked a question that will be answered with RAG documents. @@ -185,9 +185,9 @@ def find_citations( intrinsic. Returns: - List of records with the following fields: ``response_begin``, - ``response_end``, ``response_text``, ``citation_doc_id``, ``citation_begin``, - ``citation_end``, ``citation_text``. Begin and end offsets are character + List of records with the following fields: `response_begin`, + `response_end`, `response_text`, `citation_doc_id`, `citation_begin`, + `citation_end`, `citation_text`. Begin and end offsets are character offsets into their respective UTF-8 strings. """ result_json = _call_intrinsic( @@ -241,16 +241,16 @@ def flag_hallucinated_content( Args: response: The assistant's response to the user's question in the last turn - of ``context``. - documents: Document snippets that were used to generate ``response``. + of `context`. + documents: Document snippets that were used to generate `response`. context: A chat log that ends with a user asking a question. backend: Backend instance that supports the adapters that implement this intrinsic. Returns: - List of records with the following fields: ``response_begin``, - ``response_end``, ``response_text``, ``faithfulness_likelihood``, - ``explanation``. + List of records with the following fields: `response_begin`, + `response_end`, `response_text`, `faithfulness_likelihood`, + `explanation`. """ result_json = _call_intrinsic( "hallucination_detection", @@ -272,8 +272,8 @@ def rewrite_answer_for_relevance( Args: response: The assistant's response to the user's question in the last turn - of ``context``. - documents: Document snippets that were used to generate ``response``. + of `context`. + documents: Document snippets that were used to generate `response`. context: A chat log that ends with a user asking a question. backend: Backend instance that supports the adapters that implement this intrinsic. diff --git a/mellea/stdlib/components/mify.py b/mellea/stdlib/components/mify.py index 26aa6da3a..5c25bf36a 100644 --- a/mellea/stdlib/components/mify.py +++ b/mellea/stdlib/components/mify.py @@ -1,10 +1,10 @@ -"""The ``@mify`` decorator for turning Python objects into ``Component``s. +"""The `@mify` decorator for turning Python objects into `Component`s. -``mify`` wraps an existing Python class or instance with the ``MifiedProtocol`` +`mify` wraps an existing Python class or instance with the `MifiedProtocol` interface, exposing its fields as named spans and its documented methods as -``MelleaTool`` instances callable by the LLM. The resulting ``MifiedProtocol`` object +`MelleaTool` instances callable by the LLM. The resulting `MifiedProtocol` object can be queried, transformed, and formatted for a language model without any manual -``Component`` subclassing. Use ``mify`` when you have an existing domain object +`Component` subclassing. Use `mify` when you have an existing domain object (dataclass, Pydantic model, or plain class) that you want to expose directly to an LLM-driven pipeline. """ @@ -66,7 +66,7 @@ def get_query_object(self, query: str) -> Query: query (str): The natural-language query string. Returns: - Query: A ``Query`` component wrapping this object and the given query. + Query: A `Query` component wrapping this object and the given query. """ return self._query_type(self, query) @@ -79,7 +79,7 @@ def get_transform_object(self, transformation: str) -> Transform: transformation (str): The natural-language transformation description. Returns: - Transform: A ``Transform`` component wrapping this object and the + Transform: A `Transform` component wrapping this object and the given transformation description. """ return self._transform_type(self, transformation) @@ -89,8 +89,8 @@ def content_as_string(self) -> str: [no-index] - Delegates to the ``stringify_func`` passed to ``mify`` when one was - provided; otherwise falls back to ``str(self)``. + Delegates to the `stringify_func` passed to `mify` when one was + provided; otherwise falls back to `str(self)`. Returns: str: String representation of the mified object's content. @@ -193,14 +193,14 @@ def _get_all_fields(self) -> dict[str, Any]: return narrowed def format_for_llm(self) -> TemplateRepresentation: - """Return the ``TemplateRepresentation`` for this mified object. + """Return the `TemplateRepresentation` for this mified object. [no-index] - Sets the ``TemplateRepresentation`` fields based on the object and the - configuration values supplied to ``mify`` (fields, templates, tools, etc.). + Sets the `TemplateRepresentation` fields based on the object and the + configuration values supplied to `mify` (fields, templates, tools, etc.). - See the ``mify`` decorator for more details. + See the `mify` decorator for more details. Returns: TemplateRepresentation: The formatted representation including args, @@ -243,17 +243,17 @@ def parse(self, computed: ModelOutputThunk) -> str: [no-index] - Delegates to ``_parse`` and wraps any exception in a - ``ComponentParseError`` to give callers a consistent error type. + Delegates to `_parse` and wraps any exception in a + `ComponentParseError` to give callers a consistent error type. Args: computed (ModelOutputThunk): The raw model output to parse. Returns: - str: The string value extracted from ``computed``. + str: The string value extracted from `computed`. Raises: - ComponentParseError: If ``_parse`` raises any exception during + ComponentParseError: If `_parse` raises any exception during parsing. """ try: @@ -316,46 +316,46 @@ def mify(*args, **kwargs): # noqa: D417 Args: obj: A class or an instance of a class to mify. Omit when using as a - decorator with arguments (e.g. ``@mify(fields_include={...})``). + decorator with arguments (e.g. `@mify(fields_include={...})`). query_type: A specific query component type to use when querying a model. - Defaults to ``Query``. + Defaults to `Query`. transform_type: A specific transform component type to use when - transforming with a model. Defaults to ``Transform``. + transforming with a model. Defaults to `Transform`. fields_include: Fields of the object to include in its representation to - models. When set, ``stringify_func`` is not used. + models. When set, `stringify_func` is not used. fields_exclude: Fields of the object to exclude from its representation to models. funcs_include: Functions of the object to expose as tools to models. funcs_exclude: Functions of the object to hide from models. template: A Jinja2 template string. Takes precedence over - ``template_order`` when provided. + `template_order` when provided. template_order: A template name or list of names used when searching for applicable templates. parsing_func: Not yet implemented. stringify_func: A callable used to create a string representation of the - object for ``content_as_string``. + object for `content_as_string`. Returns: An object if an object was passed in or a decorator (callable) to mify classes. If an object is returned, that object will be the same object that was passed in. For example, - ``` + `` obj = mify(obj) obj.format_for_llm() - ``` + `` and - ``` + `` mify(obj) obj.format_for_llm() - ``` + `` are equivalent. Most IDEs will not correctly show the type hints for the newly added functions for either an mify object or instances of an mified class. For IDE support, write - ``` + `` assert isinstance(obj, MifiedProtocol) - ``` + `` """ # Grab and remove obj if it exists in kwargs. Otherwise, it's the only arg. obj = kwargs.pop("obj", None) diff --git a/mellea/stdlib/components/mobject.py b/mellea/stdlib/components/mobject.py index 38719b213..bea57b2a0 100644 --- a/mellea/stdlib/components/mobject.py +++ b/mellea/stdlib/components/mobject.py @@ -1,10 +1,10 @@ -"""``MObject``, ``Query``, ``Transform``, and ``MObjectProtocol`` for query/transform workflows. +"""`MObject`, `Query`, `Transform`, and `MObjectProtocol` for query/transform workflows. -Defines the ``MObjectProtocol`` protocol for objects that can be queried and -transformed by an LLM, and the concrete ``MObject`` base class that implements it. -Also provides the ``Query`` and ``Transform`` ``Component`` subtypes, which wrap an +Defines the `MObjectProtocol` protocol for objects that can be queried and +transformed by an LLM, and the concrete `MObject` base class that implements it. +Also provides the `Query` and `Transform` `Component` subtypes, which wrap an object with a natural-language question or mutation instruction respectively. These -primitives underpin ``@mify`` and can be composed directly to build document Q&A +primitives underpin `@mify` and can be composed directly to build document Q&A or structured extraction pipelines. """ @@ -19,9 +19,9 @@ class Query(Component[str]): - """A ``Component`` that pairs an ``MObject`` with a natural-language question. + """A `Component` that pairs an `MObject` with a natural-language question. - Wraps the object and its query string into a ``TemplateRepresentation`` so the + Wraps the object and its query string into a `TemplateRepresentation` so the formatter can render both together in a prompt, optionally forwarding the object's tools and fields to the template. @@ -47,7 +47,7 @@ def format_for_llm(self) -> TemplateRepresentation | str: """Format this query for the language model. Returns: - TemplateRepresentation | str: A ``TemplateRepresentation`` containing + TemplateRepresentation | str: A `TemplateRepresentation` containing the query string, the wrapped object, and any tools or fields from the object's own representation. """ @@ -77,10 +77,10 @@ def _parse(self, computed: ModelOutputThunk) -> str: class Transform(Component[str]): - """A ``Component`` that pairs an ``MObject`` with a natural-language mutation instruction. + """A `Component` that pairs an `MObject` with a natural-language mutation instruction. Wraps the object and its transformation description into a - ``TemplateRepresentation`` so the formatter can render both together in a prompt, + `TemplateRepresentation` so the formatter can render both together in a prompt, optionally forwarding the object's tools and fields to the template. Args: @@ -105,7 +105,7 @@ def format_for_llm(self) -> TemplateRepresentation | str: """Format this transform for the language model. Returns: - TemplateRepresentation | str: A ``TemplateRepresentation`` containing + TemplateRepresentation | str: A `TemplateRepresentation` containing the transformation description, the wrapped object, and any tools or fields from the object's own representation. """ @@ -153,7 +153,7 @@ def get_query_object(self, query: str) -> Query: query (str): The query string. Returns: - Query: A ``Query`` component wrapping this object and the given + Query: A `Query` component wrapping this object and the given query string. """ ... @@ -165,7 +165,7 @@ def get_transform_object(self, transformation: str) -> Transform: transformation (str): The transformation description string. Returns: - Transform: A ``Transform`` component wrapping this object and the + Transform: A `Transform` component wrapping this object and the given transformation description. """ ... @@ -173,7 +173,7 @@ def get_transform_object(self, transformation: str) -> Transform: def content_as_string(self) -> str: """Return the content of this MObject as a plain string. - The default value is just ``str(self)``. + The default value is just `str(self)`. Subclasses should override this method. Returns: @@ -184,7 +184,7 @@ def content_as_string(self) -> str: def _get_all_members(self) -> dict[str, Callable]: """Return all methods from this MObject that are not inherited from the superclass. - Undocumented methods and methods with ``[no-index]`` in their docstring + Undocumented methods and methods with `[no-index]` in their docstring are ignored. """ ... @@ -192,8 +192,8 @@ def _get_all_members(self) -> dict[str, Callable]: def format_for_llm(self) -> TemplateRepresentation | str: """Return the template representation used by the formatter. - The default ``TemplateRepresentation`` uses automatic parsing for tools - and fields. Content is retrieved from ``content_as_string()``. + The default `TemplateRepresentation` uses automatic parsing for tools + and fields. Content is retrieved from `content_as_string()`. Returns: TemplateRepresentation | str: The formatted representation for the @@ -207,13 +207,13 @@ def _parse(self, computed: ModelOutputThunk) -> str: class MObject(Component[str]): - """An extension of ``Component`` for adding query and transform operations. + """An extension of `Component` for adding query and transform operations. Args: - query_type (type): The ``Query`` subclass to use when constructing query - components. Defaults to ``Query``. - transform_type (type): The ``Transform`` subclass to use when constructing - transform components. Defaults to ``Transform``. + query_type (type): The `Query` subclass to use when constructing query + components. Defaults to `Query`. + transform_type (type): The `Transform` subclass to use when constructing + transform components. Defaults to `Transform`. """ def __init__( @@ -238,7 +238,7 @@ def get_query_object(self, query: str) -> Query: query (str): The query string. Returns: - Query: A ``Query`` component wrapping this object and the given + Query: A `Query` component wrapping this object and the given query string. """ return self._query_type(self, query) @@ -250,7 +250,7 @@ def get_transform_object(self, transformation: str) -> Transform: transformation (str): The transformation description string. Returns: - Transform: A ``Transform`` component wrapping this object and the + Transform: A `Transform` component wrapping this object and the given transformation description. """ return self._transform_type(self, transformation) @@ -258,7 +258,7 @@ def get_transform_object(self, transformation: str) -> Transform: def content_as_string(self) -> str: """Return the content of this MObject as a plain string. - The default value is just ``str(self)``. + The default value is just `str(self)`. Subclasses should override this method. Returns: @@ -269,7 +269,7 @@ def content_as_string(self) -> str: def _get_all_members(self) -> dict[str, Callable]: """Return all methods from this MObject except methods of the superclass. - Undocumented methods and methods with ``[no-index]`` in their docstring + Undocumented methods and methods with `[no-index]` in their docstring are ignored. """ all_members: dict[str, Callable] = dict( @@ -293,8 +293,8 @@ def _get_all_members(self) -> dict[str, Callable]: def format_for_llm(self) -> TemplateRepresentation | str: """Return the template representation used by the formatter. - The default ``TemplateRepresentation`` uses automatic parsing for tools - and fields. Content is retrieved from ``content_as_string()``. + The default `TemplateRepresentation` uses automatic parsing for tools + and fields. Content is retrieved from `content_as_string()`. Returns: TemplateRepresentation | str: The formatted representation for the diff --git a/mellea/stdlib/components/react.py b/mellea/stdlib/components/react.py index bee9eaf4a..198bb3929 100644 --- a/mellea/stdlib/components/react.py +++ b/mellea/stdlib/components/react.py @@ -1,10 +1,10 @@ """Components that implement the ReACT (Reason + Act) agentic pattern. -Provides ``ReactInitiator``, which primes the model with a goal and a tool list, and -``ReactThought``, which signals a thinking step. Also exports the -``MELLEA_FINALIZER_TOOL`` sentinel string used to signal loop termination. These -components are consumed by ``mellea.stdlib.frameworks.react``, which orchestrates the -reasoning-acting cycle until the model invokes ``final_answer`` or the step budget +Provides `ReactInitiator`, which primes the model with a goal and a tool list, and +`ReactThought`, which signals a thinking step. Also exports the +`MELLEA_FINALIZER_TOOL` sentinel string used to signal loop termination. These +components are consumed by `mellea.stdlib.frameworks.react`, which orchestrates the +reasoning-acting cycle until the model invokes `final_answer` or the step budget is exhausted. """ @@ -38,7 +38,7 @@ class ReactInitiator(Component[str]): Args: goal (str): The objective of the react loop. tools (list[AbstractMelleaTool] | None): Tools available to the agent. - ``None`` is treated as an empty list. + `None` is treated as an empty list. Attributes: goal (CBlock): The objective of the react loop wrapped as a content block. @@ -99,7 +99,7 @@ def __init__(self): def parts(self) -> list[Component | CBlock]: """Return the constituent parts of this component. - ``ReactThought`` has no sub-components; it solely triggers a thinking step. + `ReactThought` has no sub-components; it solely triggers a thinking step. Returns: list[Component | CBlock]: Always an empty list. diff --git a/mellea/stdlib/components/simple.py b/mellea/stdlib/components/simple.py index 522034075..18bf39936 100644 --- a/mellea/stdlib/components/simple.py +++ b/mellea/stdlib/components/simple.py @@ -1,9 +1,9 @@ -"""``SimpleComponent``: a lightweight named-span component. +"""`SimpleComponent`: a lightweight named-span component. -``SimpleComponent`` accepts arbitrary keyword arguments (strings, ``CBlock``s, or -``Component``s) and renders them as a JSON object keyed by the argument names. It is +`SimpleComponent` accepts arbitrary keyword arguments (strings, `CBlock`s, or +`Component`s) and renders them as a JSON object keyed by the argument names. It is the go-to component type for ad-hoc prompts that do not require a dedicated -``Component`` subclass or a Jinja2 template. +`Component` subclass or a Jinja2 template. """ from typing import Any @@ -37,11 +37,11 @@ def _kwargs_type_check(self, kwargs: dict[str, Any]) -> bool: @staticmethod def make_simple_string(kwargs: dict[str, Any]) -> str: - """Render keyword arguments as ``<|key|>value`` tagged strings. + """Render keyword arguments as `<|key|>value` tagged strings. Args: - kwargs (dict[str, Any]): Mapping of span names to their ``CBlock`` or - ``Component`` values. + kwargs (dict[str, Any]): Mapping of span names to their `CBlock` or + `Component` values. Returns: str: Newline-joined tagged representation of all keyword arguments. @@ -54,13 +54,13 @@ def make_simple_string(kwargs: dict[str, Any]) -> str: def make_json_string(kwargs: dict[str, Any]) -> str: """Serialize keyword arguments to a JSON string. - Each value is converted to its string representation: ``CBlock`` and - ``ModelOutputThunk`` values use their ``.value`` attribute, while - ``Component`` values use ``format_for_llm()``. + Each value is converted to its string representation: `CBlock` and + `ModelOutputThunk` values use their `.value` attribute, while + `Component` values use `format_for_llm()`. Args: - kwargs (dict[str, Any]): Mapping of span names to ``CBlock``, ``Component``, - or ``ModelOutputThunk`` values. + kwargs (dict[str, Any]): Mapping of span names to `CBlock`, `Component`, + or `ModelOutputThunk` values. Returns: str: JSON-encoded representation of the keyword arguments. @@ -79,7 +79,7 @@ def make_json_string(kwargs: dict[str, Any]) -> str: def format_for_llm(self) -> str: """Format this component as a JSON string representation for the language model. - Delegates to ``make_json_string`` using the stored keyword arguments. + Delegates to `make_json_string` using the stored keyword arguments. Returns: str: JSON-encoded string of all named spans in this component. diff --git a/mellea/stdlib/components/unit_test_eval.py b/mellea/stdlib/components/unit_test_eval.py index d2e9a9cf7..668474884 100644 --- a/mellea/stdlib/components/unit_test_eval.py +++ b/mellea/stdlib/components/unit_test_eval.py @@ -13,8 +13,8 @@ class Message(BaseModel): """Schema for a message in the test data. Attributes: - role (str): The role of the message sender (e.g. ``"user"`` or - ``"assistant"``). + role (str): The role of the message sender (e.g. `"user"` or + `"assistant"`). content (str): The text content of the message. """ @@ -59,7 +59,7 @@ def validate_examples(cls, v: list[Example]) -> list[Example]: """Validate that the examples list is not empty. Args: - v (list[Example]): The value of the ``examples`` field being + v (list[Example]): The value of the `examples` field being validated. Returns: @@ -82,7 +82,7 @@ class TestBasedEval(Component[str]): instructions (str): Evaluation guidelines used by the judge model. inputs (list[str]): The input texts for each example. targets (list[list[str]] | None): Expected target strings for each - input. ``None`` is treated as an empty list. + input. `None` is treated as an empty list. test_id (str | None): Optional unique identifier for this test. input_ids (list[str] | None): Optional identifiers for each input. @@ -112,7 +112,7 @@ def parts(self) -> list[Component | CBlock]: Returns: list[Component | CBlock]: Always an empty list; the component - renders entirely via ``format_for_llm``. + renders entirely via `format_for_llm`. """ return [] @@ -122,7 +122,7 @@ def format_for_llm(self) -> TemplateRepresentation: Returns: TemplateRepresentation: A template representation containing the judge context (input, prediction, target, guidelines) set by - ``set_judge_context``, or an empty args dict if no context has + `set_judge_context`, or an empty args dict if no context has been set yet. """ return TemplateRepresentation( @@ -144,7 +144,7 @@ def set_judge_context( input_text (str): The original input text shown to the model. prediction (str): The model's generated output to evaluate. targets_for_input (list[str]): Reference target strings for this - input. An empty list results in ``"N/A"`` as the target text. + input. An empty list results in `"N/A"` as the target text. """ if len(targets_for_input) == 0: # no reference target_text = "N/A" @@ -164,19 +164,19 @@ def set_judge_context( @classmethod def from_json_file(cls, filepath: str) -> list["TestBasedEval"]: - """Load test evaluations from a JSON file, returning one ``TestBasedEval`` per unit test. + """Load test evaluations from a JSON file, returning one `TestBasedEval` per unit test. Args: filepath (str): Path to a JSON file containing one test-data object or a JSON array of test-data objects. Returns: - list[TestBasedEval]: A list of ``TestBasedEval`` instances, one for + list[TestBasedEval]: A list of `TestBasedEval` instances, one for each object found in the file. Raises: ValueError: If any test-data object in the file does not conform to - the ``TestData`` schema. + the `TestData` schema. """ path = Path(filepath) diff --git a/mellea/stdlib/context.py b/mellea/stdlib/context.py index f0e2e8b2f..b8a748fab 100644 --- a/mellea/stdlib/context.py +++ b/mellea/stdlib/context.py @@ -1,9 +1,9 @@ -"""Concrete ``Context`` implementations for common conversation patterns. +"""Concrete `Context` implementations for common conversation patterns. -Provides ``ChatContext``, which accumulates all turns in a sliding-window chat history -(configurable via ``window_size``), and ``SimpleContext``, in which each interaction +Provides `ChatContext`, which accumulates all turns in a sliding-window chat history +(configurable via `window_size`), and `SimpleContext`, in which each interaction is treated as a stateless single-turn exchange (no prior history is passed to the -model). Import ``ChatContext`` for multi-turn conversations and ``SimpleContext`` when +model). Import `ChatContext` for multi-turn conversations and `SimpleContext` when you want each call to the model to be independent. """ @@ -18,7 +18,7 @@ class ChatContext(Context): Args: window_size (int | None): Maximum number of context turns to include when - calling ``view_for_generation``. ``None`` (the default) means the full + calling `view_for_generation`. `None` (the default) means the full history is always returned. """ @@ -34,8 +34,8 @@ def add(self, c: Component | CBlock) -> ChatContext: c (Component | CBlock): The component or content block to append. Returns: - ChatContext: A new ``ChatContext`` with the added entry, preserving the - current ``window_size`` setting. + ChatContext: A new `ChatContext` with the added entry, preserving the + current `window_size` setting. """ new = ChatContext.from_previous(self, c) new._window_size = self._window_size @@ -44,13 +44,13 @@ def add(self, c: Component | CBlock) -> ChatContext: def view_for_generation(self) -> list[Component | CBlock] | None: """Return the context entries to pass to the model, respecting the configured window. - Uses the ``window_size`` set during initialisation to limit how many past - turns are included. ``None`` is returned when the underlying history is + Uses the `window_size` set during initialisation to limit how many past + turns are included. `None` is returned when the underlying history is non-linear. Returns: list[Component | CBlock] | None: Ordered list of context entries up to - ``window_size`` turns, or ``None`` if the history is non-linear. + `window_size` turns, or `None` if the history is non-linear. """ return self.as_list(self._window_size) @@ -65,13 +65,13 @@ def add(self, c: Component | CBlock) -> SimpleContext: c (Component | CBlock): The component or content block to record. Returns: - SimpleContext: A new ``SimpleContext`` containing only the added entry; + SimpleContext: A new `SimpleContext` containing only the added entry; prior history is not retained. """ return SimpleContext.from_previous(self, c) def view_for_generation(self) -> list[Component | CBlock] | None: - """Return an empty list, since ``SimpleContext`` does not pass history to the model. + """Return an empty list, since `SimpleContext` does not pass history to the model. Each call to the model is treated as a stateless, independent exchange. No prior turns are forwarded. diff --git a/mellea/stdlib/frameworks/react.py b/mellea/stdlib/frameworks/react.py index 810542295..425082e29 100644 --- a/mellea/stdlib/frameworks/react.py +++ b/mellea/stdlib/frameworks/react.py @@ -1,10 +1,10 @@ """ReACT (Reason + Act) agentic pattern implementation. -Provides the ``react()`` async function, which drives a tool-use loop: the model +Provides the `react()` async function, which drives a tool-use loop: the model reasons about a goal, selects a tool, receives the result as an observation, and -repeats until it calls ``final_answer`` or the ``loop_budget`` is exhausted. Accepts -any list of ``AbstractMelleaTool`` instances and a ``ChatContext`` for multi-turn -history tracking. Raises ``RuntimeError`` if the loop ends without a final answer. +repeats until it calls `final_answer` or the `loop_budget` is exhausted. Accepts +any list of `AbstractMelleaTool` instances and a `ChatContext` for multi-turn +history tracking. Raises `RuntimeError` if the loop ends without a final answer. """ # from PIL import Image as PILImage diff --git a/mellea/stdlib/functional.py b/mellea/stdlib/functional.py index a9caab140..5b944ee1c 100644 --- a/mellea/stdlib/functional.py +++ b/mellea/stdlib/functional.py @@ -245,15 +245,15 @@ def chat( content: The message text to send. context: The current conversation context. backend: The backend used to generate the response. - role: The role for the outgoing message (default ``"user"``). + role: The role for the outgoing message (default `"user"`). images: Optional list of images to include in the message. - user_variables: Optional Jinja variable substitutions applied to ``content``. + user_variables: Optional Jinja variable substitutions applied to `content`. format: Optional Pydantic model for constrained decoding of the response. model_options: Additional model options to merge with backend defaults. tool_calls: If true, tool calling is enabled. Returns: - Tuple of the assistant ``Message`` and the updated ``Context``. + Tuple of the assistant `Message` and the updated `Context`. """ if user_variables is not None: content_resolved = Instruction.apply_user_dict_from_jinja( @@ -294,17 +294,17 @@ def validate( """Validates a set of requirements over the output (if provided) or the current context (if the output is not provided). Args: - reqs: A single ``Requirement`` or a list of them to validate. + reqs: A single `Requirement` or a list of them to validate. context: The current conversation context. backend: The backend used for LLM-as-a-judge requirements. - output: Optional model output ``CBlock`` to validate against instead of the context. + output: Optional model output `CBlock` to validate against instead of the context. format: Optional Pydantic model for constrained decoding. model_options: Additional model options to merge with backend defaults. generate_logs: Optional list to append generation logs to. - input: Optional input ``CBlock`` to include alongside ``output`` when validating. + input: Optional input `CBlock` to include alongside `output` when validating. Returns: - List of ``ValidationResult`` objects, one per requirement. + List of `ValidationResult` objects, one per requirement. """ # Run everything in the specific event loop for this session. @@ -819,15 +819,15 @@ async def achat( content: The message text to send. context: The current conversation context. backend: The backend used to generate the response. - role: The role for the outgoing message (default ``"user"``). + role: The role for the outgoing message (default `"user"`). images: Optional list of images to include in the message. - user_variables: Optional Jinja variable substitutions applied to ``content``. + user_variables: Optional Jinja variable substitutions applied to `content`. format: Optional Pydantic model for constrained decoding of the response. model_options: Additional model options to merge with backend defaults. tool_calls: If true, tool calling is enabled. Returns: - Tuple of the assistant ``Message`` and the updated ``Context``. + Tuple of the assistant `Message` and the updated `Context`. """ if user_variables is not None: content_resolved = Instruction.apply_user_dict_from_jinja( @@ -867,17 +867,17 @@ async def avalidate( """Asynchronous version of .validate; validates a set of requirements over the output (if provided) or the current context (if the output is not provided). Args: - reqs: A single ``Requirement`` or a list of them to validate. + reqs: A single `Requirement` or a list of them to validate. context: The current conversation context. backend: The backend used for LLM-as-a-judge requirements. - output: Optional model output ``CBlock`` to validate against instead of the context. + output: Optional model output `CBlock` to validate against instead of the context. format: Optional Pydantic model for constrained decoding. model_options: Additional model options to merge with backend defaults. generate_logs: Optional list to append generation logs to. - input: Optional input ``CBlock`` to include alongside ``output`` when validating. + input: Optional input `CBlock` to include alongside `output` when validating. Returns: - List of ``ValidationResult`` objects, one per requirement. + List of `ValidationResult` objects, one per requirement. """ # Turn a solitary requirement in to a list of requirements, and then reqify if needed. reqs = [reqs] if not isinstance(reqs, list) else reqs diff --git a/mellea/stdlib/requirements/md.py b/mellea/stdlib/requirements/md.py index 5d21223c3..f1da8b610 100644 --- a/mellea/stdlib/requirements/md.py +++ b/mellea/stdlib/requirements/md.py @@ -32,7 +32,7 @@ def as_markdown_list(ctx: Context) -> list[str] | None: Returns: List of rendered list-item strings if the output is a markdown list, - or ``None`` if parsing fails or the output is not a list. + or `None` if parsing fails or the output is not a list. """ mistletoe = _get_mistletoe() xs = list() diff --git a/mellea/stdlib/requirements/python_reqs.py b/mellea/stdlib/requirements/python_reqs.py index ad2d6a74b..7f3bc1ca9 100644 --- a/mellea/stdlib/requirements/python_reqs.py +++ b/mellea/stdlib/requirements/python_reqs.py @@ -69,11 +69,11 @@ def _has_python_code_listing(ctx: Context) -> ValidationResult: # Look for code blocks with python specifier import re - # Pattern for ```python ... ``` blocks - python_blocks = re.findall(r"```python\s*\n(.*?)\n```", content, re.DOTALL) + # Pattern for `python ... ` blocks + python_blocks = re.findall(r"``python\s*\n(.*?)\n``", content, re.DOTALL) # Pattern for generic ``` blocks - generic_blocks = re.findall(r"```\s*\n(.*?)\n```", content, re.DOTALL) + generic_blocks = re.findall(r"``\s*\n(.*?)\n``", content, re.DOTALL) all_blocks = [] @@ -146,17 +146,17 @@ class PythonExecutionReq(Requirement): and validates or executes it according to the configured execution mode. Args: - timeout (int): Maximum seconds to allow code to run. Defaults to ``5``. - allow_unsafe_execution (bool): If ``True``, execute code directly with + timeout (int): Maximum seconds to allow code to run. Defaults to `5`. + allow_unsafe_execution (bool): If `True`, execute code directly with subprocess. Use only with trusted sources. allowed_imports (list[str] | None): Allowlist of importable top-level - modules. ``None`` allows any import. - use_sandbox (bool): If ``True``, use ``llm-sandbox`` for Docker-based + modules. `None` allows any import. + use_sandbox (bool): If `True`, use `llm-sandbox` for Docker-based isolated execution. Attributes: validation_fn (Callable[[Context], ValidationResult]): The validation - function attached to this requirement; always non-``None``. + function attached to this requirement; always non-`None`. """ def __init__( diff --git a/mellea/stdlib/requirements/requirement.py b/mellea/stdlib/requirements/requirement.py index 3ec5e8164..33f371439 100644 --- a/mellea/stdlib/requirements/requirement.py +++ b/mellea/stdlib/requirements/requirement.py @@ -12,7 +12,7 @@ class LLMaJRequirement(Requirement): """A requirement that always uses LLM-as-a-Judge. Any available constraint ALoRA will be ignored. Attributes: - use_aloras (bool): Always ``False`` for this class; ALoRA adapters are + use_aloras (bool): Always `False` for this class; ALoRA adapters are never used even if they are available. """ @@ -20,14 +20,14 @@ class LLMaJRequirement(Requirement): def requirement_check_to_bool(x: CBlock | str) -> bool: - """Checks if a given output should be marked converted to ``True``. + """Checks if a given output should be marked converted to `True`. - By default, the requirement check alora outputs: ``{"requirement_likelihood": 0.0}``. - Returns ``True`` if the likelihood value is > 0.5. + By default, the requirement check alora outputs: `{"requirement_likelihood": 0.0}`. + Returns `True` if the likelihood value is > 0.5. Args: x: ALoRA output string or CBlock containing JSON with a - ``requirement_likelihood`` field. + `requirement_likelihood` field. Returns: True if the extracted likelihood exceeds 0.5, False otherwise. @@ -51,16 +51,16 @@ def requirement_check_to_bool(x: CBlock | str) -> bool: class ALoraRequirement(Requirement, Intrinsic): """A requirement validated by an ALoRA adapter; falls back to LLM-as-a-Judge only on error. - If an exception is thrown during the ALoRA execution path, ``mellea`` will + If an exception is thrown during the ALoRA execution path, `mellea` will fall back to LLMaJ. That is the only case where LLMaJ will be used. Args: description (str): Human-readable requirement description. intrinsic_name (str | None): Name of the ALoRA intrinsic to use. - Defaults to ``"requirement_check"``. + Defaults to `"requirement_check"`. Attributes: - use_aloras (bool): Always ``True``; this class always attempts to use + use_aloras (bool): Always `True`; this class always attempts to use ALoRA adapters for validation. """ @@ -89,13 +89,13 @@ def reqify(r: str | Requirement) -> Requirement: This is a utility method for functions that allow you to pass in Requirements as either explicit Requirement objects or strings that you intend to be interpreted as requirements. Args: - r: A ``Requirement`` object or a plain string description to wrap as one. + r: A `Requirement` object or a plain string description to wrap as one. Returns: - A ``Requirement`` instance. + A `Requirement` instance. Raises: - Exception: If ``r`` is neither a ``str`` nor a ``Requirement`` instance. + Exception: If `r` is neither a `str` nor a `Requirement` instance. """ if type(r) is str: return Requirement(r) @@ -106,27 +106,27 @@ def reqify(r: str | Requirement) -> Requirement: def req(*args, **kwargs) -> Requirement: - """Shorthand for ``Requirement.__init__``. + """Shorthand for `Requirement.__init__`. Args: - *args: Positional arguments forwarded to ``Requirement.__init__``. - **kwargs: Keyword arguments forwarded to ``Requirement.__init__``. + *args: Positional arguments forwarded to `Requirement.__init__`. + **kwargs: Keyword arguments forwarded to `Requirement.__init__`. Returns: - A new ``Requirement`` instance. + A new `Requirement` instance. """ return Requirement(*args, **kwargs) def check(*args, **kwargs) -> Requirement: - """Shorthand for ``Requirement.__init__(..., check_only=True)``. + """Shorthand for `Requirement.__init__(..., check_only=True)`. Args: - *args: Positional arguments forwarded to ``Requirement.__init__``. - **kwargs: Keyword arguments forwarded to ``Requirement.__init__``. + *args: Positional arguments forwarded to `Requirement.__init__`. + **kwargs: Keyword arguments forwarded to `Requirement.__init__`. Returns: - A new ``Requirement`` instance with ``check_only=True``. + A new `Requirement` instance with `check_only=True`. """ return Requirement(*args, **kwargs, check_only=True) @@ -163,11 +163,11 @@ def simple_validate( reason: only used if the provided function returns a bool; if the validation function fails, a static reason for that failure to give to the llm when repairing Returns: - A validation function that takes a ``Context`` and returns a ``ValidationResult``. + A validation function that takes a `Context` and returns a `ValidationResult`. Raises: - ValueError: If ``fn`` returns a type other than ``bool`` or - ``tuple[bool, str]``. + ValueError: If `fn` returns a type other than `bool` or + `tuple[bool, str]`. """ def validate(ctx: Context) -> ValidationResult: diff --git a/mellea/stdlib/requirements/safety/guardian.py b/mellea/stdlib/requirements/safety/guardian.py index 26cd632af..969a62acd 100644 --- a/mellea/stdlib/requirements/safety/guardian.py +++ b/mellea/stdlib/requirements/safety/guardian.py @@ -51,7 +51,7 @@ def get_available_risks(cls) -> list[str]: """Return a list of all available risk type identifiers. Returns: - list[str]: String values of all ``GuardianRisk`` enum members. + list[str]: String values of all `GuardianRisk` enum members. """ return [risk.value for risk in cls] @@ -90,17 +90,17 @@ class GuardianCheck(Requirement): Args: risk (str | GuardianRisk | None): The type of risk to check for. Required - unless ``custom_criteria`` is provided. - backend_type (BackendType): Backend type to use -- ``"ollama"`` or - ``"huggingface"``. + unless `custom_criteria` is provided. + backend_type (BackendType): Backend type to use -- `"ollama"` or + `"huggingface"`. model_version (str | None): Specific Guardian model version. Defaults to the appropriate 8B model for the chosen backend. device (str | None): Device string for HuggingFace inference (e.g. - ``"cuda"``). + `"cuda"`). ollama_url (str): Base URL for the Ollama server. thinking (bool): Enable chain-of-thought reasoning mode in the Guardian model. custom_criteria (str | None): Free-text criteria string used in place of a - standard ``GuardianRisk`` value. + standard `GuardianRisk` value. context_text (str | None): Context document for groundedness checks. tools (list[dict] | None): Tool schemas for function-call validation. backend (Backend | None): Pre-initialised backend instance to reuse; avoids @@ -202,8 +202,8 @@ def __init__( def get_effective_risk(self) -> str: """Return the effective risk criteria to use for validation. - Returns the ``custom_criteria`` string when one was provided, otherwise - returns the ``risk`` identifier set during initialisation. + Returns the `custom_criteria` string when one was provided, otherwise + returns the `risk` identifier set during initialisation. Returns: str: The active risk/criteria string forwarded to the Guardian model. @@ -215,7 +215,7 @@ def get_available_risks(cls) -> list[str]: """Return a list of all available standard risk type identifiers. Returns: - list[str]: String values of all ``GuardianRisk`` enum members. + list[str]: String values of all `GuardianRisk` enum members. """ return GuardianRisk.get_available_risks() @@ -250,7 +250,7 @@ async def validate( """Validate a conversation using Granite Guardian via the selected backend. Builds a minimal chat context from the current session context, invokes the - Guardian model, and parses its ``yes/no`` output. A ``"No"`` + Guardian model, and parses its `yes/no` output. A `"No"` label (risk not detected) is treated as a passing validation result. Args: @@ -263,8 +263,8 @@ async def validate( Guardian backend call. Returns: - ValidationResult: ``result=True`` when the content is considered safe - (Guardian returns ``"No"``), ``result=False`` otherwise. + ValidationResult: `result=True` when the content is considered safe + (Guardian returns `"No"`), `result=False` otherwise. """ logger = self._logger diff --git a/mellea/stdlib/requirements/tool_reqs.py b/mellea/stdlib/requirements/tool_reqs.py index 007b60114..59b3ed2cd 100644 --- a/mellea/stdlib/requirements/tool_reqs.py +++ b/mellea/stdlib/requirements/tool_reqs.py @@ -1,9 +1,9 @@ -"""``Requirement`` factories for tool-use validation. +"""`Requirement` factories for tool-use validation. -Provides ``uses_tool``, a ``Requirement`` factory that validates whether a model +Provides `uses_tool`, a `Requirement` factory that validates whether a model response includes a call to a specified tool — useful when you need to enforce tool invocation via rejection sampling rather than relying solely on the model's -``tool_choice`` setting. Also provides ``tool_arg_validator``, which validates the +`tool_choice` setting. Also provides `tool_arg_validator`, which validates the value of a specific argument to a named tool. Both accept either the tool's string name or its callable. """ @@ -30,10 +30,10 @@ def uses_tool(tool_name: str | Callable, check_only: bool = False) -> Requiremen tool_name: The tool that must be called; this can be either the name of the tool or the Callable for the tool. check_only: Propagates to the Requirement. - Use ``tool_choice`` if the OpenAI ``tool_choice`` model option is supported by your model and inference engine. + Use `tool_choice` if the OpenAI `tool_choice` model option is supported by your model and inference engine. Returns: - A ``Requirement`` that validates whether the specified tool was called. + A `Requirement` that validates whether the specified tool was called. """ tool_name = _name2str(tool_name) @@ -74,7 +74,7 @@ def tool_arg_validator( 2. should this be done automatically when the user provides asserts in their function body? Returns: - A ``Requirement`` that validates the specified tool argument. + A `Requirement` that validates the specified tool argument. """ if tool_name: tool_name = _name2str(tool_name) diff --git a/mellea/stdlib/sampling/base.py b/mellea/stdlib/sampling/base.py index 59c044dd0..dabe92d43 100644 --- a/mellea/stdlib/sampling/base.py +++ b/mellea/stdlib/sampling/base.py @@ -44,7 +44,7 @@ class BaseSamplingStrategy(SamplingStrategy): Args: loop_budget (int): Maximum number of generate/validate cycles. Must be - greater than 0. Defaults to ``1``. + greater than 0. Defaults to `1`. requirements (list[Requirement] | None): Global requirements evaluated on every sample. When set, overrides per-call requirements. diff --git a/mellea/stdlib/sampling/budget_forcing.py b/mellea/stdlib/sampling/budget_forcing.py index d17402e57..295d547c2 100644 --- a/mellea/stdlib/sampling/budget_forcing.py +++ b/mellea/stdlib/sampling/budget_forcing.py @@ -25,25 +25,25 @@ class BudgetForcingSamplingStrategy(RejectionSamplingStrategy): """Sampling strategy that enforces a token budget for chain-of-thought reasoning. - Extends ``RejectionSamplingStrategy`` with explicit control over the ```` + Extends `RejectionSamplingStrategy` with explicit control over the `` block size and the answer block size. On each loop iteration, - ``think_budget_forcing`` interleaves forced-thinking and final-answer generation, + `think_budget_forcing` interleaves forced-thinking and final-answer generation, after which the standard rejection-sampling validation pass determines whether to accept or retry. - Currently only supports the ``OllamaModelBackend``. + Currently only supports the `OllamaModelBackend`. Args: think_max_tokens (int | None): Tokens allocated for the thinking block. - Defaults to ``4096``. + Defaults to `4096`. answer_max_tokens (int | None): Tokens allocated for the answer block. - ``None`` means unbounded. + `None` means unbounded. start_think_token (str | None): Token opening the thinking block. - Defaults to ``""``. + Defaults to `""`. end_think_token (str | None): Token closing the thinking block. - Defaults to ``""``. + Defaults to `""`. begin_response_token (str | None): Optional token opening the response - block. Defaults to ``""``. + block. Defaults to `""`. end_response_token (str): Token closing the response block. think_more_suffix (str | None): Suffix to force continued thinking. Empty string disables forcing. diff --git a/mellea/stdlib/sampling/majority_voting.py b/mellea/stdlib/sampling/majority_voting.py index 05b377c32..f7b7f0931 100644 --- a/mellea/stdlib/sampling/majority_voting.py +++ b/mellea/stdlib/sampling/majority_voting.py @@ -24,16 +24,16 @@ class BaseMBRDSampling(RejectionSamplingStrategy): Args: number_of_samples (int): Number of samples to generate and use for - majority voting. Defaults to ``8``. - weighted (bool): Not yet implemented. If ``True``, weights scores + majority voting. Defaults to `8`. + weighted (bool): Not yet implemented. If `True`, weights scores before majority vote. loop_budget (int): Inner rejection-sampling loop count. Must be > 0. requirements (list[Requirement] | None): Requirements to validate - against. If ``None``, uses per-call requirements. + against. If `None`, uses per-call requirements. Attributes: symmetric (bool): Whether the similarity metric is symmetric, allowing - the upper-triangle score matrix to be mirrored; always ``True`` for + the upper-triangle score matrix to be mirrored; always `True` for this base class. """ @@ -70,15 +70,15 @@ def compare_strings(self, ref: str, pred: str) -> float: pred (str): The predicted string to evaluate. Returns: - float: A similarity score, typically in ``[0.0, 1.0]`` where ``1.0`` + float: A similarity score, typically in `[0.0, 1.0]` where `1.0` indicates a perfect match. """ def maybe_apply_weighted(self, scr: np.ndarray) -> np.ndarray: - """Apply per-sample weights to the score vector if ``self.weighted`` is ``True``. + """Apply per-sample weights to the score vector if `self.weighted` is `True`. Currently not implemented; the input array is returned unchanged when - ``self.weighted`` is ``True``. + `self.weighted` is `True`. Args: scr (np.ndarray): 1-D array of aggregated similarity scores, one @@ -191,19 +191,19 @@ class MajorityVotingStrategyForMath(BaseMBRDSampling): """MajorityVoting Sampling Strategy for Math Expressions. Args: - number_of_samples (int): Number of samples to generate. Defaults to ``8``. - float_rounding (int): Decimal places for float comparison. Defaults to ``6``. - strict (bool): Enforce strict comparison mode. Defaults to ``True``. + number_of_samples (int): Number of samples to generate. Defaults to `8`. + float_rounding (int): Decimal places for float comparison. Defaults to `6`. + strict (bool): Enforce strict comparison mode. Defaults to `True`. allow_set_relation_comp (bool): Allow set-relation comparisons. Defaults - to ``False``. - weighted (bool): Not yet implemented. Defaults to ``False``. - loop_budget (int): Rejection-sampling loop count. Defaults to ``1``. + to `False`. + weighted (bool): Not yet implemented. Defaults to `False`. + loop_budget (int): Rejection-sampling loop count. Defaults to `1`. requirements (list[Requirement] | None): Requirements to validate against. Attributes: match_types (list[str]): Extraction target types used for parsing math - expressions; always ``["latex", "axpr"]``, computed at init. - symmetric (bool): Inherited from ``BaseMBRDSampling``; always ``True`` + expressions; always `["latex", "axpr"]`, computed at init. + symmetric (bool): Inherited from `BaseMBRDSampling`; always `True` for this strategy (set explicitly at init). """ @@ -253,16 +253,16 @@ def compare_strings(self, ref: str, pred: str) -> float: """Compare two strings using math-aware extraction and verification. Parses both strings into mathematical expressions using the configured - ``match_types`` (latex and/or expr), then verifies equivalence via - ``math_verify.verify``. + `match_types` (latex and/or expr), then verifies equivalence via + `math_verify.verify`. Args: ref (str): The reference (gold) string containing a math expression. pred (str): The predicted string to compare against the reference. Returns: - float: ``1.0`` if the expressions are considered equivalent, - ``0.0`` otherwise. + float: `1.0` if the expressions are considered equivalent, + `0.0` otherwise. """ # Convert string match_types to ExtractionTarget objects extraction_targets = [] @@ -291,16 +291,16 @@ class MBRDRougeLStrategy(BaseMBRDSampling): """Sampling Strategy that uses RougeL to compute symbol-level distances for majority voting. Args: - number_of_samples (int): Number of samples to generate. Defaults to ``8``. - weighted (bool): Not yet implemented. Defaults to ``False``. - loop_budget (int): Rejection-sampling loop count. Defaults to ``1``. + number_of_samples (int): Number of samples to generate. Defaults to `8`. + weighted (bool): Not yet implemented. Defaults to `False`. + loop_budget (int): Rejection-sampling loop count. Defaults to `1`. requirements (list[Requirement] | None): Requirements to validate against. Attributes: - match_types (list[str]): Rouge metric names used for scoring (``["rougeL"]``). - scorer (RougeScorer): Pre-configured ``RougeScorer`` instance used for + match_types (list[str]): Rouge metric names used for scoring (`["rougeL"]`). + scorer (RougeScorer): Pre-configured `RougeScorer` instance used for pairwise string comparison. - symmetric (bool): Inherited from ``BaseMBRDSampling``; always ``True`` for + symmetric (bool): Inherited from `BaseMBRDSampling`; always `True` for RougeL (the score is symmetric by construction). """ @@ -339,7 +339,7 @@ def compare_strings(self, ref: str, pred: str) -> float: pred (str): The predicted string to evaluate. Returns: - float: RougeL F-measure score in the range ``[0.0, 1.0]``. + float: RougeL F-measure score in the range `[0.0, 1.0]`. """ scr: float = self.scorer.score(ref, pred)[self.match_types[-1]].fmeasure return scr diff --git a/mellea/stdlib/sampling/sampling_algos/budget_forcing_alg.py b/mellea/stdlib/sampling/sampling_algos/budget_forcing_alg.py index 7d211e386..dc1fad2c8 100644 --- a/mellea/stdlib/sampling/sampling_algos/budget_forcing_alg.py +++ b/mellea/stdlib/sampling/sampling_algos/budget_forcing_alg.py @@ -1,11 +1,11 @@ """Budget-forcing generation algorithm for thinking models. -Implements ``think_budget_forcing``, which extends a model's reasoning phase by +Implements `think_budget_forcing`, which extends a model's reasoning phase by repeatedly appending a "think more" suffix whenever the model attempts to close its -```` block prematurely, following the method proposed in arXiv:2501.19393. -Generation is split into a thinking pass (bounded by ``think_max_tokens``) and an -answer pass (bounded by ``answer_max_tokens``), using the raw completions API of an -``OllamaModelBackend``. +`` block prematurely, following the method proposed in arXiv:2501.19393. +Generation is split into a thinking pass (bounded by `think_max_tokens`) and an +answer pass (bounded by `answer_max_tokens`), using the raw completions API of an +`OllamaModelBackend`. """ from typing import Any @@ -47,20 +47,20 @@ async def think_budget_forcing( Args: backend: OllamaModelBackend instance to use for generation. - action: The last item of the context, passed as an ``action`` instead of as part - of the ``ctx``. See ``docs/dev/generate_signature_decisions.md``. + action: The last item of the context, passed as an `action` instead of as part + of the `ctx`. See `docs/dev/generate_signature_decisions.md`. ctx: The current conversation context. format: Optional Pydantic model for constrained decoding of the response. - tool_calls: If ``True``, tool calling is enabled. + tool_calls: If `True`, tool calling is enabled. think_max_tokens: Budget in number of tokens allocated for the think block. answer_max_tokens: Budget in number of tokens allocated for the summary and - answer block; ``None`` indicates unbounded answer, generating till EoS. - start_think_token: String indicating start of think block, default ````. - end_think_token: String indicating end of think block, default ````. + answer block; `None` indicates unbounded answer, generating till EoS. + start_think_token: String indicating start of think block, default ``. + end_think_token: String indicating end of think block, default ``. begin_response_token: Used by certain models, string indicating start of - response block, e.g. ``""``, default ``""``. + response block, e.g. `""`, default `""`. think_more_suffix: String to append to force continued thinking, e.g. - ``"\nWait"``; if ``None``, additional thinking is not forced (upper-bound + `"\nWait"`; if `None`, additional thinking is not forced (upper-bound budget case). answer_suffix: String to append to force a final answer. model_options: Any model options to upsert into the defaults for this call. @@ -70,7 +70,7 @@ async def think_budget_forcing( Raises: Exception: If the backend returns generation results without the - required ``meta`` information (e.g. token usage counts). + required `meta` information (e.g. token usage counts). Assumptions: - The chat template is applied on prompt, with think mode enabled diff --git a/mellea/stdlib/sampling/sofai.py b/mellea/stdlib/sampling/sofai.py index 0f17e8c1f..793ea49b5 100644 --- a/mellea/stdlib/sampling/sofai.py +++ b/mellea/stdlib/sampling/sofai.py @@ -40,7 +40,7 @@ class SOFAISamplingStrategy(SamplingStrategy): results. If S1 Solver fails after exhausting the budget or shows no improvement, escalates to a single attempt with S2 Solver (slow model). - The strategy leverages ``ValidationResult.reason`` fields to provide targeted + The strategy leverages `ValidationResult.reason` fields to provide targeted feedback for repair, enabling more effective iterative improvement. Args: @@ -51,9 +51,9 @@ class SOFAISamplingStrategy(SamplingStrategy): s2_solver_mode (Literal["fresh_start", "continue_chat", "best_attempt"]): How to invoke the S2 solver when S1 fails. loop_budget (int): Maximum number of S1 repair attempts before escalating - to S2. Must be greater than 0. Defaults to ``3``. + to S2. Must be greater than 0. Defaults to `3`. judge_backend (Backend | None): Optional backend for LLM-as-Judge - validation. If ``None``, falls back to the session backend. + validation. If `None`, falls back to the session backend. feedback_strategy (Literal["simple", "first_error", "all_errors"]): Detail level of repair feedback provided to the S1 solver. diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index acdd42c67..574dd9d75 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -1,10 +1,10 @@ -"""``MelleaSession``: the primary entry point for running generative programs. +"""`MelleaSession`: the primary entry point for running generative programs. -``MelleaSession`` wraps a ``Backend`` and a ``Context`` and exposes high-level methods -(``act``, ``instruct``, ``sample``) that drive the generate-validate-repair loop. It -also manages a global context variable (accessible via ``get_session()``) so that +`MelleaSession` wraps a `Backend` and a `Context` and exposes high-level methods +(`act`, `instruct`, `sample`) that drive the generate-validate-repair loop. It +also manages a global context variable (accessible via `get_session()`) so that nested components can reach the current session without explicit threading. Use -``start_session(...)`` as a context manager to create and automatically clean up a +`start_session(...)` as a context manager to create and automatically clean up a session. """ @@ -53,7 +53,7 @@ def get_session() -> MelleaSession: """Get the current session from context. Returns: - The currently active ``MelleaSession``. + The currently active `MelleaSession`. Raises: RuntimeError: If no session is currently active. @@ -70,16 +70,16 @@ def backend_name_to_class(name: str) -> Any: """Resolves backend names to Backend classes. Args: - name: Short backend name, e.g. ``"ollama"``, ``"hf"``, ``"openai"``, - ``"watsonx"``, or ``"litellm"``. + name: Short backend name, e.g. `"ollama"`, `"hf"`, `"openai"`, + `"watsonx"`, or `"litellm"`. Returns: - The corresponding ``Backend`` class, or ``None`` if the name is unrecognised. + The corresponding `Backend` class, or `None` if the name is unrecognised. Raises: ImportError: If the requested backend has optional dependencies that are - not installed (e.g. ``mellea[hf]``, ``mellea[watsonx]``, or - ``mellea[litellm]``). + not installed (e.g. `mellea[hf]`, `mellea[watsonx]`, or + `mellea[litellm]`). """ if name == "ollama": from ..backends.ollama import OllamaModelBackend @@ -154,8 +154,8 @@ def start_session( model_options: Additional model configuration options that will be passed to the backend (e.g., temperature, max_tokens, etc.). plugins: Optional list of plugins scoped to this session. Accepts - ``@hook``-decorated functions, ``@plugin``-decorated class instances, - ``MelleaPlugin`` instances, or ``PluginSet`` instances. + `@hook`-decorated functions, `@plugin`-decorated class instances, + `MelleaPlugin` instances, or `PluginSet` instances. **backend_kwargs: Additional keyword arguments passed to the backend constructor. Returns: @@ -163,13 +163,13 @@ def start_session( or called directly with session methods. Raises: - Exception: If ``backend_name`` is not one of the recognised backend + Exception: If `backend_name` is not one of the recognised backend identifiers. ImportError: If the requested backend requires optional dependencies that are not installed. Examples: - ```python + ``python # Basic usage with default settings with start_session() as session: response = session.instruct("Explain quantum computing") @@ -188,7 +188,7 @@ def start_session( session = start_session() response = session.instruct("Explain quantum computing") session.cleanup() - ``` + `` """ logger = FancyLogger.get_logger() @@ -288,11 +288,11 @@ class MelleaSession: backend (Backend): The backend to use for all model inference in this session. ctx (Context | None): The conversation context. Defaults to a new - ``SimpleContext`` if ``None``. + `SimpleContext` if `None`. Attributes: - ctx (Context): The active conversation context; never ``None`` (defaults - to a fresh ``SimpleContext`` when ``None`` is passed). Updated after + ctx (Context): The active conversation context; never `None` (defaults + to a fresh `SimpleContext` when `None` is passed). Updated after every call that produces model output. id (str): Unique session UUID assigned at construction. """ @@ -346,7 +346,7 @@ def clone(self): a copy of the current session. Keeps the context, backend, and session logger. Examples: - ```python + ``python >>> from mellea import start_session >>> m = start_session() >>> m.instruct("What is 2x2?") @@ -360,15 +360,15 @@ def clone(self): >>> out = m2.instruct("Multiply that by 3") >>> print(out) ... 12 - ``` + `` """ return copy(self) def reset(self): """Reset the context state to a fresh, empty context of the same type. - Fires the ``SESSION_RESET`` plugin hook if any plugins are registered, then - replaces ``self.ctx`` with the result of ``ctx.reset_to_new()``, discarding + Fires the `SESSION_RESET` plugin hook if any plugins are registered, then + replaces `self.ctx` with the result of `ctx.reset_to_new()`, discarding all accumulated conversation history. """ if has_plugins(HookType.SESSION_RESET): @@ -543,8 +543,8 @@ def instruct( images: A list of images to be used in the instruction or None if none. Returns: - A ``ModelOutputThunk`` if ``return_sampling_results`` is ``False``, - else a ``SamplingResult``. + A `ModelOutputThunk` if `return_sampling_results` is `False`, + else a `SamplingResult`. """ r = mfuncs.instruct( description, @@ -588,15 +588,15 @@ def chat( Args: content: The message text to send. - role: The role for the outgoing message (default ``"user"``). + role: The role for the outgoing message (default `"user"`). images: Optional list of images to include in the message. - user_variables: Optional Jinja variable substitutions applied to ``content``. + user_variables: Optional Jinja variable substitutions applied to `content`. format: Optional Pydantic model for constrained decoding of the response. model_options: Additional model options to merge with backend defaults. tool_calls: If true, tool calling is enabled. Returns: - The assistant ``Message`` response. + The assistant `Message` response. """ result, context = mfuncs.chat( content=content, @@ -626,15 +626,15 @@ def validate( """Validates a set of requirements over the output (if provided) or the current context (if the output is not provided). Args: - reqs: A single ``Requirement`` or a list of them to validate. - output: Optional model output ``CBlock`` to validate against instead of the context. + reqs: A single `Requirement` or a list of them to validate. + output: Optional model output `CBlock` to validate against instead of the context. format: Optional Pydantic model for constrained decoding. model_options: Additional model options to merge with backend defaults. generate_logs: Optional list to append generation logs to. - input: Optional input ``CBlock`` to include alongside ``output`` when validating. + input: Optional input `CBlock` to include alongside `output` when validating. Returns: - List of ``ValidationResult`` objects, one per requirement. + List of `ValidationResult` objects, one per requirement. """ return mfuncs.validate( reqs=reqs, @@ -856,8 +856,8 @@ async def ainstruct( images: A list of images to be used in the instruction or None if none. Returns: - A ``ModelOutputThunk`` if ``return_sampling_results`` is ``False``, - else a ``SamplingResult``. + A `ModelOutputThunk` if `return_sampling_results` is `False`, + else a `SamplingResult`. """ r = await mfuncs.ainstruct( description, @@ -901,15 +901,15 @@ async def achat( Args: content: The message text to send. - role: The role for the outgoing message (default ``"user"``). + role: The role for the outgoing message (default `"user"`). images: Optional list of images to include in the message. - user_variables: Optional Jinja variable substitutions applied to ``content``. + user_variables: Optional Jinja variable substitutions applied to `content`. format: Optional Pydantic model for constrained decoding of the response. model_options: Additional model options to merge with backend defaults. tool_calls: If true, tool calling is enabled. Returns: - The assistant ``Message`` response. + The assistant `Message` response. """ result, context = await mfuncs.achat( content=content, @@ -939,15 +939,15 @@ async def avalidate( """Validates a set of requirements over the output (if provided) or the current context (if the output is not provided). Args: - reqs: A single ``Requirement`` or a list of them to validate. - output: Optional model output ``CBlock`` to validate against instead of the context. + reqs: A single `Requirement` or a list of them to validate. + output: Optional model output `CBlock` to validate against instead of the context. format: Optional Pydantic model for constrained decoding. model_options: Additional model options to merge with backend defaults. generate_logs: Optional list to append generation logs to. - input: Optional input ``CBlock`` to include alongside ``output`` when validating. + input: Optional input `CBlock` to include alongside `output` when validating. Returns: - List of ``ValidationResult`` objects, one per requirement. + List of `ValidationResult` objects, one per requirement. """ return await mfuncs.avalidate( reqs=reqs, @@ -1029,13 +1029,13 @@ async def atransform( def powerup(cls, powerup_cls: type): """Appends methods in a class object `powerup_cls` to MelleaSession. - Iterates over all functions defined on ``powerup_cls`` and attaches each - one as a method on the ``MelleaSession`` class, effectively extending + Iterates over all functions defined on `powerup_cls` and attaches each + one as a method on the `MelleaSession` class, effectively extending the session with domain-specific helpers at runtime. Args: powerup_cls (type): A class whose functions should be added to - ``MelleaSession`` as instance methods. + `MelleaSession` as instance methods. """ for name, fn in inspect.getmembers(powerup_cls, predicate=inspect.isfunction): setattr(cls, name, fn) diff --git a/mellea/stdlib/tools/interpreter.py b/mellea/stdlib/tools/interpreter.py index 237545621..a799ec715 100644 --- a/mellea/stdlib/tools/interpreter.py +++ b/mellea/stdlib/tools/interpreter.py @@ -1,13 +1,13 @@ """Code interpreter tool and execution environments for agentic workflows. -Provides ``ExecutionResult`` (capturing stdout, stderr, success, and optional static -analysis output) and three concrete ``ExecutionEnvironment`` implementations: -``StaticAnalysisEnvironment`` (parse and import-check only, no execution), -``UnsafeEnvironment`` (subprocess execution in the current Python environment), and -``LLMSandboxEnvironment`` (Docker-isolated execution via ``llm-sandbox``). All -environments support an optional ``allowed_imports`` allowlist. The top-level -``code_interpreter`` and ``local_code_interpreter`` functions are ready to be wrapped -as ``MelleaTool`` instances for ReACT or other agentic loops. +Provides `ExecutionResult` (capturing stdout, stderr, success, and optional static +analysis output) and three concrete `ExecutionEnvironment` implementations: +`StaticAnalysisEnvironment` (parse and import-check only, no execution), +`UnsafeEnvironment` (subprocess execution in the current Python environment), and +`LLMSandboxEnvironment` (Docker-isolated execution via `llm-sandbox`). All +environments support an optional `allowed_imports` allowlist. The top-level +`code_interpreter` and `local_code_interpreter` functions are ready to be wrapped +as `MelleaTool` instances for ReACT or other agentic loops. """ import ast @@ -40,13 +40,13 @@ class ExecutionResult: TODO: should we also be trying to pass back the value of the final expression evaluated, or the value of locals() and globals()? Args: - success (bool): ``True`` if execution succeeded (exit code 0 or - static-analysis passed); ``False`` otherwise. - stdout (str | None): Captured standard output, or ``None`` if + success (bool): `True` if execution succeeded (exit code 0 or + static-analysis passed); `False` otherwise. + stdout (str | None): Captured standard output, or `None` if execution was skipped. - stderr (str | None): Captured standard error, or ``None`` if + stderr (str | None): Captured standard error, or `None` if execution was skipped. - skipped (bool): ``True`` when execution was not attempted. + skipped (bool): `True` when execution was not attempted. skip_message (str | None): Explanation of why execution was skipped. analysis_result (Any | None): Optional payload from static-analysis environments. @@ -95,7 +95,7 @@ class ExecutionEnvironment(ABC): Args: allowed_imports (list[str] | None): Allowlist of top-level module names - that generated code may import. ``None`` disables the import check. + that generated code may import. `None` disables the import check. """ @@ -129,8 +129,8 @@ def execute(self, code: str, timeout: int) -> ExecutionResult: compatibility. Returns: - ExecutionResult: Result with ``skipped=True`` and the parsed AST in - ``analysis_result`` on success, or a syntax-error description on + ExecutionResult: Result with `skipped=True` and the parsed AST in + `analysis_result` on success, or a syntax-error description on failure. """ try: @@ -243,9 +243,9 @@ class LLMSandboxEnvironment(ExecutionEnvironment): def execute(self, code: str, timeout: int) -> ExecutionResult: """Execute code using llm-sandbox in an isolated Docker container. - Checks the import allowlist first, then delegates to a ``SandboxSession`` - from the ``llm-sandbox`` package. Returns a skipped result if - ``llm-sandbox`` is not installed. + Checks the import allowlist first, then delegates to a `SandboxSession` + from the `llm-sandbox` package. Returns a skipped result if + `llm-sandbox` is not installed. Args: code (str): The Python source code to execute. @@ -339,7 +339,7 @@ def code_interpreter(code: str) -> ExecutionResult: code: The Python code to execute. Returns: - An ``ExecutionResult`` with stdout, stderr, and a success flag. + An `ExecutionResult` with stdout, stderr, and a success flag. """ exec_env = LLMSandboxEnvironment(allowed_imports=None) return exec_env.execute(code, 60) @@ -352,7 +352,7 @@ def local_code_interpreter(code: str) -> ExecutionResult: code: The Python code to execute. Returns: - An ``ExecutionResult`` with stdout, stderr, and a success flag. + An `ExecutionResult` with stdout, stderr, and a success flag. """ exec_env = UnsafeEnvironment(allowed_imports=None) return exec_env.execute(code, 60) diff --git a/mellea/telemetry/tracing.py b/mellea/telemetry/tracing.py index cab2c2bc6..d6ae9db6d 100644 --- a/mellea/telemetry/tracing.py +++ b/mellea/telemetry/tracing.py @@ -96,7 +96,7 @@ def is_application_tracing_enabled() -> bool: Returns: True if application tracing has been enabled via the - ``MELLEA_TRACE_APPLICATION`` environment variable. + `MELLEA_TRACE_APPLICATION` environment variable. """ return _TRACE_APPLICATION_ENABLED @@ -106,7 +106,7 @@ def is_backend_tracing_enabled() -> bool: Returns: True if backend tracing has been enabled via the - ``MELLEA_TRACE_BACKEND`` environment variable. + `MELLEA_TRACE_BACKEND` environment variable. """ return _TRACE_BACKEND_ENABLED