diff --git a/docs/dev/taint_tracking.md b/docs/dev/taint_tracking.md new file mode 100644 index 000000000..0ec7849eb --- /dev/null +++ b/docs/dev/taint_tracking.md @@ -0,0 +1,112 @@ +# Taint Tracking - Backend Security + +Mellea backends implement thread security using the **SecLevel** model with capability-based access control and taint tracking. Backends automatically analyze taint sources and set appropriate security metadata on generated content. + +## Security Model + +The security system uses three types of security levels: + +```python +SecLevel := None | Classified of AccessType | TaintedBy of (list[CBlock | Component] | None) +``` + +- **SecLevel.none()**: Safe content with no restrictions +- **SecLevel.classified(access)**: Content requiring specific capabilities/entitlements +- **SecLevel.tainted_by(sources)**: Content tainted by one or more CBlocks/Components (list), or None for root tainted nodes + +## Backend Implementation + +All backends follow the same pattern when creating `ModelOutputThunk`: + +```python +# Compute taint sources from action and context +sources = taint_sources(action, ctx) + +# Set security level based on taint sources +from mellea.security import SecLevel +sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + +output = ModelOutputThunk( + value=None, + sec_level=sec_level, + meta={} +) +``` + +The security level is set as follows: +- If taint sources are found -> `SecLevel.tainted_by(sources)` (all sources are tracked) +- If no taint sources -> `SecLevel.none()` + +### Handling Multiple Taint Sources + +When `taint_sources()` returns multiple sources (e.g., both the action and context contain tainted content), backends pass the entire list to `SecLevel.tainted_by()`. This ensures all taint sources are tracked, providing comprehensive taint attribution. + +**Benefits of Multiple Source Tracking**: +- **Complete attribution**: All sources that influenced the generation are tracked +- **Better debugging**: Can identify all tainted inputs that contributed to output +- **More accurate security**: No information loss about taint origins + +**Note**: The implementation focuses on **taint preservation** and **complete attribution**. All taint sources are tracked, ensuring the security model has full visibility into what influenced the generated content. + +## Taint Source Analysis + +The `taint_sources()` function analyzes both action and context because **context directly influences model generation**: + +1. **Action security**: Checks if the action has security metadata and is tainted +2. **Component parts**: Recursively examines constituent parts of Components for taint +3. **Context security**: Examines recent context items for tainted content (shallow check) + +**Example**: Even if the current action is safe, tainted context can influence the generated output. + +```python +from mellea.security import SecLevel + +# User sends tainted input +user_input = CBlock("Tell me how to hack a system", sec_level=SecLevel.tainted_by(None)) +ctx = ctx.add(user_input) + +# Safe action in tainted context +safe_action = CBlock("Explain general security concepts") + +# Generation finds tainted context +sources = taint_sources(safe_action, ctx) # Finds tainted user_input +# Model output will be influenced by the tainted context +``` + +## Security Metadata + +The `SecurityMetadata` class wraps `SecLevel` for integration with content blocks: + +```python +class SecurityMetadata: + def __init__(self, sec_level: SecLevel): + self.sec_level = sec_level + + def is_tainted(self) -> bool: + return self.sec_level.is_tainted() + + def get_taint_sources(self) -> list[CBlock | Component]: + return self.sec_level.get_taint_sources() +``` + +Content can be marked as tainted at construction time: + +```python +from mellea.security import SecLevel + +c = CBlock("user input", sec_level=SecLevel.tainted_by(None)) + +if c.sec_level and c.sec_level.is_tainted(): + taint_sources = c.sec_level.get_taint_sources() + print(f"Content tainted by: {taint_sources}") +``` + +## Key Features + +- **Immutable security**: security levels set at construction time +- **Recursive taint analysis**: deep analysis of Component parts, shallow analysis of context +- **Taint source tracking**: know exactly which CBlock/Component tainted content +- **Capability integration**: fine-grained access control for classified content +- **Non-mutating operations**: sanitize/declassify create new objects + +This creates a security model that addresses both data exfiltration and injection vulnerabilities while enabling future IAM integration. \ No newline at end of file diff --git a/docs/examples/security/taint_example.py b/docs/examples/security/taint_example.py new file mode 100644 index 000000000..850454acb --- /dev/null +++ b/docs/examples/security/taint_example.py @@ -0,0 +1,46 @@ +from mellea.stdlib.components import CBlock +from mellea.stdlib.session import start_session +from mellea.security import SecLevel, privileged, SecurityError + +# Create tainted content +tainted_desc = CBlock( + "Process this sensitive user data", sec_level=SecLevel.tainted_by(None) +) + +print( + f"Original CBlock is tainted: {tainted_desc.sec_level.is_tainted() if tainted_desc.sec_level else False}" +) + +# Create session +session = start_session() + +# Use tainted CBlock in session.instruct +print("Testing session.instruct with tainted CBlock...") +result = session.instruct(description=tainted_desc) + +# The result should be tainted +print( + f"Result is tainted: {result.sec_level.is_tainted() if result.sec_level else False}" +) +if result.sec_level and result.sec_level.is_tainted(): + taint_sources = result.sec_level.get_taint_sources() + print(f"Taint sources: {taint_sources}") + print("✅ SUCCESS: Taint preserved!") +else: + print("❌ FAIL: Result should be tainted but isn't!") + + +# Mock privileged function that requires un-tainted input +@privileged +def process_un_tainted_data(data: CBlock) -> str: + """A function that requires un-tainted input.""" + return f"Processed: {data.value}" + + +print("\nTesting privileged function with tainted result...") +try: + # This should raise a SecurityError + processed = process_un_tainted_data(result) + print("❌ FAIL: Should have raised SecurityError!") +except SecurityError as e: + print(f"✅ SUCCESS: SecurityError raised - {e}") diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index c7886e7e2..012340852 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -43,6 +43,7 @@ from ..core.base import AbstractMelleaTool from ..formatters import ChatFormatter, TemplateFormatter, granite as granite_formatters from ..helpers import message_to_openai_message, messages_to_docs, send_to_queue +from ..security import SecLevel, taint_sources from ..stdlib.components import Intrinsic, Message from ..stdlib.requirements import ALoraRequirement, LLMaJRequirement from ..telemetry.backend_instrumentation import ( @@ -514,8 +515,13 @@ async def _generate_from_intrinsic( other_input, ) - output = ModelOutputThunk(None) + # Compute taint sources from action and context + sources = taint_sources(action, ctx) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + + output = ModelOutputThunk(value=None, sec_level=sec_level, meta={}) output._start = datetime.datetime.now() + output._context = ctx.view_for_generation() output._action = action output._model_options = model_options @@ -789,6 +795,11 @@ async def _generate_from_context_with_kv_cache( output = ModelOutputThunk(None) output._start = datetime.datetime.now() + # Compute taint sources from action and context + sources = taint_sources(action, ctx) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + + output = ModelOutputThunk(value=None, sec_level=sec_level, meta={}) output._context = ctx.view_for_generation() output._action = action output._model_options = model_options @@ -934,6 +945,11 @@ async def _generate_from_context_standard( output = ModelOutputThunk(None) output._start = datetime.datetime.now() + # Compute taint sources from action and context + sources = taint_sources(action, ctx) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + + output = ModelOutputThunk(value=None, sec_level=sec_level, meta={}) output._context = ctx.view_for_generation() output._action = action output._model_options = model_options @@ -1251,8 +1267,12 @@ async def generate_from_raw( for i, decoded_result in enumerate(decoded_results): n_prompt_tokens = inputs["input_ids"][i].size(0) # type: ignore n_completion_tokens = len(sequences_to_decode[i]) + sources = taint_sources(actions[i], ctx) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + result = ModelOutputThunk( value=decoded_result, + sec_level=sec_level, meta={ "usage": { "prompt_tokens": n_prompt_tokens, # type: ignore diff --git a/mellea/backends/litellm.py b/mellea/backends/litellm.py index 4f5bac389..d95083070 100644 --- a/mellea/backends/litellm.py +++ b/mellea/backends/litellm.py @@ -33,6 +33,7 @@ message_to_openai_message, send_to_queue, ) +from ..security import SecLevel, taint_sources from ..stdlib.components import Message from ..stdlib.requirements import ALoraRequirement from ..telemetry.backend_instrumentation import ( @@ -329,7 +330,11 @@ async def _generate_from_chat_context_standard( **model_specific_options, ) - output = ModelOutputThunk(None) + # Compute taint sources from action and context + sources = taint_sources(action, ctx) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + + output = ModelOutputThunk(value=None, sec_level=sec_level, meta={}) output._start = datetime.datetime.now() output._context = linearized_context output._action = action @@ -624,16 +629,22 @@ async def generate_from_raw( ) for res, action, prompt in zip(responses, actions, prompts): - output = ModelOutputThunk(res.text) # type: ignore + sources = taint_sources(action, None) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + + output = ModelOutputThunk( + value=res.text, # type: ignore + sec_level=sec_level, + meta={ + "litellm_chat_response": res.model_dump(), + "usage": completion_response.usage.model_dump() + if completion_response.usage + else None, + }, + ) output._context = None # There is no context for generate_from_raw for now output._action = action output._model_options = model_opts - output._meta = { - "litellm_chat_response": res.model_dump(), - "usage": completion_response.usage.model_dump() - if completion_response.usage - else None, - } output.parsed_repr = ( action.parse(output) if isinstance(action, Component) else output.value diff --git a/mellea/backends/ollama.py b/mellea/backends/ollama.py index afe80ec84..a27fe1098 100644 --- a/mellea/backends/ollama.py +++ b/mellea/backends/ollama.py @@ -25,6 +25,7 @@ from ..core.base import AbstractMelleaTool from ..formatters import ChatFormatter, TemplateFormatter from ..helpers import ClientCache, get_current_event_loop, send_to_queue +from ..security import SecLevel, taint_sources from ..stdlib.components import Message from ..stdlib.requirements import ALoraRequirement from ..telemetry.backend_instrumentation import ( @@ -364,7 +365,11 @@ async def generate_from_chat_context( format=_format.model_json_schema() if _format is not None else None, # type: ignore ) # type: ignore - output = ModelOutputThunk(None) + # Compute taint sources from action and context + sources = taint_sources(action, ctx) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + + output = ModelOutputThunk(value=None, sec_level=sec_level, meta={}) output._start = datetime.datetime.now() output._context = linearized_context output._action = action @@ -470,16 +475,20 @@ async def generate_from_raw( for i, response in enumerate(responses): result = None error = None + sources = taint_sources(actions[i], None) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + if isinstance(response, BaseException): FancyLogger.get_logger().warning( f"generate_from_raw: request {i} failed with " f"{type(response).__name__}: {response}" ) - result = ModelOutputThunk(value="") + result = ModelOutputThunk(value="", sec_level=sec_level, meta={}) error = response else: result = ModelOutputThunk( value=response.response, + sec_level=sec_level, meta={ "generate_response": response.model_dump(), "usage": { diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index faadfc454..daa43eb82 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -42,6 +42,7 @@ messages_to_docs, send_to_queue, ) +from ..security import SecLevel, taint_sources from ..stdlib.components import Intrinsic, Message from ..stdlib.requirements import LLMaJRequirement from ..telemetry.backend_instrumentation import ( @@ -471,7 +472,11 @@ async def _generate_from_chat_context_standard( ), ) # type: ignore - output = ModelOutputThunk(None) + # Compute taint sources from action and context + sources = taint_sources(action, ctx) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + + output = ModelOutputThunk(value=None, sec_level=sec_level, meta={}) output._start = datetime.datetime.now() output._context = linearized_context output._action = action @@ -746,16 +751,22 @@ async def generate_from_raw( for response, action, prompt in zip( completion_response.choices, actions, prompts ): - output = ModelOutputThunk(response.text) + sources = taint_sources(action, None) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + + output = ModelOutputThunk( + value=response.text, + sec_level=sec_level, + meta={ + "oai_completion_response": response.model_dump(), + "usage": completion_response.usage.model_dump() + if completion_response.usage + else None, + }, + ) output._context = None # There is no context for generate_from_raw for now output._action = action output._model_options = model_opts - output._meta = { - "oai_completion_response": response.model_dump(), - "usage": completion_response.usage.model_dump() - if completion_response.usage - else None, - } output.parsed_repr = ( action.parse(output) if isinstance(action, Component) else output.value diff --git a/mellea/backends/vllm.py b/mellea/backends/vllm.py index 9ceedd651..0e298898d 100644 --- a/mellea/backends/vllm.py +++ b/mellea/backends/vllm.py @@ -44,6 +44,7 @@ from ..core.base import AbstractMelleaTool from ..formatters import ChatFormatter, TemplateFormatter from ..helpers import get_current_event_loop, send_to_queue +from ..security import SecLevel, taint_sources from .backend import FormatterBackend from .model_options import ModelOption from .tools import ( @@ -325,7 +326,11 @@ async def _generate_from_context_standard( # stream = model_options.get(ModelOption.STREAM, False) # if stream: - output = ModelOutputThunk(None) + # Compute taint sources from action and context + sources = taint_sources(action, ctx) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + + output = ModelOutputThunk(value=None, sec_level=sec_level, meta={}) output._start = datetime.datetime.now() generator = self._model.generate( # type: ignore @@ -486,7 +491,11 @@ async def generate(prompt, request_id): tasks = [generate(p, f"{id(prompts)}-{i}") for i, p in enumerate(prompts)] decoded_results = await asyncio.gather(*tasks) - results = [ModelOutputThunk(value=text) for text in decoded_results] + results = [] + for i, text in enumerate(decoded_results): + sources = taint_sources(actions[i], ctx) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + results.append(ModelOutputThunk(value=text, sec_level=sec_level, meta={})) for i, result in enumerate(results): date = datetime.datetime.now() diff --git a/mellea/backends/watsonx.py b/mellea/backends/watsonx.py index d6ca943e9..bb4829c32 100644 --- a/mellea/backends/watsonx.py +++ b/mellea/backends/watsonx.py @@ -36,6 +36,7 @@ get_current_event_loop, send_to_queue, ) +from ..security import SecLevel, taint_sources from ..stdlib.components import Message from ..stdlib.requirements import ALoraRequirement from ..telemetry.backend_instrumentation import ( @@ -368,7 +369,11 @@ async def generate_from_chat_context( ), ) - output = ModelOutputThunk(None) + # Compute taint sources from action and context + sources = taint_sources(action, ctx) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + + output = ModelOutputThunk(value=None, sec_level=sec_level, meta={}) output._start = datetime.datetime.now() output._context = linearized_context output._action = action @@ -610,8 +615,12 @@ async def generate_from_raw( for i, response in enumerate(responses): output = response["results"][0] + sources = taint_sources(actions[i], ctx) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + result = ModelOutputThunk( value=output["generated_text"], + sec_level=sec_level, meta={ "oai_completion_response": response["results"][0], "usage": { diff --git a/mellea/core/base.py b/mellea/core/base.py index a5c6bc1f8..07b5582cd 100644 --- a/mellea/core/base.py +++ b/mellea/core/base.py @@ -20,6 +20,7 @@ from ..plugins.manager import has_plugins, invoke_hook from ..plugins.types import HookType +from ..security import SecLevel, TaintChecking class CBlock: @@ -29,6 +30,7 @@ def __init__( self, value: str | None, meta: dict[str, Any] | None = None, + sec_level: Any = None, *, cache: bool = False, ): @@ -37,6 +39,7 @@ def __init__( Args: value: the underlying value stored in this CBlock meta: Any meta-information about this CBlock (e.g., the inference engine's Completion object). + sec_level: Optional SecLevel for security metadata cache: If set to `True` then this CBlock's KV cache might be stored by the inference engine. Experimental. """ if value is not None and not isinstance(value, str): @@ -47,6 +50,9 @@ def __init__( meta = {} self._meta = meta + # Store security level directly + self._sec_level: SecLevel | None = sec_level + @property def value(self) -> str | None: """Gets the value of the block.""" @@ -65,6 +71,15 @@ def __repr__(self): """Provides a python-parsable representation of the block (usually).""" return f"CBlock({self.value}, {self._meta.__repr__()})" + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this CBlock. + + Returns: + SecLevel if present, None otherwise + """ + return self._sec_level + class ImageBlock(CBlock): """A `ImageBlock` represents an image (as base64 PNG).""" @@ -137,7 +152,7 @@ class ComponentParseError(Exception): @runtime_checkable -class Component(Protocol, Generic[S]): +class Component(TaintChecking, Protocol, Generic[S]): """A `Component` is a composite data structure that is intended to be represented to an LLM.""" def parts(self) -> list[Component | CBlock]: @@ -151,6 +166,15 @@ def format_for_llm(self) -> TemplateRepresentation | str: """ raise NotImplementedError("format_for_llm isn't implemented by default") + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component. + + Returns: + SecLevel if present, None otherwise + """ + return None + def parse(self, computed: ModelOutputThunk) -> S: """Parse the expected type from a given `ModelOutputThunk`. @@ -183,9 +207,10 @@ def __init__( meta: dict[str, Any] | None = None, parsed_repr: S | None = None, tool_calls: dict[str, ModelToolCall] | None = None, + sec_level: Any = None, ): """Initializes as a cblock, optionally also with a parsed representation from an output formatter.""" - super().__init__(value, meta) + super().__init__(value, meta, sec_level=sec_level) self.parsed_repr: S | None = parsed_repr """Will be non-`None` once computed.""" diff --git a/mellea/core/requirement.py b/mellea/core/requirement.py index 9162c1fca..d0f612674 100644 --- a/mellea/core/requirement.py +++ b/mellea/core/requirement.py @@ -4,6 +4,7 @@ from collections.abc import Callable from copy import copy +from ..security import SecLevel from .backend import Backend, BaseModelSubclass from .base import CBlock, Component, Context, ModelOutputThunk, TemplateRepresentation @@ -112,6 +113,7 @@ def __init__( # Used for validation. Do not manually populate. self._output: str | None = None + self._sec_level: SecLevel | None = None async def validate( self, @@ -149,6 +151,15 @@ async def validate( context=val_ctx, ) + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component. + + Returns: + SecLevel if present, None otherwise + """ + return self._sec_level + def parts(self): """Returns all of the constituent parts of a Requirement.""" return [] diff --git a/mellea/security/__init__.py b/mellea/security/__init__.py new file mode 100644 index 000000000..52b735a11 --- /dev/null +++ b/mellea/security/__init__.py @@ -0,0 +1,27 @@ +"""Security module for mellea. + +This module provides security features for tracking and managing the security +level of content blocks and components in the mellea library. +""" + +from .core import ( + AccessType, + SecLevel, + SecurityError, + TaintChecking, + classified_sources, + declassify, + privileged, + taint_sources, +) + +__all__ = [ + "AccessType", + "SecLevel", + "SecurityError", + "TaintChecking", + "classified_sources", + "declassify", + "privileged", + "taint_sources", +] diff --git a/mellea/security/core.py b/mellea/security/core.py new file mode 100644 index 000000000..26b994cd9 --- /dev/null +++ b/mellea/security/core.py @@ -0,0 +1,335 @@ +"""Core security functionality for mellea. + +This module provides the fundamental security classes and functions for +tracking security levels of content blocks and enforcing security policies. +""" + +import abc +import functools +from collections.abc import Callable +from enum import Enum +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Protocol, + TypeVar, + Union, + runtime_checkable, +) + +if TYPE_CHECKING: + from ..core.base import CBlock, Component + +T = TypeVar("T") + + +class SecLevelType(str, Enum): + """Security level type constants.""" + + NONE = "none" + CLASSIFIED = "classified" + TAINTED_BY = "tainted_by" + + +class AccessType(Generic[T], abc.ABC): + """Abstract base class for access-based security. + + This trait allows integration with IAM systems and provides fine-grained + access control based on entitlements rather than coarse security levels. + """ + + @abc.abstractmethod + def has_access(self, entitlement: T | None) -> bool: + """Check if the given entitlement has access. + + Args: + entitlement: The entitlement to check (e.g., user role, IAM identifier) + + Returns: + True if the entitlement has access, False otherwise + """ + + +class SecLevel(Generic[T]): + """Security level with access-based control and taint tracking. + + SecLevel := None | Classified of AccessType | TaintedBy of (list[CBlock | Component] | None) + """ + + def __init__(self, level_type: SecLevelType | str, data: Any = None): + """Initialize security level. + + Args: + level_type: Type of security level (SecLevelType enum or string) + data: Associated data (AccessType for classified, list[CBlock|Component] for tainted_by) + """ + # Convert string to enum if needed for backward compatibility + if isinstance(level_type, str): + level_type = SecLevelType(level_type) + self.level_type = level_type + self.data = data + + @classmethod + def none(cls) -> "SecLevel": + """Create a SecLevel with no restrictions (safe).""" + return cls(SecLevelType.NONE) + + @classmethod + def classified(cls, access_type: AccessType[T]) -> "SecLevel": + """Create a SecLevel with classified access requirements.""" + return cls(SecLevelType.CLASSIFIED, access_type) + + @classmethod + def tainted_by( + cls, sources: "CBlock | Component | list[CBlock | Component] | None" + ) -> "SecLevel": + """Create a SecLevel tainted by one or more CBlocks or Components. + + Args: + sources: A single CBlock/Component, a list of CBlocks/Components, or None for root nodes. + If a single source is provided, it will be converted to a list internally. + + Returns: + SecLevel with TAINTED_BY type + """ + # Normalize to list: convert single source to list, None to empty list + if sources is None: + sources_list: list[CBlock | Component] = [] + elif isinstance(sources, list): + sources_list = sources + else: + sources_list = [sources] + + return cls(SecLevelType.TAINTED_BY, sources_list) + + def is_tainted(self) -> bool: + """Check if this security level represents tainted content. + + Returns: + True if tainted, False otherwise + """ + return self.level_type == SecLevelType.TAINTED_BY + + def is_classified(self) -> bool: + """Check if this security level represents classified content. + + Returns: + True if classified, False otherwise + """ + return self.level_type == SecLevelType.CLASSIFIED + + def get_access_type(self) -> AccessType[T] | None: + """Get the AccessType for classified content. + + Returns: + The AccessType if this is classified, None otherwise + """ + if self.level_type == SecLevelType.CLASSIFIED: + return self.data + return None + + def get_taint_sources(self) -> "list[CBlock | Component]": + """Get all sources of taint if this is a tainted level. + + Returns: + List of CBlocks or Components that tainted this content, empty list if not tainted + """ + if self.level_type == SecLevelType.TAINTED_BY: + if isinstance(self.data, list): + return self.data + # Handle legacy single-source format (shouldn't happen in new code) + return [self.data] if self.data is not None else [] + return [] + + +class SecurityError(Exception): + """Exception raised for security-related errors.""" + + +@runtime_checkable +class TaintChecking(Protocol): + """Protocol for objects that can provide security level information. + + This protocol allows uniform access to security levels without + relying on hasattr checks or _meta dictionary access. + """ + + @property + def sec_level(self) -> "SecLevel | None": + """Get the security level for this object. + + Returns: + SecLevel if present, None otherwise + """ + ... + + +def _collect_sources_by_predicate( + action: "Component | CBlock", ctx: Any, predicate: Callable[["SecLevel"], bool] +) -> "list[CBlock | Component]": + """Recursively collect CBlocks/Components whose sec_level satisfies predicate. + + Shared logic for taint_sources and classified_sources. Walks action and + context (shallow), recursing into Component.parts(). + """ + from ..core.base import ( + CBlock, + Component, + ) # Import here to avoid circular dependency + + sources = [] + + if isinstance(action, TaintChecking): + sec_level = action.sec_level + if sec_level is not None and predicate(sec_level): + sources.append(action) + + match action: + case CBlock(): + pass + case _ if isinstance(action, Component): + parts = action.parts() + for part in parts: + if isinstance(part, TaintChecking): + sec_level = part.sec_level + if sec_level is not None and predicate(sec_level): + sources.append(part) + if isinstance(part, Component): + nested = _collect_sources_by_predicate(part, None, predicate) + sources.extend(nested) + + if hasattr(ctx, "as_list"): + try: + context_items = ctx.as_list(last_n_components=5) + for item in context_items: + if isinstance(item, CBlock | Component) and isinstance( + item, TaintChecking + ): + sec_level = item.sec_level + if sec_level is not None and predicate(sec_level): + sources.append(item) + if isinstance(item, Component): + nested = _collect_sources_by_predicate(item, None, predicate) + sources.extend(nested) + except Exception: + pass + + return sources + + +def taint_sources(action: "Component | CBlock", ctx: Any) -> "list[CBlock | Component]": + """Compute taint sources from action and context. + + This function examines the action and context to determine what + security sources might be present. It performs recursive analysis + of Component parts and shallow analysis of context to identify + potential taint sources and returns the actual objects that are tainted. + + Args: + action: The action component or content block + ctx: The context containing previous interactions + + Returns: + List of tainted CBlocks or Components + """ + return _collect_sources_by_predicate(action, ctx, lambda sec: sec.is_tainted()) + + +def classified_sources( + action: "Component | CBlock", ctx: Any = None +) -> "list[CBlock | Component]": + """Compute classified sources from action and context. + + Recursively examines the action and context (same structure as + taint_sources) and returns all CBlocks or Components that have + classified security level. + + Args: + action: The action component or content block + ctx: Optional context containing previous interactions (shallow scan) + + Returns: + List of classified CBlocks or Components + """ + return _collect_sources_by_predicate(action, ctx, lambda sec: sec.is_classified()) + + +F = TypeVar("F", bound=Callable[..., Any]) + + +def _raise_if_privilege_violation( + obj: Any, func_name: str, arg_name: str | None = None +) -> None: + """Raise SecurityError if obj or any nested part is tainted or classified. + + Uses taint_sources() and classified_sources() for recursive detection. + """ + suffix = f" in argument '{arg_name}'" if arg_name else "" + + sources = taint_sources(obj, None) + if sources: + source_names = ", ".join(type(s).__name__ for s in sources) + raise SecurityError( + f"Function {func_name} requires safe input, but received " + f"tainted content (tainted by: {source_names}){suffix}" + ) + + sources = classified_sources(obj, None) + if sources: + source_names = ", ".join(type(s).__name__ for s in sources) + raise SecurityError( + f"Function {func_name} requires safe input, but received " + f"classified content (sources: {source_names}){suffix}" + ) + + +def privileged(func: F) -> F: + """Decorator to mark functions that require safe (non-tainted, non-classified) input. + + Functions decorated with @privileged will raise SecurityError if + called with tainted or classified content blocks. Checks are performed + recursively: if any argument is a Component, its parts (and their parts) + are also checked for taint or classified content. + + Args: + func: The function to decorate + + Returns: + The decorated function + + Raises: + SecurityError: If the function is called with tainted or classified content + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + for arg in args: + _raise_if_privilege_violation(arg, func.__name__, None) + for key, value in kwargs.items(): + _raise_if_privilege_violation(value, func.__name__, key) + return func(*args, **kwargs) + + return wrapper # type: ignore + + +def declassify(cblock: "CBlock") -> "CBlock": + """Create a declassified version of a CBlock (non-mutating). + + This function creates a new CBlock with the same content but marked + as safe (SecLevel.none()). The original CBlock is not modified. + + Args: + cblock: The CBlock to declassify + + Returns: + A new CBlock with safe security level + """ + from ..core.base import CBlock # Import here to avoid circular dependency + + # Return new CBlock with same content but safe security metadata + return CBlock( + cblock.value, + cblock._meta.copy() if cblock._meta else None, + sec_level=SecLevel.none(), + ) diff --git a/mellea/stdlib/components/chat.py b/mellea/stdlib/components/chat.py index 22c11369d..450e639bf 100644 --- a/mellea/stdlib/components/chat.py +++ b/mellea/stdlib/components/chat.py @@ -12,6 +12,7 @@ ModelToolCall, TemplateRepresentation, ) +from ...security import SecLevel from .docs.document import Document @@ -45,6 +46,7 @@ def __init__( self._content_cblock = CBlock(self.content) self._images = images self._docs = documents + self._sec_level: SecLevel | None = None @property def images(self) -> None | list[str]: @@ -53,6 +55,15 @@ def images(self) -> None | list[str]: return [str(i) for i in self._images] return None + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component. + + Returns: + SecLevel if present, None otherwise + """ + return self._sec_level + def parts(self) -> list[Component | CBlock]: """Returns all of the constituent parts of an Instruction.""" parts: list[Component | CBlock] = [self._content_cblock] diff --git a/mellea/stdlib/components/docs/document.py b/mellea/stdlib/components/docs/document.py index 577a6639a..d59dc3b72 100644 --- a/mellea/stdlib/components/docs/document.py +++ b/mellea/stdlib/components/docs/document.py @@ -1,6 +1,7 @@ """Document component.""" from ....core import CBlock, Component, ModelOutputThunk +from ....security import SecLevel # TODO: Add support for passing in docs as model options. @@ -12,6 +13,12 @@ def __init__(self, text: str, title: str | None = None, doc_id: str | None = Non self.text = text self.title = title self.doc_id = doc_id + self._sec_level: SecLevel | None = None + + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component.""" + return self._sec_level def parts(self) -> list[Component | CBlock]: """The set of all the constituent parts of the `Component`.""" diff --git a/mellea/stdlib/components/docs/richdocument.py b/mellea/stdlib/components/docs/richdocument.py index 67bcb01fa..9210c8ac4 100644 --- a/mellea/stdlib/components/docs/richdocument.py +++ b/mellea/stdlib/components/docs/richdocument.py @@ -13,6 +13,7 @@ from ....backends.tools import MelleaTool from ....core import CBlock, Component, ModelOutputThunk, TemplateRepresentation +from ....security import SecLevel from ..mobject import MObject, Query, Transform @@ -25,6 +26,16 @@ class RichDocument(Component[str]): def __init__(self, doc: DoclingDocument): """A `RichDocument` is a block of content with an underlying DoclingDocument.""" self._doc = doc + self._sec_level: SecLevel | None = None + + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component. + + Returns: + SecLevel if present, None otherwise + """ + return self._sec_level def parts(self) -> list[Component | CBlock]: """RichDocument has no parts. diff --git a/mellea/stdlib/components/genslot.py b/mellea/stdlib/components/genslot.py index eff9ae753..751edaa89 100644 --- a/mellea/stdlib/components/genslot.py +++ b/mellea/stdlib/components/genslot.py @@ -24,6 +24,7 @@ TemplateRepresentation, ValidationResult, ) +from ...security import SecLevel from ..requirements.requirement import reqify from ..session import MelleaSession @@ -289,6 +290,16 @@ def __init__(self, func: Callable[P, R]): # Set when calling the decorated func. self.precondition_requirements: list[Requirement] = [] self.requirements: list[Requirement] = [] + self._sec_level: SecLevel | None = None + + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component. + + Returns: + SecLevel if present, None otherwise + """ + return self._sec_level @abc.abstractmethod def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: diff --git a/mellea/stdlib/components/instruction.py b/mellea/stdlib/components/instruction.py index 288d9e2d8..f7f4fe089 100644 --- a/mellea/stdlib/components/instruction.py +++ b/mellea/stdlib/components/instruction.py @@ -15,6 +15,7 @@ TemplateRepresentation, blockify, ) +from ...security import SecLevel from ..requirements.requirement import reqify @@ -126,6 +127,12 @@ def __init__( ) self._images = images self._repair_string: str | None = None + self._sec_level: SecLevel | None = None + + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component.""" + return self._sec_level def parts(self): """Returns all of the constituent parts of an Instruction.""" diff --git a/mellea/stdlib/components/intrinsic/intrinsic.py b/mellea/stdlib/components/intrinsic/intrinsic.py index c12fa54fe..1731bca37 100644 --- a/mellea/stdlib/components/intrinsic/intrinsic.py +++ b/mellea/stdlib/components/intrinsic/intrinsic.py @@ -2,6 +2,7 @@ from ....backends.adapters import AdapterType, fetch_intrinsic_metadata from ....core import CBlock, Component, ModelOutputThunk, TemplateRepresentation +from ....security import SecLevel class Intrinsic(Component[str]): @@ -30,6 +31,16 @@ def __init__( if intrinsic_kwargs is None: intrinsic_kwargs = {} self.intrinsic_kwargs = intrinsic_kwargs + self._sec_level: SecLevel | None = None + + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component. + + Returns: + SecLevel if present, None otherwise + """ + return self._sec_level @property def intrinsic_name(self): diff --git a/mellea/stdlib/components/mify.py b/mellea/stdlib/components/mify.py index c54f16056..92dec0cbf 100644 --- a/mellea/stdlib/components/mify.py +++ b/mellea/stdlib/components/mify.py @@ -13,6 +13,7 @@ ModelOutputThunk, TemplateRepresentation, ) +from ...security import SecLevel from .mobject import MObjectProtocol, Query, Transform @@ -219,6 +220,15 @@ def parse(self, computed: ModelOutputThunk) -> str: except Exception as e: raise ComponentParseError(f"component parsing failed: {e}") + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component. + + Returns: + SecLevel if present, None otherwise + """ + return getattr(self, "_sec_level", None) + T = TypeVar("T") @@ -349,6 +359,16 @@ def mification(obj: T) -> T: # For objects, have to specifically bind methods. setattr(obj, name, types.MethodType(func, obj)) + # Add properties from MifiedProtocol (properties are descriptors, not methods) + # Create sec_level property directly to ensure Component protocol compliance + if "sec_level" not in current_members.keys(): + # Create a property that returns _sec_level attribute + sec_level_prop = property( + lambda self: getattr(self, "_sec_level", None), + doc="Get the security level for this Component.", + ) + setattr(obj, "sec_level", sec_level_prop) + # Set the defaults for the object/class. setattr(obj, "_query_type", query_type) setattr(obj, "_transform_type", transform_type) diff --git a/mellea/stdlib/components/mobject.py b/mellea/stdlib/components/mobject.py index b5dc25f87..2b0e0294d 100644 --- a/mellea/stdlib/components/mobject.py +++ b/mellea/stdlib/components/mobject.py @@ -8,6 +8,7 @@ from ...backends.tools import MelleaTool from ...core import CBlock, Component, ModelOutputThunk, TemplateRepresentation +from ...security import SecLevel class Query(Component[str]): @@ -22,6 +23,16 @@ def __init__(self, obj: Component, query: str) -> None: """ self._obj = obj self._query = query + self._sec_level: SecLevel | None = None + + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component. + + Returns: + SecLevel if present, None otherwise + """ + return self._sec_level def parts(self) -> list[Component | CBlock]: """Get the parts of the query.""" @@ -66,6 +77,16 @@ def __init__(self, obj: Component, transformation: str) -> None: """ self._obj = obj self._transformation = transformation + self._sec_level: SecLevel | None = None + + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component. + + Returns: + SecLevel if present, None otherwise + """ + return self._sec_level def parts(self) -> list[Component | CBlock]: """Get the parts of the transform.""" @@ -165,6 +186,16 @@ def __init__( """ self._query_type = query_type self._transform_type = transform_type + self._sec_level: SecLevel | None = None + + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component. + + Returns: + SecLevel if present, None otherwise + """ + return self._sec_level def parts(self) -> list[Component | CBlock]: """MObject has no parts because of how format_for_llm is defined.""" diff --git a/mellea/stdlib/components/simple.py b/mellea/stdlib/components/simple.py index 2d0f7dccc..de05ffea2 100644 --- a/mellea/stdlib/components/simple.py +++ b/mellea/stdlib/components/simple.py @@ -1,6 +1,7 @@ """SimpleComponent.""" from ...core import CBlock, Component, ModelOutputThunk +from ...security import SecLevel class SimpleComponent(Component[str]): @@ -13,6 +14,12 @@ def __init__(self, **kwargs): kwargs[key] = CBlock(value=kwargs[key]) self._kwargs_type_check(kwargs) self._kwargs = kwargs + self._sec_level: SecLevel | None = None + + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component.""" + return self._sec_level def parts(self): """Returns the values of the kwargs.""" @@ -21,9 +28,9 @@ def parts(self): def _kwargs_type_check(self, kwargs): for key in kwargs.keys(): value = kwargs[key] - assert issubclass(type(value), Component) or issubclass( - type(value), CBlock - ), f"Expected span but found {type(value)} of value: {value}" + assert isinstance(value, Component) or isinstance(value, CBlock), ( + f"Expected span but found {type(value)} of value: {value}" + ) assert type(key) is str return True diff --git a/mellea/stdlib/components/unit_test_eval.py b/mellea/stdlib/components/unit_test_eval.py index f7f4bf6e9..31f3572a9 100644 --- a/mellea/stdlib/components/unit_test_eval.py +++ b/mellea/stdlib/components/unit_test_eval.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, field_validator from ...core import CBlock, Component, ModelOutputThunk, TemplateRepresentation +from ...security import SecLevel class Message(BaseModel): @@ -63,6 +64,16 @@ def __init__( self.targets = targets or [] self.test_id = test_id self.input_ids = input_ids or [] + self._sec_level: SecLevel | None = None + + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component. + + Returns: + SecLevel if present, None otherwise + """ + return self._sec_level def parts(self) -> list[Component | CBlock]: """The set of constituent parts of the Component.""" diff --git a/mellea/stdlib/functional.py b/mellea/stdlib/functional.py index 41aa1fbf7..4d0dda34d 100644 --- a/mellea/stdlib/functional.py +++ b/mellea/stdlib/functional.py @@ -116,7 +116,7 @@ def act( @overload def instruct( - description: str, + description: str | CBlock, context: Context, backend: Backend, *, @@ -137,7 +137,7 @@ def instruct( @overload def instruct( - description: str, + description: str | CBlock, context: Context, backend: Backend, *, @@ -157,7 +157,7 @@ def instruct( def instruct( - description: str, + description: str | CBlock, context: Context, backend: Backend, *, diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index 784923d9f..ca1d12c2e 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -435,7 +435,7 @@ def act( @overload def instruct( self, - description: str, + description: str | CBlock, *, images: list[ImageBlock] | list[PILImage.Image] | None = None, requirements: list[Requirement | str] | None = None, @@ -454,7 +454,7 @@ def instruct( @overload def instruct( self, - description: str, + description: str | CBlock, *, images: list[ImageBlock] | list[PILImage.Image] | None = None, requirements: list[Requirement | str] | None = None, @@ -472,7 +472,7 @@ def instruct( def instruct( self, - description: str, + description: str | CBlock, *, images: list[ImageBlock] | list[PILImage.Image] | None = None, requirements: list[Requirement | str] | None = None, @@ -490,7 +490,7 @@ def instruct( """Generates from an instruction. Args: - description: The description of the instruction. + description: The description of the instruction (str or CBlock). requirements: A list of requirements that the instruction can be validated against. icl_examples: A list of in-context-learning examples that the instruction can be validated against. grounding_context: A list of grounding contexts that the instruction can use. They can bind as variables using a (key: str, value: str | ContentBlock) tuple. diff --git a/test/stdlib/test_security_comprehensive.py b/test/stdlib/test_security_comprehensive.py new file mode 100644 index 000000000..14bfecbd4 --- /dev/null +++ b/test/stdlib/test_security_comprehensive.py @@ -0,0 +1,507 @@ +"""Comprehensive security tests for mellea thread security features.""" + +import pytest + +from mellea.security import ( + AccessType, + SecLevel, + SecurityError, + declassify, + privileged, + taint_sources, +) +from mellea.stdlib.components import CBlock, ModelOutputThunk, SimpleComponent +from mellea.stdlib.components.instruction import Instruction +from mellea.stdlib.context import ChatContext + + +class TestAccessType: + """Test AccessType functionality.""" + + def test_access_type_interface(self): + """Test that AccessType is an abstract base class.""" + with pytest.raises(TypeError): + AccessType() # Should not be instantiable directly + + def test_access_type_implementation(self): + """Test implementing AccessType.""" + + class TestAccess(AccessType[str]): + def has_access(self, entitlement: str | None) -> bool: + return entitlement == "admin" + + access = TestAccess() + assert access.has_access("admin") + assert not access.has_access("user") + assert not access.has_access(None) + + +class TestSecLevel: + """Test SecLevel functionality.""" + + def test_sec_level_none(self): + """Test SecLevel.none() creates safe level.""" + from mellea.security.core import SecLevelType + + sec_level = SecLevel.none() + assert sec_level.level_type == SecLevelType.NONE + assert not sec_level.is_tainted() + assert not sec_level.is_classified() + assert sec_level.get_access_type() is None + + def test_sec_level_tainted_by(self): + """Test SecLevel.tainted_by() creates tainted level.""" + from mellea.security.core import SecLevelType + + source = CBlock("source content") + sec_level = SecLevel.tainted_by(source) + assert sec_level.level_type == SecLevelType.TAINTED_BY + assert sec_level.is_tainted() + assert not sec_level.is_classified() + assert sec_level.get_taint_sources() == [source] + assert sec_level.get_access_type() is None + + def test_sec_level_classified(self): + """Test SecLevel.classified() creates classified level.""" + from mellea.security.core import SecLevelType + + class TestAccess(AccessType[str]): + def has_access(self, entitlement: str | None) -> bool: + return entitlement == "admin" + + access = TestAccess() + sec_level = SecLevel.classified(access) + assert sec_level.level_type == SecLevelType.CLASSIFIED + assert not sec_level.is_tainted() + assert sec_level.is_classified() + assert sec_level.get_access_type() is access + assert sec_level.get_access_type().has_access("admin") + assert not sec_level.get_access_type().has_access("user") + assert not sec_level.get_access_type().has_access(None) + + +class TestCBlockSecurity: + """Test CBlock security functionality.""" + + def test_cblock_mark_tainted(self): + """Test marking CBlock as tainted.""" + cblock = CBlock("test content", sec_level=SecLevel.tainted_by(None)) + + assert cblock.sec_level is not None + assert cblock.sec_level.is_tainted() + assert not cblock.sec_level.is_classified() + assert cblock.sec_level.get_access_type() is None + + def test_cblock_mark_tainted_by_source(self): + """Test marking CBlock as tainted by another source.""" + source = CBlock("source content") + cblock = CBlock("test content", sec_level=SecLevel.tainted_by(source)) + + assert cblock.sec_level.is_tainted() + assert cblock.sec_level.get_taint_sources() == [source] + + def test_cblock_default_safe(self): + """Test that CBlock defaults to safe when no security metadata.""" + cblock = CBlock("test content") + assert cblock.sec_level is None or ( + not cblock.sec_level.is_tainted() and not cblock.sec_level.is_classified() + ) + + def test_cblock_with_classified_metadata(self): + """Test CBlock with classified security metadata.""" + + class TestAccess(AccessType[str]): + def has_access(self, entitlement: str | None) -> bool: + return entitlement == "admin" + + access = TestAccess() + sec_level = SecLevel.classified(access) + + cblock = CBlock("classified content", sec_level=sec_level) + + assert cblock.sec_level.is_classified() + access_type = cblock.sec_level.get_access_type() + assert access_type is not None + assert access_type.has_access("admin") + assert not access_type.has_access("user") + assert not access_type.has_access(None) + + +class TestDeclassify: + """Test declassify function.""" + + def test_declassify_creates_new_object(self): + """Test that declassify creates a new object without mutating original.""" + from mellea.security.core import SecLevelType + + original = CBlock("test content", sec_level=SecLevel.tainted_by(None)) + + declassified = declassify(original) + + # Objects are different + assert original is not declassified + assert id(original) != id(declassified) + + # Content is preserved + assert original.value == declassified.value + + # Security levels are different + assert original.sec_level.is_tainted() + assert not declassified.sec_level.is_tainted() + assert not declassified.sec_level.is_classified() + assert declassified.sec_level.level_type == SecLevelType.NONE + + # Original is unchanged + assert original.sec_level.is_tainted() + + def test_declassify_preserves_other_metadata(self): + """Test that declassify preserves other metadata.""" + from mellea.security.core import SecLevelType + + original = CBlock( + "test content", + meta={"custom": "value", "other": 123}, + sec_level=SecLevel.tainted_by(None), + ) + + declassified = declassify(original) + + assert declassified._meta["custom"] == "value" + assert declassified._meta["other"] == 123 + assert declassified.sec_level.level_type == SecLevelType.NONE + + +class TestPrivilegedDecorator: + """Test @privileged decorator functionality.""" + + def test_privileged_accepts_safe_input(self): + """Test that privileged functions accept safe input.""" + + @privileged + def safe_function(cblock: CBlock) -> str: + return f"Processed: {cblock.value}" + + # CBlock with no security metadata defaults to safe + safe_cblock = CBlock("safe content") + + result = safe_function(safe_cblock) + assert result == "Processed: safe content" + + def test_privileged_accepts_declassified_input(self): + """Test that privileged functions accept declassified input.""" + + @privileged + def safe_function(cblock: CBlock) -> str: + return f"Processed: {cblock.value}" + + tainted_cblock = CBlock("tainted content", sec_level=SecLevel.tainted_by(None)) + declassified_cblock = declassify(tainted_cblock) + + result = safe_function(declassified_cblock) + assert result == "Processed: tainted content" + + def test_privileged_rejects_tainted_input(self): + """Test that privileged functions reject tainted input.""" + + @privileged + def safe_function(cblock: CBlock) -> str: + return f"Processed: {cblock.value}" + + tainted_cblock = CBlock("tainted content", sec_level=SecLevel.tainted_by(None)) + + with pytest.raises(SecurityError, match="requires safe input"): + safe_function(tainted_cblock) + + def test_privileged_rejects_classified_input(self): + """Test that privileged functions reject classified input without proper entitlement.""" + + @privileged + def safe_function(cblock: CBlock) -> str: + return f"Processed: {cblock.value}" + + class TestAccess(AccessType[str]): + def has_access(self, entitlement: str | None) -> bool: + return entitlement == "admin" + + access = TestAccess() + sec_level = SecLevel.classified(access) + + classified_cblock = CBlock("classified content", sec_level=sec_level) + + with pytest.raises(SecurityError, match="requires safe input"): + safe_function(classified_cblock) + + def test_privileged_accepts_no_security_metadata(self): + """Test that privileged functions accept input with no security metadata.""" + + @privileged + def safe_function(cblock: CBlock) -> str: + return f"Processed: {cblock.value}" + + # CBlock with no security metadata defaults to safe + cblock = CBlock("content") + + result = safe_function(cblock) + assert result == "Processed: content" + + def test_privileged_with_kwargs(self): + """Test privileged function with keyword arguments.""" + + @privileged + def safe_function(data: CBlock, prefix: str = "Processed: ") -> str: + return f"{prefix}{data.value}" + + tainted_cblock = CBlock("tainted content", sec_level=SecLevel.tainted_by(None)) + + with pytest.raises(SecurityError, match="argument 'data'"): + safe_function(data=tainted_cblock) + + def test_privileged_rejects_nested_tainted_content(self): + """Test that privileged rejects Components containing tainted parts recursively.""" + tainted_inner = CBlock("sensitive", sec_level=SecLevel.tainted_by(None)) + # Component with safe top-level but tainted nested CBlock + nested = SimpleComponent(data=tainted_inner) + instruction = Instruction( + description=CBlock("Process the data"), grounding_context={"ctx": nested} + ) + + @privileged + def safe_function(comp): + return "ok" + + with pytest.raises(SecurityError, match="requires safe input"): + safe_function(instruction) + with pytest.raises(SecurityError, match="tainted content"): + safe_function(instruction) + + def test_privileged_rejects_nested_classified_content(self): + """Test that privileged rejects Components containing classified parts recursively.""" + + class TestAccess(AccessType[str]): + def has_access(self, entitlement: str | None) -> bool: + return entitlement == "admin" + + classified_inner = CBlock("secret", sec_level=SecLevel.classified(TestAccess())) + nested = SimpleComponent(data=classified_inner) + instruction = Instruction( + description=CBlock("Process"), grounding_context={"ctx": nested} + ) + + @privileged + def safe_function(comp): + return "ok" + + with pytest.raises(SecurityError, match="classified content"): + safe_function(instruction) + with pytest.raises(SecurityError, match=r"sources:.*CBlock"): + safe_function(instruction) + + +class TestTaintSources: + """Test taint source computation.""" + + def test_taint_sources_from_tainted_action(self): + """Test taint sources from tainted action.""" + action = CBlock("tainted action", sec_level=SecLevel.tainted_by(None)) + + sources = taint_sources(action, None) + assert len(sources) == 1 + assert sources[0] is action + + def test_taint_sources_from_safe_action(self): + """Test taint sources from safe action.""" + action = CBlock("safe action") + # No security metadata - defaults to safe + + sources = taint_sources(action, None) + assert len(sources) == 0 + + def test_taint_sources_from_context(self): + """Test taint sources from context.""" + action = CBlock("safe action") + + # Create context with tainted content + ctx = ChatContext() + tainted_cblock = CBlock("tainted context", sec_level=SecLevel.tainted_by(None)) + ctx = ctx.add(tainted_cblock) + + sources = taint_sources(action, ctx) + assert len(sources) == 1 + assert sources[0] is tainted_cblock + + def test_taint_sources_empty(self): + """Test taint sources with no tainted content.""" + action = CBlock("safe action") + ctx = ChatContext() + safe_cblock = CBlock("safe context") + # No security metadata - defaults to safe + ctx = ctx.add(safe_cblock) + + sources = taint_sources(action, ctx) + assert len(sources) == 0 + + def test_taint_sources_from_component_parts(self): + """Test taint sources from Component parts.""" + # Create Instruction with tainted description + tainted_desc = CBlock( + "tainted description", sec_level=SecLevel.tainted_by(None) + ) + instruction = Instruction(description=tainted_desc) + + sources = taint_sources(instruction, None) + assert len(sources) == 1 + assert sources[0] is tainted_desc + + def test_taint_sources_from_nested_component_with_tainted_cblocks(self): + """Test taint sources from nested Components containing tainted CBlocks.""" + # Create tainted CBlocks + tainted_data = CBlock( + "sensitive user data", sec_level=SecLevel.tainted_by(None) + ) + tainted_config = CBlock("secret config", sec_level=SecLevel.tainted_by(None)) + safe_info = CBlock("public info") # Safe CBlock + + # Create a SimpleComponent with mixed tainted and safe CBlocks + nested_component = SimpleComponent( + data=tainted_data, config=tainted_config, info=safe_info + ) + + # Create an Instruction with the nested Component in grounding_context + instruction = Instruction( + description="Process the data", + grounding_context={"context": nested_component}, + ) + + # taint_sources should find both tainted CBlocks through the nested Component + sources = taint_sources(instruction, None) + + # Should find both tainted CBlocks + assert len(sources) == 2 + assert tainted_data in sources + assert tainted_config in sources + assert safe_info not in sources # Safe CBlock should not be included + + def test_taint_sources_shallow_search_limit(self): + """Test that shallow search only checks last 5 components.""" + action = CBlock("safe action") + + # Create context with 7 items: tainted at positions 0 and 5 + ctx = ChatContext() + tainted_early = CBlock("tainted early", sec_level=SecLevel.tainted_by(None)) + ctx = ctx.add(tainted_early) # Position 0 - outside last 5 + + # Add 4 safe items + for i in range(4): + ctx = ctx.add(CBlock(f"safe {i}")) + + tainted_late = CBlock("tainted late", sec_level=SecLevel.tainted_by(None)) + ctx = ctx.add(tainted_late) # Position 5 - within last 5 + + # Add one more safe item + ctx = ctx.add(CBlock("safe final")) # Position 6 + + sources = taint_sources(action, ctx) + # Should only find tainted_late (position 5), not tainted_early (position 0) + assert len(sources) == 1 + assert sources[0] is tainted_late + + +class TestModelOutputThunkSecurity: + """Test ModelOutputThunk security functionality.""" + + def test_from_generation_with_taint_sources(self): + """Test ModelOutputThunk creation with taint sources.""" + taint_source = CBlock("taint source", sec_level=SecLevel.tainted_by(None)) + + sec_level = SecLevel.tainted_by([taint_source]) + mot = ModelOutputThunk( + value="generated content", sec_level=sec_level, meta={"custom": "value"} + ) + + assert mot.value == "generated content" + assert mot._meta["custom"] == "value" + assert mot.sec_level is not None + assert mot.sec_level.is_tainted() + assert not mot.sec_level.is_classified() + assert mot.sec_level.get_taint_sources() == [taint_source] + + def test_from_generation_without_taint_sources(self): + """Test ModelOutputThunk creation without taint sources.""" + from mellea.security.core import SecLevelType + + mot = ModelOutputThunk( + value="generated content", + sec_level=SecLevel.none(), + meta={"custom": "value"}, + ) + + assert mot.value == "generated content" + assert mot._meta["custom"] == "value" + assert mot.sec_level is not None + assert mot.sec_level.level_type == SecLevelType.NONE + assert not mot.sec_level.is_tainted() + assert not mot.sec_level.is_classified() + + def test_from_generation_empty_taint_sources(self): + """Test ModelOutputThunk creation with empty taint sources.""" + from mellea.security.core import SecLevelType + + mot = ModelOutputThunk( + value="generated content", + sec_level=SecLevel.none(), + meta={"custom": "value"}, + ) + + assert mot.sec_level.level_type == SecLevelType.NONE + assert not mot.sec_level.is_tainted() + assert not mot.sec_level.is_classified() + + +class TestSecurityIntegration: + """Test integration between security components.""" + + def test_security_flow_through_generation(self): + """Test security metadata flows through generation pipeline.""" + from mellea.security.core import SecLevelType + + # Create tainted input + tainted_input = CBlock("user input", sec_level=SecLevel.tainted_by(None)) + + # Simulate generation with taint sources + sources = taint_sources(tainted_input, None) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + mot = ModelOutputThunk(value="model response", sec_level=sec_level) + + # Verify output is tainted + assert mot.sec_level.is_tainted() + + # Declassify the output + safe_mot = declassify(mot) + assert not safe_mot.sec_level.is_tainted() + assert not safe_mot.sec_level.is_classified() + assert safe_mot.sec_level.level_type == SecLevelType.NONE + + # Verify original is unchanged + assert mot.sec_level.is_tainted() + + def test_privileged_function_with_generated_content(self): + """Test privileged function with generated content.""" + + @privileged + def process_response(mot: ModelOutputThunk) -> str: + return f"Processed: {mot.value}" + + # Generate tainted content + taint_source = CBlock("taint source", sec_level=SecLevel.tainted_by(None)) + + sec_level = SecLevel.tainted_by([taint_source]) + mot = ModelOutputThunk(value="tainted response", sec_level=sec_level) + + # Privileged function should reject tainted content + with pytest.raises(SecurityError): + process_response(mot) + + # Declassify and try again + safe_mot = declassify(mot) + result = process_response(safe_mot) + assert result == "Processed: tainted response"