diff --git a/.claude/skills/pr-review/SKILL.md b/.claude/skills/pr-review/SKILL.md new file mode 100644 index 0000000..42a0add --- /dev/null +++ b/.claude/skills/pr-review/SKILL.md @@ -0,0 +1,129 @@ +--- +name: pr-review +description: Review code changes on the current branch for quality, bugs, performance, and security +disable-model-invocation: true +argument-hint: "[optional: LINEAR-TICKET-ID]" +allowed-tools: Read, Grep, Glob, Bash(git diff:*), Bash(git log:*), Bash(git show:*), Bash(git branch:*), Bash(gh pr:*), Bash(gh api:*), Bash(~/.claude/scripts/fetch-github-pr.sh:*), Bash(~/.claude/scripts/fetch-sentry-data.sh:*), Bash(~/.claude/scripts/fetch-slack-thread.sh:*) +--- + +# Code Review + +You are reviewing code changes on the current branch. Your review must be based on the **current state of the code right now**, not on anything you've seen earlier in this conversation. + +## CRITICAL: Always Use Fresh Data + +**IGNORE any file contents, diffs, or line numbers you may have seen earlier in this conversation.** They may be stale. You MUST re-fetch everything from scratch using the commands below. + +## Step 1: Get the Current Diff and PR Context + +Run ALL of these commands to get a fresh view: + +```bash +# The authoritative diff -- only review what's in HERE +git diff main...HEAD + +# Recent commits on this branch +git log --oneline main..HEAD + +# PR description and comments +gh pr view --json number,title,body,comments,reviews,reviewRequests +``` + +Also fetch PR review comments (inline code comments): + +```bash +# Get the PR number +PR_NUMBER=$(gh pr view --json number -q '.number') + +# Fetch all review comments (inline comments on specific lines) +gh api repos/{owner}/{repo}/pulls/$PR_NUMBER/comments --jq '.[] | {path: .path, line: .line, body: .body, user: .user.login, created_at: .created_at}' + +# Fetch review-level comments (general review comments) +gh api repos/{owner}/{repo}/pulls/$PR_NUMBER/reviews --jq '.[] | {state: .state, body: .body, user: .user.login}' +``` + +## Step 2: Understand Context from PR Comments + +Before reviewing, read through the PR comments and review comments. Note **who** said what (by username). + +- **Already-addressed feedback**: If a reviewer pointed out an issue and the author has already fixed it (the fix is visible in the current diff), do NOT re-raise it. +- **Ongoing discussions**: Note any unresolved threads -- your review should take these into account. +- **Previous approvals/requests for changes**: Understand what reviewers have already looked at. + +**IMPORTANT**: Your review is YOUR independent review. Do not take credit for or reference other reviewers' findings as if they were yours. If another reviewer already flagged something, you can note "as [reviewer] pointed out" but do not present their feedback as your own prior review. Your verdict should be based solely on your own analysis of the current code. + +## Step 3: Get Requirements Context + +Check if a Linear ticket ID was provided as an argument ($ARGUMENTS). If not, try to extract it from the branch name (pattern: `{username}/{linear-ticket}-{title}`). + +If a Linear ticket is found: +- Use Linear MCP tools (`get_issue`) to get the issue details and comments +- **Check for a parent ticket**: If the issue has a parent issue, fetch the parent too. Our pattern is to have a parent ticket with project-wide requirements and sub-tickets for specific tasks (often one per repo/PR). The parent ticket will contain the full scope of the project, while the sub-ticket scopes what this specific PR should cover. Use both to assess completeness — the PR should fulfill the sub-ticket's scope, and that scope should be a reasonable subset of the parent's backend-related requirements. +- Look for Sentry links in the description/comments; if found, use Sentry MCP tools to get error details +- Assess whether the changes fulfill the ticket requirements + +If no ticket is found, check the PR description for context on what the changes are meant to accomplish. + +## Step 4: Review the Code + +Review ONLY the changed lines (from `git diff main...HEAD`). Do not comment on unchanged code. + +**When referencing code, always use the file path and quote the actual code snippet** rather than citing line numbers, since line numbers shift as the branch evolves. + +### Code Quality +- Is the code well-structured and maintainable? +- Does it follow CLAUDE.md conventions? (import grouping, error handling with lib/errors, naming, alphabetization, etc.) +- Any AI-generated slop? (excessive comments, unnecessary abstractions, over-engineering) + +### Performance +- N+1 queries, inefficient loops, missing indexes for new queries +- Unbuffered writes in hot paths (especially ClickHouse) +- Missing LIMIT clauses on potentially large result sets + +### Bugs +- Nil pointer risks (especially on struct pointer params and optional relations) +- Functions returning `nil, nil` (violates convention) +- Missing error handling +- Race conditions in concurrent code paths + +### Security +- Hardcoded secrets or sensitive data exposure +- Missing input validation on service request structs + +### Tests +- Are there tests for the new/changed code? +- Do the tests cover edge cases and error paths? +- Are test assertions specific (not just "no error")? + +## Step 5: Present the Review + +Structure your review as: + +``` +## Summary +[1-2 sentences: what this PR does and overall assessment] + +## Requirements Check +[Does the PR fulfill the Linear ticket / PR description requirements? Any gaps?] + +## Issues +### Critical (must fix before merge) +- [blocking issues] + +### Suggestions (nice to have) +- [non-blocking improvements] + +## Prior Review Activity +[Summarize what other reviewers have flagged, attributed by name. Note which of their concerns have been addressed in the current code and which remain open.] + +## Verdict +[LGTM / Needs changes / Needs discussion -- based on YOUR analysis, not other reviewers' findings] +``` + +## Guidelines + +- Be concise. Don't pad with praise or filler. +- Only raise issues that matter. Don't nitpick formatting (that's what linters are for). +- Quote code snippets rather than referencing line numbers. +- If PR comments show a discussion was already resolved, don't reopen it. +- If you're unsure about something, flag it as a question rather than a definitive issue. diff --git a/.fernignore b/.fernignore index 35c0aa1..a2a4fc3 100644 --- a/.fernignore +++ b/.fernignore @@ -7,14 +7,18 @@ README.md src/schematic/client.py # Additional custom code +.claude/ .github/CODEOWNERS scripts/ -src/schematic/cache.py +src/schematic/cache/ src/schematic/event_buffer.py src/schematic/http_client.py src/schematic/logging.py +src/schematic/datastream/ src/schematic/webhook_utils/ src/schematic/webhooks/verification.py tests/custom/ +tests/datastream/ tests/webhook_utils/ CLAUDE.md +WASM_VERSION diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e4588ce..8baf77a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,5 +1,7 @@ name: ci on: [push] +env: + GH_TOKEN: ${{ secrets.GH_TOKEN }} jobs: compile: runs-on: ubuntu-latest @@ -13,8 +15,10 @@ jobs: - name: Bootstrap poetry run: | curl -sSL https://install.python-poetry.org | python - -y --version 1.5.1 + - name: Download WASM binary + run: ./scripts/download-wasm.sh - name: Install dependencies - run: poetry install + run: poetry install --extras datastream - name: Compile run: poetry run mypy . test: @@ -29,14 +33,42 @@ jobs: - name: Bootstrap poetry run: | curl -sSL https://install.python-poetry.org | python - -y --version 1.5.1 + - name: Download WASM binary + run: ./scripts/download-wasm.sh - name: Install dependencies - run: poetry install + run: poetry install --extras datastream - name: Test run: poetry run pytest -rP -n auto . + verify-package: + runs-on: ubuntu-latest + steps: + - name: Checkout repo + uses: actions/checkout@v4 + - name: Set up python + uses: actions/setup-python@v4 + with: + python-version: 3.9 + - name: Bootstrap poetry + run: | + curl -sSL https://install.python-poetry.org | python - -y --version 1.5.1 + - name: Download WASM binary + run: ./scripts/download-wasm.sh + - name: Build package + run: poetry build + - name: Verify WASM in wheel + run: | + if ! zipinfo dist/*.whl | grep -q 'rulesengine.wasm'; then + echo "ERROR: rulesengine.wasm not found in wheel" + echo "Wheel contents:" + zipinfo dist/*.whl + exit 1 + fi + echo "Verified: rulesengine.wasm is included in the wheel" + publish: - needs: [compile, test] + needs: [compile, test, verify-package] if: github.event_name == 'push' && contains(github.ref, 'refs/tags/') runs-on: ubuntu-latest steps: @@ -49,6 +81,8 @@ jobs: - name: Bootstrap poetry run: | curl -sSL https://install.python-poetry.org | python - -y --version 1.5.1 + - name: Download WASM binary + run: ./scripts/download-wasm.sh - name: Install dependencies run: poetry install - name: Publish to pypi diff --git a/.gitignore b/.gitignore index d2e4ca8..9c767ac 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,7 @@ __pycache__/ dist/ poetry.toml + +# WASM binary (downloaded from rulesengine-rust GitHub Releases) +src/schematic/datastream/wasm/rulesengine.wasm +src/schematic/datastream/wasm/.wasm_version diff --git a/README.md b/README.md index bba4b80..17070ee 100644 --- a/README.md +++ b/README.md @@ -285,6 +285,129 @@ client.check_flag( ) ``` +## DataStream + +The DataStream functionality provides real-time updates for flags, companies, and users. It uses WebSocket connections to receive updates from the Schematic backend and evaluates feature flags locally using a WASM rules engine, reducing the number of network calls. + +### Installation + +DataStream requires additional dependencies for WebSocket connections and local flag evaluation. Install them with the `datastream` extra: + +```bash +pip install 'schematichq[datastream]' +# or +poetry add schematichq -E datastream +``` + +### Key Features + +- **Real-Time Updates**: Automatically updates cached data when changes occur on the backend. +- **Local Flag Evaluation**: Flag checks are evaluated locally via WASM, eliminating per-check network requests. +- **Configurable Caching**: Supports both in-memory caching and custom cache providers (e.g. Redis). + +### How to Enable DataStream + +To enable DataStream, set `use_datastream=True` in your `AsyncSchematicConfig`: + +```python +import asyncio +from schematic.client import AsyncSchematic, AsyncSchematicConfig, DataStreamConfig + +async def main(): + config = AsyncSchematicConfig( + use_datastream=True, + ) + + async with AsyncSchematic("YOUR_API_KEY", config) as client: + is_enabled = await client.check_flag( + "some-flag-key", + company={"id": "your-company-id"}, + user={"id": "your-user-id"}, + ) + +asyncio.run(main()) +``` + +### Configuring Cache TTL + +You can customize the cache TTL (in milliseconds) via the `DataStreamConfig`: + +```python +config = AsyncSchematicConfig( + use_datastream=True, + datastream=DataStreamConfig( + cache_ttl=300_000, # 5 minutes + ), +) +``` + +### Replicator Mode + +When running the `schematic-datastream-replicator` service, configure the client to operate in Replicator Mode. In this mode, the client skips establishing its own WebSocket connection and instead relies on a shared cache populated by the external replicator service. + +```python +import asyncio +from schematic.client import AsyncSchematic, AsyncSchematicConfig, DataStreamConfig + +async def main(): + config = AsyncSchematicConfig( + use_datastream=True, + datastream=DataStreamConfig( + replicator_mode=True, + ), + ) + + async with AsyncSchematic("YOUR_API_KEY", config) as client: + is_enabled = await client.check_flag( + "some-flag-key", + company={"id": "your-company-id"}, + ) + +asyncio.run(main()) +``` + +#### Cache TTL Configuration + +When using Replicator Mode, you should set the SDK's cache TTL to match the replicator's cache TTL. The replicator defaults to an unlimited cache TTL. If the SDK uses a shorter TTL (the default is 24 hours), locally updated cache entries will be written back with the shorter TTL and eventually evicted from the shared cache. + +To match the replicator's default unlimited TTL: + +```python +config = AsyncSchematicConfig( + use_datastream=True, + datastream=DataStreamConfig( + replicator_mode=True, + cache_ttl=None, # Unlimited, matching the replicator default + ), +) +``` + +#### Advanced Configuration + +```python +config = AsyncSchematicConfig( + use_datastream=True, + datastream=DataStreamConfig( + replicator_mode=True, + cache_ttl=None, + replicator_health_url="http://my-replicator:8090/ready", + replicator_health_check=60_000, # 60 seconds, in milliseconds + ), +) +``` + +#### Default Configuration + +- **Replicator Health URL**: `http://localhost:8090/ready` +- **Health Check Interval**: 30 seconds +- **Cache TTL**: 24 hours (SDK default; should be set to match the replicator's TTL, which defaults to unlimited) + +When running in Replicator Mode, the client will: +- Skip establishing WebSocket connections +- Periodically check if the replicator service is ready +- Use cached data populated by the external replicator service +- Fall back to direct API calls if the replicator is not available + ## Webhook Verification Schematic can send webhooks to notify your application of events. To ensure the security of these webhooks, Schematic signs each request using HMAC-SHA256. The Python SDK provides utility functions to verify these signatures. diff --git a/WASM_VERSION b/WASM_VERSION new file mode 100644 index 0000000..6e8bf73 --- /dev/null +++ b/WASM_VERSION @@ -0,0 +1 @@ +0.1.0 diff --git a/poetry.lock b/poetry.lock index 83fdc8b..107ce04 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.3.2 and should not be changed by hand. [[package]] name = "annotated-types" @@ -6,6 +6,7 @@ version = "0.7.0" description = "Reusable constraint types to use with typing.Annotated" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, @@ -20,6 +21,7 @@ version = "4.5.2" description = "High level compatibility layer for multiple asynchronous event loop implementations" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "anyio-4.5.2-py3-none-any.whl", hash = "sha256:c011ee36bc1e8ba40e5a81cb9df91925c218fe9b778554e0b56a21e1b5d4716f"}, {file = "anyio-4.5.2.tar.gz", hash = "sha256:23009af4ed04ce05991845451e11ef02fc7c5ed29179ac9a420e5ad0ac7ddc5b"}, @@ -33,7 +35,7 @@ typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""} [package.extras] doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] -test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21.0b1)"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "truststore (>=0.9.1) ; python_version >= \"3.10\"", "uvloop (>=0.21.0b1) ; platform_python_implementation == \"CPython\" and platform_system != \"Windows\""] trio = ["trio (>=0.26.1)"] [[package]] @@ -42,6 +44,7 @@ version = "2026.2.25" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "certifi-2026.2.25-py3-none-any.whl", hash = "sha256:027692e4402ad994f1c42e52a4997a9763c646b73e4096e4d5d6db8af1d6f0fa"}, {file = "certifi-2026.2.25.tar.gz", hash = "sha256:e887ab5cee78ea814d3472169153c2d12cd43b14bd03329a39a9c6e2e80bfba7"}, @@ -53,6 +56,8 @@ version = "0.4.6" description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +groups = ["dev"] +markers = "sys_platform == \"win32\"" files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, @@ -64,6 +69,8 @@ version = "1.3.1" description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" +groups = ["main", "dev"] +markers = "python_version < \"3.11\"" files = [ {file = "exceptiongroup-1.3.1-py3-none-any.whl", hash = "sha256:a7a39a3bd276781e98394987d3a5701d0c4edffb633bb7a5144577f82c773598"}, {file = "exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219"}, @@ -81,6 +88,7 @@ version = "2.1.2" description = "execnet: rapid multi-Python deployment" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "execnet-2.1.2-py3-none-any.whl", hash = "sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec"}, {file = "execnet-2.1.2.tar.gz", hash = "sha256:63d83bfdd9a23e35b9c6a3261412324f964c2ec8dcd8d3c6916ee9373e0befcd"}, @@ -95,6 +103,7 @@ version = "0.16.0" description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86"}, {file = "h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1"}, @@ -106,6 +115,7 @@ version = "1.0.9" description = "A minimal low-level HTTP client." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55"}, {file = "httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8"}, @@ -127,6 +137,7 @@ version = "0.28.1" description = "The next generation HTTP client." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad"}, {file = "httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc"}, @@ -139,7 +150,7 @@ httpcore = "==1.*" idna = "*" [package.extras] -brotli = ["brotli", "brotlicffi"] +brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""] cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] @@ -151,6 +162,7 @@ version = "3.11" description = "Internationalized Domain Names in Applications (IDNA)" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea"}, {file = "idna-3.11.tar.gz", hash = "sha256:795dafcc9c04ed0c1fb032c2aa73654d8e8c5023a7df64a53f39190ada629902"}, @@ -159,12 +171,37 @@ files = [ [package.extras] all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"] +[[package]] +name = "importlib-resources" +version = "6.4.5" +description = "Read resources from Python packages" +optional = true +python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"datastream\" or extra == \"rulesengine\"" +files = [ + {file = "importlib_resources-6.4.5-py3-none-any.whl", hash = "sha256:ac29d5f956f01d5e4bb63102a5a19957f1b9175e45649977264a1416783bb717"}, + {file = "importlib_resources-6.4.5.tar.gz", hash = "sha256:980862a1d16c9e147a59603677fa2aa5fd82b87f223b6cb870695bcfce830065"}, +] + +[package.dependencies] +zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} + +[package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] +cover = ["pytest-cov"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["jaraco.test (>=5.4)", "pytest (>=6,!=8.1.*)", "zipp (>=3.17)"] +type = ["pytest-mypy"] + [[package]] name = "iniconfig" version = "2.1.0" description = "brain-dead simple config-ini parsing" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760"}, {file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"}, @@ -176,6 +213,7 @@ version = "1.13.0" description = "Optional static typing for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "mypy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6607e0f1dd1fb7f0aca14d936d13fd19eba5e17e1cd2a14f808fa5f8f6d8f60a"}, {file = "mypy-1.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8a21be69bd26fa81b1f80a61ee7ab05b076c674d9b18fb56239d72e21d9f4c80"}, @@ -229,6 +267,7 @@ version = "1.1.0" description = "Type system extensions for programs checked with the mypy type checker." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505"}, {file = "mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558"}, @@ -240,6 +279,7 @@ version = "26.0" description = "Core utilities for Python packages" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529"}, {file = "packaging-26.0.tar.gz", hash = "sha256:00243ae351a257117b6a241061796684b084ed1c516a08c48a3f7e147a9d80b4"}, @@ -251,6 +291,7 @@ version = "1.5.0" description = "plugin and hook calling mechanisms for python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, @@ -266,6 +307,7 @@ version = "2.10.6" description = "Data validation using Python type hints" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "pydantic-2.10.6-py3-none-any.whl", hash = "sha256:427d664bf0b8a2b34ff5dd0f5a18df00591adcee7198fbd71981054cef37b584"}, {file = "pydantic-2.10.6.tar.gz", hash = "sha256:ca5daa827cce33de7a42be142548b0096bf05a7e7b365aebfa5f8eeec7128236"}, @@ -278,7 +320,7 @@ typing-extensions = ">=4.12.2" [package.extras] email = ["email-validator (>=2.0.0)"] -timezone = ["tzdata"] +timezone = ["tzdata ; python_version >= \"3.9\" and platform_system == \"Windows\""] [[package]] name = "pydantic-core" @@ -286,6 +328,7 @@ version = "2.27.2" description = "Core functionality for Pydantic validation and serialization" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "pydantic_core-2.27.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2d367ca20b2f14095a8f4fa1210f5a7b78b8a20009ecced6b12818f455b1e9fa"}, {file = "pydantic_core-2.27.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:491a2b73db93fab69731eaee494f320faa4e093dbed776be1a829c2eb222c34c"}, @@ -398,6 +441,7 @@ version = "7.4.4" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "pytest-7.4.4-py3-none-any.whl", hash = "sha256:b090cdf5ed60bf4c45261be03239c2c1c22df034fbffe691abe93cd80cea01d8"}, {file = "pytest-7.4.4.tar.gz", hash = "sha256:2cf0005922c6ace4a3e2ec8b4080eb0d9753fdc93107415332f50ce9e7994280"}, @@ -420,6 +464,7 @@ version = "0.23.8" description = "Pytest support for asyncio" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pytest_asyncio-0.23.8-py3-none-any.whl", hash = "sha256:50265d892689a5faefb84df80819d1ecef566eb3549cf915dfb33569359d1ce2"}, {file = "pytest_asyncio-0.23.8.tar.gz", hash = "sha256:759b10b33a6dc61cce40a8bd5205e302978bbbcc00e279a8b61d9a6a3c82e4d3"}, @@ -438,6 +483,7 @@ version = "3.6.1" description = "pytest xdist plugin for distributed testing, most importantly across multiple CPUs" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pytest_xdist-3.6.1-py3-none-any.whl", hash = "sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7"}, {file = "pytest_xdist-3.6.1.tar.gz", hash = "sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d"}, @@ -458,6 +504,7 @@ version = "2.9.0.post0" description = "Extensions to the standard Python datetime module" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +groups = ["dev"] files = [ {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, @@ -472,6 +519,7 @@ version = "0.11.5" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "ruff-0.11.5-py3-none-linux_armv6l.whl", hash = "sha256:2561294e108eb648e50f210671cc56aee590fb6167b594144401532138c66c7b"}, {file = "ruff-0.11.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:ac12884b9e005c12d0bd121f56ccf8033e1614f736f766c118ad60780882a077"}, @@ -499,6 +547,7 @@ version = "1.17.0" description = "Python 2 and 3 compatibility utilities" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +groups = ["dev"] files = [ {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"}, {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, @@ -510,6 +559,7 @@ version = "1.3.1" description = "Sniff out which async library your code is running under" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, @@ -521,6 +571,8 @@ version = "2.4.0" description = "A lil' TOML parser" optional = false python-versions = ">=3.8" +groups = ["dev"] +markers = "python_version < \"3.11\"" files = [ {file = "tomli-2.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b5ef256a3fd497d4973c11bf142e9ed78b150d36f5773f1ca6088c230ffc5867"}, {file = "tomli-2.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5572e41282d5268eb09a697c89a7bee84fae66511f87533a6f88bd2f7b652da9"}, @@ -577,6 +629,7 @@ version = "2.9.0.20241206" description = "Typing stubs for python-dateutil" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "types_python_dateutil-2.9.0.20241206-py3-none-any.whl", hash = "sha256:e248a4bc70a486d3e3ec84d0dc30eec3a5f979d6e7ee4123ae043eedbb987f53"}, {file = "types_python_dateutil-2.9.0.20241206.tar.gz", hash = "sha256:18f493414c26ffba692a72369fea7a154c502646301ebfe3d56a04b3767284cb"}, @@ -588,12 +641,159 @@ version = "4.13.2" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" +groups = ["main", "dev"] files = [ {file = "typing_extensions-4.13.2-py3-none-any.whl", hash = "sha256:a439e7c04b49fec3e5d3e2beaa21755cadbbdc391694e28ccdd36ca4a1408f8c"}, {file = "typing_extensions-4.13.2.tar.gz", hash = "sha256:e6c81219bd689f51865d9e372991c540bda33a0379d5573cddb9a3a23f7caaef"}, ] +[[package]] +name = "wasmtime" +version = "25.0.0" +description = "A WebAssembly runtime powered by Wasmtime" +optional = true +python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"datastream\" or extra == \"rulesengine\"" +files = [ + {file = "wasmtime-25.0.0-py3-none-any.whl", hash = "sha256:22aa59fc6e01deec8a6703046f82466090d5811096a3bb5c169907e36c842af1"}, + {file = "wasmtime-25.0.0-py3-none-macosx_10_13_x86_64.whl", hash = "sha256:13e9a718e9d580c1738782cc19f4dcb9fb068f7e51778ea621fd664f4433525b"}, + {file = "wasmtime-25.0.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5bdf1214ee3ee78a4a8a92da339f4c4c8c109e65af881b37f4adfc05d02af426"}, + {file = "wasmtime-25.0.0-py3-none-manylinux1_x86_64.whl", hash = "sha256:b4364e14d44e3b7afe6a40bf608e9d0d2c40b09dece441d20f4f6e31906b729c"}, + {file = "wasmtime-25.0.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:a07445073cf36a6e5d1dc28246a897dcbdaa537ba8be8805be65422ecca297eb"}, + {file = "wasmtime-25.0.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:53d5f614348a28aabdf80ae4f6fdfa803031af1f74ada03826fd4fd43aeee6c8"}, + {file = "wasmtime-25.0.0-py3-none-win_amd64.whl", hash = "sha256:f8a2a213b9179965db2d2eedececd69a37e287e902330509afae51c71a3a6842"}, +] + +[package.dependencies] +importlib-resources = ">=5.10" + +[package.extras] +testing = ["componentize-py", "coverage", "pycparser", "pytest", "pytest-mypy"] + +[[package]] +name = "websockets" +version = "13.1" +description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" +optional = true +python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"datastream\"" +files = [ + {file = "websockets-13.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f48c749857f8fb598fb890a75f540e3221d0976ed0bf879cf3c7eef34151acee"}, + {file = "websockets-13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c7e72ce6bda6fb9409cc1e8164dd41d7c91466fb599eb047cfda72fe758a34a7"}, + {file = "websockets-13.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f779498eeec470295a2b1a5d97aa1bc9814ecd25e1eb637bd9d1c73a327387f6"}, + {file = "websockets-13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4676df3fe46956fbb0437d8800cd5f2b6d41143b6e7e842e60554398432cf29b"}, + {file = "websockets-13.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a7affedeb43a70351bb811dadf49493c9cfd1ed94c9c70095fd177e9cc1541fa"}, + {file = "websockets-13.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1971e62d2caa443e57588e1d82d15f663b29ff9dfe7446d9964a4b6f12c1e700"}, + {file = "websockets-13.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5f2e75431f8dc4a47f31565a6e1355fb4f2ecaa99d6b89737527ea917066e26c"}, + {file = "websockets-13.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:58cf7e75dbf7e566088b07e36ea2e3e2bd5676e22216e4cad108d4df4a7402a0"}, + {file = "websockets-13.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c90d6dec6be2c7d03378a574de87af9b1efea77d0c52a8301dd831ece938452f"}, + {file = "websockets-13.1-cp310-cp310-win32.whl", hash = "sha256:730f42125ccb14602f455155084f978bd9e8e57e89b569b4d7f0f0c17a448ffe"}, + {file = "websockets-13.1-cp310-cp310-win_amd64.whl", hash = "sha256:5993260f483d05a9737073be197371940c01b257cc45ae3f1d5d7adb371b266a"}, + {file = "websockets-13.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:61fc0dfcda609cda0fc9fe7977694c0c59cf9d749fbb17f4e9483929e3c48a19"}, + {file = "websockets-13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ceec59f59d092c5007e815def4ebb80c2de330e9588e101cf8bd94c143ec78a5"}, + {file = "websockets-13.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c1dca61c6db1166c48b95198c0b7d9c990b30c756fc2923cc66f68d17dc558fd"}, + {file = "websockets-13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:308e20f22c2c77f3f39caca508e765f8725020b84aa963474e18c59accbf4c02"}, + {file = "websockets-13.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62d516c325e6540e8a57b94abefc3459d7dab8ce52ac75c96cad5549e187e3a7"}, + {file = "websockets-13.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87c6e35319b46b99e168eb98472d6c7d8634ee37750d7693656dc766395df096"}, + {file = "websockets-13.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:5f9fee94ebafbc3117c30be1844ed01a3b177bb6e39088bc6b2fa1dc15572084"}, + {file = "websockets-13.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:7c1e90228c2f5cdde263253fa5db63e6653f1c00e7ec64108065a0b9713fa1b3"}, + {file = "websockets-13.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:6548f29b0e401eea2b967b2fdc1c7c7b5ebb3eeb470ed23a54cd45ef078a0db9"}, + {file = "websockets-13.1-cp311-cp311-win32.whl", hash = "sha256:c11d4d16e133f6df8916cc5b7e3e96ee4c44c936717d684a94f48f82edb7c92f"}, + {file = "websockets-13.1-cp311-cp311-win_amd64.whl", hash = "sha256:d04f13a1d75cb2b8382bdc16ae6fa58c97337253826dfe136195b7f89f661557"}, + {file = "websockets-13.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9d75baf00138f80b48f1eac72ad1535aac0b6461265a0bcad391fc5aba875cfc"}, + {file = "websockets-13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9b6f347deb3dcfbfde1c20baa21c2ac0751afaa73e64e5b693bb2b848efeaa49"}, + {file = "websockets-13.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:de58647e3f9c42f13f90ac7e5f58900c80a39019848c5547bc691693098ae1bd"}, + {file = "websockets-13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1b54689e38d1279a51d11e3467dd2f3a50f5f2e879012ce8f2d6943f00e83f0"}, + {file = "websockets-13.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf1781ef73c073e6b0f90af841aaf98501f975d306bbf6221683dd594ccc52b6"}, + {file = "websockets-13.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d23b88b9388ed85c6faf0e74d8dec4f4d3baf3ecf20a65a47b836d56260d4b9"}, + {file = "websockets-13.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3c78383585f47ccb0fcf186dcb8a43f5438bd7d8f47d69e0b56f71bf431a0a68"}, + {file = "websockets-13.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:d6d300f8ec35c24025ceb9b9019ae9040c1ab2f01cddc2bcc0b518af31c75c14"}, + {file = "websockets-13.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a9dcaf8b0cc72a392760bb8755922c03e17a5a54e08cca58e8b74f6902b433cf"}, + {file = "websockets-13.1-cp312-cp312-win32.whl", hash = "sha256:2f85cf4f2a1ba8f602298a853cec8526c2ca42a9a4b947ec236eaedb8f2dc80c"}, + {file = "websockets-13.1-cp312-cp312-win_amd64.whl", hash = "sha256:38377f8b0cdeee97c552d20cf1865695fcd56aba155ad1b4ca8779a5b6ef4ac3"}, + {file = "websockets-13.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a9ab1e71d3d2e54a0aa646ab6d4eebfaa5f416fe78dfe4da2839525dc5d765c6"}, + {file = "websockets-13.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b9d7439d7fab4dce00570bb906875734df13d9faa4b48e261c440a5fec6d9708"}, + {file = "websockets-13.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:327b74e915cf13c5931334c61e1a41040e365d380f812513a255aa804b183418"}, + {file = "websockets-13.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:325b1ccdbf5e5725fdcb1b0e9ad4d2545056479d0eee392c291c1bf76206435a"}, + {file = "websockets-13.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:346bee67a65f189e0e33f520f253d5147ab76ae42493804319b5716e46dddf0f"}, + {file = "websockets-13.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91a0fa841646320ec0d3accdff5b757b06e2e5c86ba32af2e0815c96c7a603c5"}, + {file = "websockets-13.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:18503d2c5f3943e93819238bf20df71982d193f73dcecd26c94514f417f6b135"}, + {file = "websockets-13.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:a9cd1af7e18e5221d2878378fbc287a14cd527fdd5939ed56a18df8a31136bb2"}, + {file = "websockets-13.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:70c5be9f416aa72aab7a2a76c90ae0a4fe2755c1816c153c1a2bcc3333ce4ce6"}, + {file = "websockets-13.1-cp313-cp313-win32.whl", hash = "sha256:624459daabeb310d3815b276c1adef475b3e6804abaf2d9d2c061c319f7f187d"}, + {file = "websockets-13.1-cp313-cp313-win_amd64.whl", hash = "sha256:c518e84bb59c2baae725accd355c8dc517b4a3ed8db88b4bc93c78dae2974bf2"}, + {file = "websockets-13.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c7934fd0e920e70468e676fe7f1b7261c1efa0d6c037c6722278ca0228ad9d0d"}, + {file = "websockets-13.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:149e622dc48c10ccc3d2760e5f36753db9cacf3ad7bc7bbbfd7d9c819e286f23"}, + {file = "websockets-13.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a569eb1b05d72f9bce2ebd28a1ce2054311b66677fcd46cf36204ad23acead8c"}, + {file = "websockets-13.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:95df24ca1e1bd93bbca51d94dd049a984609687cb2fb08a7f2c56ac84e9816ea"}, + {file = "websockets-13.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8dbb1bf0c0a4ae8b40bdc9be7f644e2f3fb4e8a9aca7145bfa510d4a374eeb7"}, + {file = "websockets-13.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:035233b7531fb92a76beefcbf479504db8c72eb3bff41da55aecce3a0f729e54"}, + {file = "websockets-13.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:e4450fc83a3df53dec45922b576e91e94f5578d06436871dce3a6be38e40f5db"}, + {file = "websockets-13.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:463e1c6ec853202dd3657f156123d6b4dad0c546ea2e2e38be2b3f7c5b8e7295"}, + {file = "websockets-13.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:6d6855bbe70119872c05107e38fbc7f96b1d8cb047d95c2c50869a46c65a8e96"}, + {file = "websockets-13.1-cp38-cp38-win32.whl", hash = "sha256:204e5107f43095012b00f1451374693267adbb832d29966a01ecc4ce1db26faf"}, + {file = "websockets-13.1-cp38-cp38-win_amd64.whl", hash = "sha256:485307243237328c022bc908b90e4457d0daa8b5cf4b3723fd3c4a8012fce4c6"}, + {file = "websockets-13.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:9b37c184f8b976f0c0a231a5f3d6efe10807d41ccbe4488df8c74174805eea7d"}, + {file = "websockets-13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:163e7277e1a0bd9fb3c8842a71661ad19c6aa7bb3d6678dc7f89b17fbcc4aeb7"}, + {file = "websockets-13.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4b889dbd1342820cc210ba44307cf75ae5f2f96226c0038094455a96e64fb07a"}, + {file = "websockets-13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:586a356928692c1fed0eca68b4d1c2cbbd1ca2acf2ac7e7ebd3b9052582deefa"}, + {file = "websockets-13.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7bd6abf1e070a6b72bfeb71049d6ad286852e285f146682bf30d0296f5fbadfa"}, + {file = "websockets-13.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2aad13a200e5934f5a6767492fb07151e1de1d6079c003ab31e1823733ae79"}, + {file = "websockets-13.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:df01aea34b6e9e33572c35cd16bae5a47785e7d5c8cb2b54b2acdb9678315a17"}, + {file = "websockets-13.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:e54affdeb21026329fb0744ad187cf812f7d3c2aa702a5edb562b325191fcab6"}, + {file = "websockets-13.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:9ef8aa8bdbac47f4968a5d66462a2a0935d044bf35c0e5a8af152d58516dbeb5"}, + {file = "websockets-13.1-cp39-cp39-win32.whl", hash = "sha256:deeb929efe52bed518f6eb2ddc00cc496366a14c726005726ad62c2dd9017a3c"}, + {file = "websockets-13.1-cp39-cp39-win_amd64.whl", hash = "sha256:7c65ffa900e7cc958cd088b9a9157a8141c991f8c53d11087e6fb7277a03f81d"}, + {file = "websockets-13.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:5dd6da9bec02735931fccec99d97c29f47cc61f644264eb995ad6c0c27667238"}, + {file = "websockets-13.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:2510c09d8e8df777177ee3d40cd35450dc169a81e747455cc4197e63f7e7bfe5"}, + {file = "websockets-13.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1c3cf67185543730888b20682fb186fc8d0fa6f07ccc3ef4390831ab4b388d9"}, + {file = "websockets-13.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bcc03c8b72267e97b49149e4863d57c2d77f13fae12066622dc78fe322490fe6"}, + {file = "websockets-13.1-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:004280a140f220c812e65f36944a9ca92d766b6cc4560be652a0a3883a79ed8a"}, + {file = "websockets-13.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:e2620453c075abeb0daa949a292e19f56de518988e079c36478bacf9546ced23"}, + {file = "websockets-13.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:9156c45750b37337f7b0b00e6248991a047be4aa44554c9886fe6bdd605aab3b"}, + {file = "websockets-13.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:80c421e07973a89fbdd93e6f2003c17d20b69010458d3a8e37fb47874bd67d51"}, + {file = "websockets-13.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82d0ba76371769d6a4e56f7e83bb8e81846d17a6190971e38b5de108bde9b0d7"}, + {file = "websockets-13.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e9875a0143f07d74dc5e1ded1c4581f0d9f7ab86c78994e2ed9e95050073c94d"}, + {file = "websockets-13.1-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a11e38ad8922c7961447f35c7b17bffa15de4d17c70abd07bfbe12d6faa3e027"}, + {file = "websockets-13.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:4059f790b6ae8768471cddb65d3c4fe4792b0ab48e154c9f0a04cefaabcd5978"}, + {file = "websockets-13.1-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:25c35bf84bf7c7369d247f0b8cfa157f989862c49104c5cf85cb5436a641d93e"}, + {file = "websockets-13.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:83f91d8a9bb404b8c2c41a707ac7f7f75b9442a0a876df295de27251a856ad09"}, + {file = "websockets-13.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7a43cfdcddd07f4ca2b1afb459824dd3c6d53a51410636a2c7fc97b9a8cf4842"}, + {file = "websockets-13.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48a2ef1381632a2f0cb4efeff34efa97901c9fbc118e01951ad7cfc10601a9bb"}, + {file = "websockets-13.1-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:459bf774c754c35dbb487360b12c5727adab887f1622b8aed5755880a21c4a20"}, + {file = "websockets-13.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:95858ca14a9f6fa8413d29e0a585b31b278388aa775b8a81fa24830123874678"}, + {file = "websockets-13.1-py3-none-any.whl", hash = "sha256:a9a396a6ad26130cdae92ae10c36af09d9bfe6cafe69670fd3b6da9b07b4044f"}, + {file = "websockets-13.1.tar.gz", hash = "sha256:a3b3366087c1bc0a2795111edcadddb8b3b59509d5db5d7ea3fdd69f954a8878"}, +] + +[[package]] +name = "zipp" +version = "3.20.2" +description = "Backport of pathlib-compatible object wrapper for zip files" +optional = true +python-versions = ">=3.8" +groups = ["main"] +markers = "(extra == \"datastream\" or extra == \"rulesengine\") and python_version < \"3.10\"" +files = [ + {file = "zipp-3.20.2-py3-none-any.whl", hash = "sha256:a817ac80d6cf4b23bf7f2828b7cabf326f15a001bea8b1f9b49631780ba28350"}, + {file = "zipp-3.20.2.tar.gz", hash = "sha256:bc9eb26f4506fda01b81bcde0ca78103b6e62f991b381fec825435c836edbc29"}, +] + +[package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] +cover = ["pytest-cov"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["big-O", "importlib-resources ; python_version < \"3.9\"", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"] +type = ["pytest-mypy"] + +[extras] +datastream = ["wasmtime", "websockets"] +rulesengine = ["wasmtime"] + [metadata] -lock-version = "2.0" +lock-version = "2.1" python-versions = "^3.8" -content-hash = "bcf31a142c86d9e556553c8c260a93b563ac64a043076dbd48b26111d422c26e" +content-hash = "fe5ba1a2a6da427a37d4a3e628adbeb1f5a433ae8d8ba4c71d7b0b694c90b16c" diff --git a/pyproject.toml b/pyproject.toml index 0ef3755..1fa35ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,10 @@ packages = [ { include = "schematic", from = "src"} ] +include = [ + { path = "src/schematic/datastream/wasm/*.wasm", format = ["sdist", "wheel"] } +] + [tool.poetry.urls] Repository = 'https://github.com/schematichq/schematic-python' @@ -41,6 +45,12 @@ httpx = ">=0.21.2" pydantic = ">= 1.9.2" pydantic-core = ">=2.18.2" typing_extensions = ">= 4.0.0" +websockets = { version = ">=10.0", optional = true } +wasmtime = { version = ">=19.0.0", optional = true } + +[tool.poetry.extras] +datastream = ["websockets", "wasmtime"] +rulesengine = ["wasmtime"] [tool.poetry.group.dev.dependencies] mypy = "==1.13.0" diff --git a/scripts/download-wasm.sh b/scripts/download-wasm.sh new file mode 100755 index 0000000..24e8c52 --- /dev/null +++ b/scripts/download-wasm.sh @@ -0,0 +1,66 @@ +#!/bin/bash +set -e + +# Downloads the rules engine WASM binary from the schematic-api GitHub Release. +# Reads the pinned version from WASM_VERSION at the repo root. + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +WASM_DIR="$REPO_ROOT/src/schematic/datastream/wasm" +VERSION_FILE="$REPO_ROOT/WASM_VERSION" +TARGET_FILE="$WASM_DIR/rulesengine.wasm" + +GITHUB_REPO="SchematicHQ/schematic-api" + +if [ ! -f "$VERSION_FILE" ]; then + echo "ERROR: WASM_VERSION file not found at $VERSION_FILE" + exit 1 +fi + +VERSION=$(tr -d '[:space:]' < "$VERSION_FILE") +TAG="rulesengine/v${VERSION}" + +# Skip download if binary already exists and version matches +if [ -f "$TARGET_FILE" ] && [ -f "$WASM_DIR/.wasm_version" ]; then + CURRENT=$(tr -d '[:space:]' < "$WASM_DIR/.wasm_version") + if [ "$CURRENT" = "$VERSION" ]; then + echo "WASM binary already at version $VERSION, skipping download." + exit 0 + fi +fi + +ASSET_NAME="rulesengine-wasm-python-v${VERSION}.tar.gz" + +echo "Downloading rules engine WASM v${VERSION}..." +mkdir -p "$WASM_DIR" + +TMPDIR=$(mktemp -d) +trap 'rm -rf "$TMPDIR"' EXIT + +if ! gh release download "$TAG" \ + -R "$GITHUB_REPO" \ + -p "$ASSET_NAME" \ + -D "$TMPDIR" 2>/dev/null; then + echo "ERROR: Failed to download WASM binary" + echo "Tag: $TAG" + echo "Asset: $ASSET_NAME" + echo "" + echo "If this is a new version, ensure a release exists at:" + echo " https://github.com/${GITHUB_REPO}/releases/tag/${TAG}" + echo "" + echo "Ensure the GitHub CLI is authenticated: gh auth status" + exit 1 +fi + +tar -xzf "$TMPDIR/$ASSET_NAME" -C "$TMPDIR" + +if [ ! -f "$TMPDIR/rulesengine.wasm" ]; then + echo "ERROR: rulesengine.wasm not found in release archive" + ls -la "$TMPDIR" + exit 1 +fi + +cp "$TMPDIR/rulesengine.wasm" "$TARGET_FILE" +echo "$VERSION" > "$WASM_DIR/.wasm_version" + +echo "Downloaded rules engine WASM v${VERSION} to $TARGET_FILE" diff --git a/src/schematic/cache.py b/src/schematic/cache.py deleted file mode 100644 index d538c8e..0000000 --- a/src/schematic/cache.py +++ /dev/null @@ -1,71 +0,0 @@ -import time -from collections import OrderedDict -from typing import Generic, Optional -from typing import OrderedDict as OrderedDictType -from typing import TypeVar - -T = TypeVar("T") - -DEFAULT_CACHE_SIZE = 1000 # 1000 items -DEFAULT_CACHE_TTL = 5000 # 5 seconds - - -class CacheProvider(Generic[T]): - def get(self, key: str) -> Optional[T]: - pass - - def set(self, key: str, val: T, ttl_override: Optional[int] = None) -> None: - pass - - -class CachedItem(Generic[T]): - def __init__(self, value: T, expiration: float): - self.value = value - self.expiration = expiration - - -class LocalCache(CacheProvider[T]): - def __init__(self, max_size: int, ttl: int): - self.cache: OrderedDictType[str, CachedItem[T]] = OrderedDict() - self.max_size = max_size - self.ttl = ttl - - def get(self, key: str) -> Optional[T]: - if self.max_size == 0 or key not in self.cache: - return None - - item = self.cache[key] - current_time = time.time() * 1000 - - if current_time > item.expiration: - del self.cache[key] - return None - - # Move the accessed item to the end (most recently used) - self.cache.move_to_end(key) - return item.value - - def set(self, key: str, val: T, ttl_override: Optional[int] = None) -> None: - if self.max_size == 0: - return - - ttl = self.ttl if ttl_override is None else ttl_override - expiration = time.time() * 1000 + ttl - - # If the key already exists, update it and move it to the end - if key in self.cache: - self.cache[key] = CachedItem(val, expiration) - self.cache.move_to_end(key) - else: - # If we're at capacity, remove the least recently used item - if len(self.cache) >= self.max_size: - self.cache.popitem(last=False) - - # Add the new item - self.cache[key] = CachedItem(val, expiration) - - def clean_expired(self): - current_time = time.time() * 1000 - self.cache = OrderedDict( - (k, v) for k, v in self.cache.items() if v.expiration > current_time - ) diff --git a/src/schematic/cache/__init__.py b/src/schematic/cache/__init__.py new file mode 100644 index 0000000..dc53a89 --- /dev/null +++ b/src/schematic/cache/__init__.py @@ -0,0 +1,26 @@ +from .local import ( + DEFAULT_CACHE_SIZE, + DEFAULT_CACHE_TTL, + DEFAULT_MAX_ITEMS, + DEFAULT_TTL_MS, + AsyncLocalCache, + LocalCache, +) +from .provider import AsyncCacheProvider, CacheProvider +from .redis import RedisCache + +__all__ = [ + # Providers + "AsyncCacheProvider", + "CacheProvider", + # Local cache + "AsyncLocalCache", + "LocalCache", + # Redis cache + "RedisCache", + # Constants + "DEFAULT_CACHE_SIZE", + "DEFAULT_CACHE_TTL", + "DEFAULT_MAX_ITEMS", + "DEFAULT_TTL_MS", +] diff --git a/src/schematic/cache/local.py b/src/schematic/cache/local.py new file mode 100644 index 0000000..abc6717 --- /dev/null +++ b/src/schematic/cache/local.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +import time +from collections import OrderedDict +from typing import Any, Generic, List, Optional, Set, TypeVar + +from .provider import AsyncCacheProvider, CacheProvider + +T = TypeVar("T") + +DEFAULT_CACHE_SIZE = 1000 +DEFAULT_CACHE_TTL = 5000 # 5 seconds + +# Aliases for backwards compatibility +DEFAULT_MAX_ITEMS = DEFAULT_CACHE_SIZE +DEFAULT_TTL_MS = DEFAULT_CACHE_TTL + + +class CachedItem(Generic[T]): + __slots__ = ("value", "expiration") + + def __init__(self, value: T, expiration: float) -> None: + self.value = value + self.expiration = expiration + + +class LocalCache(CacheProvider[T]): + """In-memory synchronous cache with LRU eviction and TTL support.""" + + def __init__(self, max_size: int = DEFAULT_CACHE_SIZE, ttl: int = DEFAULT_CACHE_TTL) -> None: + self.cache: OrderedDict[str, CachedItem[Any]] = OrderedDict() + self.max_size = max_size + self.ttl = ttl + + def get(self, key: str) -> Optional[T]: + if self.max_size == 0 or key not in self.cache: + return None + + item = self.cache[key] + current_time = time.time() * 1000 + + if current_time > item.expiration: + del self.cache[key] + return None + + self.cache.move_to_end(key) + return item.value + + def set(self, key: str, val: T, ttl_override: Optional[int] = None) -> None: + if self.max_size == 0: + return + + ttl = self.ttl if ttl_override is None else ttl_override + expiration = time.time() * 1000 + ttl + + if key in self.cache: + self.cache[key] = CachedItem(val, expiration) + self.cache.move_to_end(key) + else: + if len(self.cache) >= self.max_size: + self.cache.popitem(last=False) + self.cache[key] = CachedItem(val, expiration) + + def clean_expired(self) -> None: + current_time = time.time() * 1000 + self.cache = OrderedDict( + (k, v) for k, v in self.cache.items() if v.expiration > current_time + ) + + +class _AsyncCacheItem(Generic[T]): + __slots__ = ("value", "expiration") + + def __init__(self, value: T, expiration: float) -> None: + self.value = value + self.expiration = expiration + + +class AsyncLocalCache(AsyncCacheProvider[T]): + """In-memory async cache with LRU eviction and TTL support.""" + + def __init__(self, *, max_items: int = DEFAULT_MAX_ITEMS, ttl: int = DEFAULT_TTL_MS) -> None: + self._cache: OrderedDict[str, _AsyncCacheItem[Any]] = OrderedDict() + self._max_items = max_items + self._default_ttl = ttl + + async def get(self, key: str) -> Optional[T]: + item = self._cache.get(key) + if item is None: + return None + + now_ms = time.time() * 1000 + if now_ms >= item.expiration: + del self._cache[key] + return None + + self._cache.move_to_end(key) + return item.value + + async def set(self, key: str, value: T, ttl: Optional[int] = None) -> None: + if self._max_items == 0: + return + + effective_ttl = ttl if ttl is not None else self._default_ttl + + if key in self._cache: + del self._cache[key] + + self._evict_expired() + + while len(self._cache) >= self._max_items: + self._cache.popitem(last=False) + + now_ms = time.time() * 1000 + self._cache[key] = _AsyncCacheItem( + value=value, + expiration=now_ms + effective_ttl, + ) + + async def delete(self, key: str) -> None: + self._cache.pop(key, None) + + async def delete_missing(self, keys_to_keep: List[str], *, scan_pattern: Optional[str] = None) -> None: + keep_set: Set[str] = set(keys_to_keep) + to_delete = [k for k in self._cache if k not in keep_set] + for k in to_delete: + del self._cache[k] + + def _evict_expired(self) -> None: + now_ms = time.time() * 1000 + expired = [k for k, item in self._cache.items() if now_ms >= item.expiration] + for k in expired: + del self._cache[k] diff --git a/src/schematic/cache/provider.py b/src/schematic/cache/provider.py new file mode 100644 index 0000000..d61f138 --- /dev/null +++ b/src/schematic/cache/provider.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from typing import Generic, List, Optional, TypeVar + +T = TypeVar("T") + + +class CacheProvider(Generic[T]): + """Synchronous cache provider interface.""" + + def get(self, key: str) -> Optional[T]: + raise NotImplementedError + + def set(self, key: str, val: T, ttl_override: Optional[int] = None) -> None: + raise NotImplementedError + + +class AsyncCacheProvider(Generic[T]): + """Async cache provider interface for storing and retrieving entities.""" + + async def get(self, key: str) -> Optional[T]: + raise NotImplementedError + + async def set(self, key: str, value: T, ttl: Optional[int] = None) -> None: + raise NotImplementedError + + async def delete(self, key: str) -> None: + raise NotImplementedError + + async def delete_missing(self, keys_to_keep: List[str], *, scan_pattern: Optional[str] = None) -> None: + """Delete all keys not in keys_to_keep. Optional for bulk operations.""" + raise NotImplementedError diff --git a/src/schematic/cache/redis.py b/src/schematic/cache/redis.py new file mode 100644 index 0000000..895199e --- /dev/null +++ b/src/schematic/cache/redis.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import json +import logging +from typing import Any, List, Optional, TypeVar + +from .provider import AsyncCacheProvider + +T = TypeVar("T") + +logger = logging.getLogger(__name__) + +class RedisCache(AsyncCacheProvider[T]): + """Async Redis cache provider using redis.asyncio. + + Stores values as JSON strings. Supports TTL and pattern-based bulk deletion. + + Usage:: + + import redis.asyncio as redis + + pool = redis.ConnectionPool.from_url("redis://localhost:6379") + client = redis.Redis(connection_pool=pool) + cache = RedisCache(client, prefix="schematic:flags") + """ + + def __init__( + self, + client: Any, + *, + prefix: str = "schematic", + default_ttl_ms: Optional[int] = None, + ) -> None: + self._client = client + self._prefix = prefix + self._default_ttl_ms = default_ttl_ms + + def _prefixed(self, key: str) -> str: + return f"{self._prefix}:{key}" + + async def get(self, key: str) -> Optional[T]: + prefixed_key = self._prefixed(key) + raw = await self._client.get(prefixed_key) + logger.debug("Redis GET %s -> %s", prefixed_key, "hit" if raw is not None else "miss") + if raw is None: + return None + try: + return json.loads(raw) + except (json.JSONDecodeError, TypeError): + logger.warning("Failed to deserialize cache value for key: %s", key) + return None + + async def set(self, key: str, value: T, ttl: Optional[int] = None) -> None: + effective_ttl = ttl if ttl is not None else self._default_ttl_ms + serialized = json.dumps(value, default=_json_default) + prefixed = self._prefixed(key) + if effective_ttl is not None: + await self._client.psetex(prefixed, effective_ttl, serialized) + else: + await self._client.set(prefixed, serialized) + + async def delete(self, key: str) -> None: + await self._client.delete(self._prefixed(key)) + + async def delete_missing(self, keys_to_keep: List[str], *, scan_pattern: Optional[str] = None) -> None: + """Delete all keys matching scan_pattern that are not in keys_to_keep. + + Uses SCAN to iterate keys without blocking the server. + """ + pattern = self._prefixed(scan_pattern or "*") + keep_set = {self._prefixed(k) for k in keys_to_keep} + to_delete: list[str] = [] + + async for key in self._client.scan_iter(match=pattern, count=100): + key_str = key.decode("utf-8") if isinstance(key, bytes) else key + if key_str not in keep_set: + to_delete.append(key_str) + + if to_delete: + await self._client.delete(*to_delete) + + +def _json_default(obj: Any) -> Any: + """Default JSON serializer for Pydantic models and datetime objects.""" + if hasattr(obj, "model_dump"): + return obj.model_dump(mode="json") + if hasattr(obj, "dict"): + return obj.dict() + if hasattr(obj, "isoformat"): + return obj.isoformat() + raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") diff --git a/src/schematic/client.py b/src/schematic/client.py index 1a649b5..65d3b30 100644 --- a/src/schematic/client.py +++ b/src/schematic/client.py @@ -1,24 +1,52 @@ import atexit import logging from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union import httpx from .base_client import AsyncBaseSchematic, BaseSchematic -from .cache import DEFAULT_CACHE_SIZE, DEFAULT_CACHE_TTL, CacheProvider, LocalCache +from .cache import DEFAULT_CACHE_SIZE, DEFAULT_CACHE_TTL, AsyncCacheProvider, CacheProvider, LocalCache +from .datastream import DataStreamClient, DataStreamClientOptions from .event_buffer import AsyncEventBuffer, EventBuffer from .http_client import AsyncOfflineHTTPClient, OfflineHTTPClient from .logging import get_default_logger from .types import ( + CheckFlagRequestBody, + CheckFlagResponseData, CreateEventRequestBody, EventBody, + EventBodyFlagCheck, EventBodyIdentify, EventBodyIdentifyCompany, EventBodyTrack, + FeatureEntitlement, ) +@dataclass +class CheckFlagOptions: + """Options for flag check methods.""" + + default_value: Optional[Union[bool, Callable[[], bool]]] = None + timeout: Optional[float] = None + + +@dataclass +class DataStreamConfig: + """Configuration for DataStream real-time flag evaluation.""" + + cache_ttl: Optional[int] = None + company_cache: Optional[AsyncCacheProvider[Any]] = None + company_lookup_cache: Optional[AsyncCacheProvider[str]] = None + user_cache: Optional[AsyncCacheProvider[Any]] = None + user_lookup_cache: Optional[AsyncCacheProvider[str]] = None + flag_cache: Optional[AsyncCacheProvider[Any]] = None + replicator_mode: bool = False + replicator_health_url: Optional[str] = None + replicator_health_check: Optional[int] = None + + @dataclass class SchematicConfig: base_url: Optional[str] = None @@ -29,7 +57,7 @@ class SchematicConfig: logger: Optional[logging.Logger] = None offline: bool = False timeout: Optional[float] = None - cache_providers: Optional[List[CacheProvider[bool]]] = None + cache_providers: Optional[List[CacheProvider[CheckFlagResponseData]]] = None class Schematic(BaseSchematic): @@ -52,9 +80,10 @@ def __init__(self, api_key: str, config: Optional[SchematicConfig] = None): logger=self.logger, period=self.event_buffer_period, ) - self.flag_check_cache_providers = config.cache_providers or [ - LocalCache[bool](DEFAULT_CACHE_SIZE, DEFAULT_CACHE_TTL) - ] + self.flag_check_cache_providers: List[CacheProvider[CheckFlagResponseData]] = ( + config.cache_providers if config.cache_providers is not None + else [LocalCache[CheckFlagResponseData](DEFAULT_CACHE_SIZE, DEFAULT_CACHE_TTL)] + ) self.offline = config.offline atexit.register(self.shutdown) @@ -70,16 +99,38 @@ def check_flag( flag_key: str, company: Optional[Dict[str, str]] = None, user: Optional[Dict[str, str]] = None, + options: Optional[CheckFlagOptions] = None, ) -> bool: + resp = self.check_flag_with_entitlement(flag_key, company=company, user=user, options=options) + return resp.value + + def check_flag_with_entitlement( + self, + flag_key: str, + company: Optional[Dict[str, str]] = None, + user: Optional[Dict[str, str]] = None, + options: Optional[CheckFlagOptions] = None, + ) -> CheckFlagResponseData: + default_value = self._resolve_default(flag_key, options) + if self.offline: - return self._get_flag_default(flag_key) + return CheckFlagResponseData( + flag=flag_key, + reason="flag default", + value=default_value, + ) + return self._check_flag_via_api(flag_key, company, user, default_value) + + def _check_flag_via_api( + self, + flag_key: str, + company: Optional[Dict[str, str]], + user: Optional[Dict[str, str]], + default_value: bool, + ) -> CheckFlagResponseData: try: - cache_key = ( - flag_key + ":" + str(company) + ":" + str(user) - if (company or user) - else flag_key - ) + cache_key = _build_cache_key(flag_key, company, user) for provider in self.flag_check_cache_providers: cached_value = provider.get(cache_key) @@ -87,16 +138,24 @@ def check_flag( return cached_value resp = self.features.check_flag(flag_key, company=company, user=user) - if resp is None: - return self._get_flag_default(flag_key) + if resp is None or resp.data.value is None: + return CheckFlagResponseData( + flag=flag_key, + reason="flag default", + value=default_value, + ) for provider in self.flag_check_cache_providers: - provider.set(cache_key, resp.data.value) + provider.set(cache_key, resp.data) - return resp.data.value + return resp.data except Exception as e: self.logger.error(e) - return self._get_flag_default(flag_key) + return CheckFlagResponseData( + flag=flag_key, + reason="flag default", + value=default_value, + ) def identify( self, @@ -146,6 +205,13 @@ def _enqueue_event(self, event_type: str, body: EventBody) -> None: def _get_flag_default(self, flag_key: str) -> bool: return self.flag_defaults.get(flag_key, False) + def _resolve_default(self, flag_key: str, options: Optional[CheckFlagOptions] = None) -> bool: + if options and options.default_value is not None: + if callable(options.default_value): + return options.default_value() + return options.default_value + return self._get_flag_default(flag_key) + @dataclass class AsyncSchematicConfig: @@ -157,36 +223,38 @@ class AsyncSchematicConfig: logger: Optional[logging.Logger] = None offline: bool = False timeout: Optional[float] = None - cache_providers: Optional[List[CacheProvider[bool]]] = None + cache_providers: Optional[List[CacheProvider[CheckFlagResponseData]]] = None + use_datastream: bool = False + datastream: Optional[DataStreamConfig] = None class AsyncSchematic(AsyncBaseSchematic): """Async Schematic client for feature flags and event tracking. - + This client provides async methods for checking feature flags and tracking events. - It automatically initializes on first use and maintains background tasks for + It automatically initializes on first use and maintains background tasks for event buffering that require proper cleanup. - + IMPORTANT: Always call shutdown() when done, or use as a context manager: - + # Recommended patterns: - + # 1. Context manager (automatic cleanup): async with AsyncSchematic(api_key, config) as client: result = await client.check_flag("my-flag") # Auto-initializes - + # 2. Manual (explicit cleanup): client = AsyncSchematic(api_key, config) try: result = await client.check_flag("my-flag") # Auto-initializes finally: await client.shutdown() # REQUIRED for proper cleanup - + # 3. Web framework (lifecycle managed): # In startup: client = AsyncSchematic(api_key, config) # In shutdown: await client.shutdown() """ - + def __init__(self, api_key: str, config: Optional[AsyncSchematicConfig] = None): self._initialized = False config = config or AsyncSchematicConfig() @@ -209,38 +277,147 @@ def __init__(self, api_key: str, config: Optional[AsyncSchematicConfig] = None): logger=self.logger, period=self.event_buffer_period, ) - self.flag_check_cache_providers = config.cache_providers or [ - LocalCache[bool](DEFAULT_CACHE_SIZE, DEFAULT_CACHE_TTL) - ] + self.flag_check_cache_providers: List[CacheProvider[CheckFlagResponseData]] = ( + config.cache_providers if config.cache_providers is not None + else [LocalCache[CheckFlagResponseData](DEFAULT_CACHE_SIZE, DEFAULT_CACHE_TTL)] + ) self.offline = config.offline self._shutdown_requested = False self._is_shutting_down = False + + # DataStream client + self._datastream_client: Optional[DataStreamClient] = None + if config.use_datastream and not config.offline: + ds = config.datastream or DataStreamConfig() + ds_opts = DataStreamClientOptions( + api_key=api_key, + base_url=config.base_url, + logger=self.logger, + ) + if ds.cache_ttl is not None: + ds_opts.cache_ttl = ds.cache_ttl + if ds.company_cache is not None: + ds_opts.company_cache = ds.company_cache + if ds.company_lookup_cache is not None: + ds_opts.company_lookup_cache = ds.company_lookup_cache + if ds.user_cache is not None: + ds_opts.user_cache = ds.user_cache + if ds.user_lookup_cache is not None: + ds_opts.user_lookup_cache = ds.user_lookup_cache + if ds.flag_cache is not None: + ds_opts.flag_cache = ds.flag_cache + ds_opts.replicator_mode = ds.replicator_mode + if ds.replicator_health_url is not None: + ds_opts.replicator_health_url = ds.replicator_health_url + if ds.replicator_health_check is not None: + ds_opts.replicator_health_check = ds.replicator_health_check + + self._datastream_client = DataStreamClient(ds_opts) + self._initialized = True async def __aenter__(self): + await self._start_datastream() return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self.shutdown() + async def _start_datastream(self) -> None: + if self._datastream_client is not None: + try: + await self._datastream_client.start() + except Exception as e: + self.logger.error(f"Failed to start DataStream client: {e}") + self._datastream_client = None + async def initialize(self) -> None: - pass + await self._start_datastream() + + def _get_datastream(self) -> Optional[DataStreamClient]: + return self._datastream_client async def check_flag( self, flag_key: str, company: Optional[Dict[str, str]] = None, user: Optional[Dict[str, str]] = None, + options: Optional[CheckFlagOptions] = None, ) -> bool: + resp = await self.check_flag_with_entitlement(flag_key, company=company, user=user, options=options) + return resp.value + + async def check_flag_with_entitlement( + self, + flag_key: str, + company: Optional[Dict[str, str]] = None, + user: Optional[Dict[str, str]] = None, + options: Optional[CheckFlagOptions] = None, + ) -> CheckFlagResponseData: + default_value = self._resolve_default(flag_key, options) + if self.offline: - return self._get_flag_default(flag_key) + return CheckFlagResponseData( + flag=flag_key, + reason="flag default", + value=default_value, + ) + # Try DataStream first if available + ds = self._get_datastream() + if ds is not None: + try: + resp = await ds.check_flag( + CheckFlagRequestBody(company=company, user=user), + flag_key, + ) + + # Enqueue flag_check event + await self._enqueue_event( + "flag_check", + EventBodyFlagCheck( + flag_key=flag_key, + value=resp.value if resp.value is not None else False, + reason=resp.reason if resp.reason else "unknown", + rule_id=resp.rule_id, + company_id=resp.company_id, + user_id=resp.user_id, + flag_id=resp.flag_id, + req_company=company, + req_user=user, + ), + ) + + entitlement = ( + FeatureEntitlement.model_validate(resp.entitlement.model_dump(mode="json")) + if resp.entitlement is not None else None + ) + return CheckFlagResponseData( + company_id=resp.company_id, + entitlement=entitlement, + error=resp.err, + flag=resp.flag_key, + flag_id=resp.flag_id, + reason=resp.reason, + rule_id=resp.rule_id, + rule_type=resp.rule_type, + user_id=resp.user_id, + value=resp.value if resp.value is not None else self._get_flag_default(flag_key), + ) + except Exception as e: + self.logger.debug(f"Datastream flag check failed ({e}), falling back to API") + + return await self._check_flag_via_api(flag_key, company, user, default_value) + + async def _check_flag_via_api( + self, + flag_key: str, + company: Optional[Dict[str, str]], + user: Optional[Dict[str, str]], + default_value: bool, + ) -> CheckFlagResponseData: try: - cache_key = ( - flag_key + ":" + str(company) + ":" + str(user) - if (company or user) - else flag_key - ) + cache_key = _build_cache_key(flag_key, company, user) for provider in self.flag_check_cache_providers: cached_value = provider.get(cache_key) @@ -248,16 +425,24 @@ async def check_flag( return cached_value resp = await self.features.check_flag(flag_key, company=company, user=user) - if resp is None: - return self._get_flag_default(flag_key) + if resp is None or resp.data.value is None: + return CheckFlagResponseData( + flag=flag_key, + reason="flag default", + value=default_value, + ) for provider in self.flag_check_cache_providers: - provider.set(cache_key, resp.data.value) + provider.set(cache_key, resp.data) - return resp.data.value + return resp.data except Exception as e: self.logger.error(e) - return self._get_flag_default(flag_key) + return CheckFlagResponseData( + flag=flag_key, + reason="flag default", + value=default_value, + ) async def identify( self, @@ -265,7 +450,7 @@ async def identify( company: Optional[EventBodyIdentifyCompany] = None, name: Optional[str] = None, traits: Optional[Dict[str, Any]] = None, - ) -> None: + ) -> None: await self._enqueue_event( "identify", EventBodyIdentify( @@ -283,7 +468,7 @@ async def track( user: Optional[Dict[str, str]] = None, traits: Optional[Dict[str, Any]] = None, quantity: Optional[int] = None, - ) -> None: + ) -> None: await self._enqueue_event( "track", EventBodyTrack( @@ -295,6 +480,18 @@ async def track( ), ) + # Update company metrics in DataStream if available and connected + ds = self._get_datastream() + if company and ds is not None and ds.is_connected(): + try: + await ds.update_company_metrics( + company, + event, + quantity or 1, + ) + except Exception as e: + self.logger.error(f"Failed to update company metrics: {e}") + async def _enqueue_event(self, event_type: str, body: EventBody) -> None: if self.offline: return @@ -307,31 +504,44 @@ async def _enqueue_event(self, event_type: str, body: EventBody) -> None: def _get_flag_default(self, flag_key: str) -> bool: return self.flag_defaults.get(flag_key, False) + def _resolve_default(self, flag_key: str, options: Optional[CheckFlagOptions] = None) -> bool: + if options and options.default_value is not None: + if callable(options.default_value): + return options.default_value() + return options.default_value + return self._get_flag_default(flag_key) + async def shutdown(self) -> None: """Properly shutdown the client, flushing any pending events. - + This method should be called when you're done using the client to ensure: - All pending events are flushed to the server - Background tasks are properly terminated - Resources are cleaned up - + It's safe to call this method multiple times, even if the client was never used. """ # Only do the shutdown once if self._is_shutting_down: self.logger.debug("Shutdown already in progress, skipping") return - + self._is_shutting_down = True - + # If we were never initialized, there's nothing to clean up if not self._initialized: self.logger.debug("Client was never initialized, nothing to clean up") return - + self.logger.info("Shutting down AsyncSchematic...") - + try: + if self._datastream_client is not None: + try: + await self._datastream_client.close() + except Exception as e: + self.logger.error(f"Error closing DataStream client: {e}") + # Flush and stop the event buffer await self.event_buffer.stop() self.logger.info("Shutdown complete.") @@ -339,4 +549,16 @@ async def shutdown(self) -> None: self.logger.error(f"Error during shutdown: {e}") finally: self._shutdown_requested = True - + + +def _build_cache_key( + flag_key: str, + company: Optional[Dict[str, str]] = None, + user: Optional[Dict[str, str]] = None, +) -> str: + parts = [flag_key] + if company: + parts.append("company:" + ";".join(f"{k}={v}" for k, v in sorted(company.items()))) + if user: + parts.append("user:" + ";".join(f"{k}={v}" for k, v in sorted(user.items()))) + return ":".join(parts) diff --git a/src/schematic/datastream/__init__.py b/src/schematic/datastream/__init__.py new file mode 100644 index 0000000..7ed02c5 --- /dev/null +++ b/src/schematic/datastream/__init__.py @@ -0,0 +1,34 @@ +from ..cache import AsyncCacheProvider, AsyncLocalCache +from .datastream_client import DataStreamClient, DataStreamClientOptions +from .merge import deep_copy_company, deep_copy_user, extract_id, partial_company, partial_user +from .rules_engine import RulesEngineClient +from .types import DataStreamBaseReq, DataStreamError, DataStreamReq, DataStreamResp, EntityType, MessageType +from .websocket_client import ClientOptions, DatastreamWSClient, convert_api_url_to_websocket_url + +__all__ = [ + # Cache + "AsyncCacheProvider", + "AsyncLocalCache", + # Datastream client + "DataStreamClient", + "DataStreamClientOptions", + # Merge utilities + "deep_copy_company", + "deep_copy_user", + "extract_id", + "partial_company", + "partial_user", + # Rules engine + "RulesEngineClient", + # Types + "DataStreamBaseReq", + "DataStreamError", + "DataStreamReq", + "DataStreamResp", + "EntityType", + "MessageType", + # WebSocket client + "ClientOptions", + "DatastreamWSClient", + "convert_api_url_to_websocket_url", +] diff --git a/src/schematic/datastream/datastream_client.py b/src/schematic/datastream/datastream_client.py new file mode 100644 index 0000000..28ca895 --- /dev/null +++ b/src/schematic/datastream/datastream_client.py @@ -0,0 +1,1027 @@ +from __future__ import annotations + +import asyncio +import json +import logging +import typing +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Set, Union + +from ..types.check_flag_request_body import CheckFlagRequestBody +from ..types.rulesengine_check_flag_result import RulesengineCheckFlagResult +from ..types.rulesengine_company import RulesengineCompany +from ..types.rulesengine_flag import RulesengineFlag +from ..types.rulesengine_user import RulesengineUser +from ..cache import AsyncCacheProvider, AsyncLocalCache +from .merge import partial_company, partial_user +from .rules_engine import RulesEngineClient +from .types import DataStreamBaseReq, DataStreamReq, DataStreamResp, EntityType, MessageType +from .websocket_client import ClientOptions as WSClientOptions, DatastreamWSClient + + +_hints_cache: Dict[type, Dict[str, Any]] = {} + + +def _get_type_hints(model_cls: type) -> Dict[str, Any]: + """Cached wrapper around typing.get_type_hints to avoid repeated introspection.""" + if model_cls not in _hints_cache: + _hints_cache[model_cls] = typing.get_type_hints(model_cls) + return _hints_cache[model_cls] + + +def _coerce_nulls(data: dict, model_cls: type) -> dict: + """Convert null values to empty lists for required list fields. + + Go serializes nil slices as JSON null, but our Pydantic models require + lists. This recursively fixes nulls before model_validate. + """ + if not hasattr(model_cls, "__annotations__"): + return data + + hints = _get_type_hints(model_cls) + result = dict(data) + for field_name, field_type in hints.items(): + if field_name not in result: + continue + origin = getattr(field_type, "__origin__", None) + args = getattr(field_type, "__args__", ()) + + # typing.List[X] — coerce None to [] + if origin is list and result[field_name] is None: + result[field_name] = [] + # typing.List[Model] — recurse into each element + elif origin is list and args and isinstance(result[field_name], list): + inner = args[0] + if hasattr(inner, "__annotations__"): + result[field_name] = [ + _coerce_nulls(item, inner) if isinstance(item, dict) else item + for item in result[field_name] + ] + return result + + +def _validate(model_cls: type, raw: Any) -> Any: + """Validate raw data into a Pydantic model, coercing Go-style nulls first.""" + if isinstance(raw, dict): + return model_cls.model_validate(_coerce_nulls(raw, model_cls)) # type: ignore[attr-defined] + return raw + + +# Cache key prefixes +_PREFIX_COMPANY = "company" +_PREFIX_USER = "user" +_PREFIX_FLAGS = "flags" + +# Timing constants (milliseconds) +RESOURCE_TIMEOUT_MS = 30_000 # 30 seconds +DEFAULT_TTL_MS = 24 * 60 * 60 * 1000 # 24 hours +MAX_CACHE_TTL_MS = 30 * 24 * 60 * 60 * 1000 # 30 days +DEFAULT_REPLICATOR_HEALTH_CHECK_MS = 30_000 # 30 seconds +REPLICATOR_HEALTH_TIMEOUT_S = 5.0 +REPLICATOR_CACHE_VERSION_TIMEOUT_S = 2.0 + + +@dataclass +class DataStreamClientOptions: + """Configuration for the DataStream client.""" + + api_key: str + logger: logging.Logger + base_url: Optional[str] = None + cache_ttl: Optional[int] = DEFAULT_TTL_MS + + # Custom cache providers (override defaults) + company_cache: Optional[AsyncCacheProvider[Any]] = None + company_lookup_cache: Optional[AsyncCacheProvider[str]] = None + user_cache: Optional[AsyncCacheProvider[Any]] = None + user_lookup_cache: Optional[AsyncCacheProvider[str]] = None + flag_cache: Optional[AsyncCacheProvider[Any]] = None + + # Replicator mode + replicator_mode: bool = False + replicator_health_url: Optional[str] = "http://localhost:8090/ready" + replicator_health_check: int = DEFAULT_REPLICATOR_HEALTH_CHECK_MS + + # Event callbacks + on_connected: Optional[Callable[[], None]] = None + on_disconnected: Optional[Callable[[], None]] = None + on_ready: Optional[Callable[[], None]] = None + on_not_ready: Optional[Callable[[], None]] = None + on_error: Optional[Callable[[Exception], None]] = None + on_replicator_health_changed: Optional[Callable[[bool], None]] = None + + +class DataStreamClient: + """Datastream client with caching, WASM flag evaluation, and replicator support. + + Manages a WebSocket connection to Schematic's datastream, caches entities + locally, and evaluates feature flags using the WASM rules engine. + + In **replicator mode** no WebSocket connection is established — the client + relies entirely on a shared cache populated by an external replicator + service and performs health checks against a configurable URL. + + Usage:: + + client = DataStreamClient(DataStreamClientOptions( + api_key="your-api-key", + base_url="https://api.schematichq.com", + logger=logging.getLogger(__name__), + )) + await client.start() + result = await client.check_flag(CheckFlagRequestBody(company={"id": "co_123"}), "premium-feature") + await client.close() + """ + + def __init__(self, options: DataStreamClientOptions) -> None: + self._api_key = options.api_key + self._base_url = options.base_url + self._logger = options.logger + self._cache_ttl = options.cache_ttl + + # Callbacks + self._on_connected = options.on_connected + self._on_disconnected = options.on_disconnected + self._on_ready = options.on_ready + self._on_not_ready = options.on_not_ready + self._on_error = options.on_error + self._on_replicator_health_changed = options.on_replicator_health_changed + + # Replicator mode + self._replicator_mode = options.replicator_mode + self._replicator_health_url = options.replicator_health_url + self._replicator_health_check_ms = options.replicator_health_check + self._replicator_ready = False + self._replicator_health_task: Optional[asyncio.Task[None]] = None + self._replicator_cache_version: Optional[str] = None + + if self._replicator_mode: + caches = [ + options.company_cache, options.company_lookup_cache, + options.user_cache, options.user_lookup_cache, + options.flag_cache, + ] + if not all(caches): + raise ValueError( + "Replicator mode requires custom cache providers for company, company_lookup, " + "user, user_lookup, and flag caches" + ) + for c in caches: + if isinstance(c, AsyncLocalCache): + raise TypeError( + "Replicator mode requires shared cache providers (e.g. RedisCache), " + "not AsyncLocalCache, to ensure shared state across processes" + ) + + # Cache providers + local_ttl = self._cache_ttl if self._cache_ttl is not None else DEFAULT_TTL_MS + flag_ttl = max(MAX_CACHE_TTL_MS, local_ttl) + self._company_cache: AsyncCacheProvider[RulesengineCompany] = options.company_cache or AsyncLocalCache(ttl=local_ttl) + self._user_cache: AsyncCacheProvider[RulesengineUser] = options.user_cache or AsyncLocalCache(ttl=local_ttl) + self._flag_cache: AsyncCacheProvider[RulesengineFlag] = options.flag_cache or AsyncLocalCache(ttl=flag_ttl) + + # Key -> ID mapping caches (two-level caching) + self._company_key_cache: AsyncCacheProvider[str] = options.company_lookup_cache or AsyncLocalCache(ttl=local_ttl) + self._user_key_cache: AsyncCacheProvider[str] = options.user_lookup_cache or AsyncLocalCache(ttl=local_ttl) + + # WebSocket client + self._ws_client: Optional[DatastreamWSClient] = None + + # Rules engine + self._rules_engine = RulesEngineClient() + + # Pending requests — maps cache key to list of asyncio Futures + self._pending_company: Dict[str, List[asyncio.Future[RulesengineCompany]]] = {} + self._pending_user: Dict[str, List[asyncio.Future[RulesengineUser]]] = {} + self._pending_flags: Optional[asyncio.Future[bool]] = None + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + async def start(self) -> None: + """Initialise and start the datastream client.""" + # Initialise the rules engine + try: + await self._rules_engine.initialize() + self._logger.debug("Rules engine initialized successfully") + except Exception as exc: + self._logger.warning("Failed to initialize rules engine: %s", exc) + + # Replicator mode — no WebSocket + if self._replicator_mode: + self._logger.info("Replicator mode enabled — skipping WebSocket connection") + if self._replicator_health_url: + # Run an initial health check synchronously so the cache version + # is available before the first flag check arrives. + await self._check_replicator_health() + self._start_replicator_health_check() + return + + if not self._base_url: + raise ValueError("base_url is required when not in replicator mode") + + self._logger.info("Starting DataStream client") + + self._ws_client = DatastreamWSClient(WSClientOptions( + url=self._base_url, + api_key=self._api_key, + message_handler=self._handle_message, + logger=self._logger, + connection_ready_handler=self._handle_connection_ready, + on_connected=self._on_ws_connected, + on_disconnected=self._on_ws_disconnected, + on_ready=self._on_ready, + on_not_ready=self._on_not_ready, + on_error=self._on_ws_error, + )) + self._ws_client.start() + + def is_connected(self) -> bool: + if self._replicator_mode: + return self._replicator_ready + return self._ws_client.is_connected() if self._ws_client else False + + def is_replicator_ready(self) -> bool: + return self._replicator_ready + + def is_replicator_mode(self) -> bool: + return self._replicator_mode + + @property + def replicator_cache_version(self) -> Optional[str]: + return self._replicator_cache_version + + async def get_company(self, keys: Dict[str, str]) -> RulesengineCompany: + """Retrieve a company by keys, using cache or datastream.""" + cached = await self._get_company_from_cache(keys) + if cached is not None: + self._logger.debug("Company found in cache for keys: %s", keys) + return cached + + if self._replicator_mode: + raise RuntimeError("Company not found in cache and replicator mode is enabled") + + if not self.is_connected(): + raise RuntimeError("DataStream client is not connected") + + cache_keys = self._generate_company_cache_keys(keys) + existing = any(k in self._pending_company for k in cache_keys) + + loop = asyncio.get_running_loop() + future: asyncio.Future[Any] = loop.create_future() + + for ck in cache_keys: + self._pending_company.setdefault(ck, []).append(future) + + if not existing: + try: + await self._send_request(DataStreamReq(entity_type=EntityType.COMPANY, keys=keys)) + except Exception as exc: + self._cleanup_pending_company(cache_keys, future) + raise + + try: + return await asyncio.wait_for(asyncio.shield(future), timeout=RESOURCE_TIMEOUT_MS / 1000) + except asyncio.TimeoutError: + self._cleanup_pending_company(cache_keys, future) + raise TimeoutError("Timeout while waiting for company data") + + async def get_user(self, keys: Dict[str, str]) -> RulesengineUser: + """Retrieve a user by keys, using cache or datastream.""" + cached = await self._get_user_from_cache(keys) + if cached is not None: + self._logger.debug("User found in cache for keys: %s", keys) + return cached + + if self._replicator_mode: + raise RuntimeError("User not found in cache and replicator mode is enabled") + + if not self.is_connected(): + raise RuntimeError("DataStream client is not connected") + + cache_keys = self._generate_user_cache_keys(keys) + existing = any(k in self._pending_user for k in cache_keys) + + loop = asyncio.get_running_loop() + future: asyncio.Future[Any] = loop.create_future() + + for ck in cache_keys: + self._pending_user.setdefault(ck, []).append(future) + + if not existing: + try: + await self._send_request(DataStreamReq(entity_type=EntityType.USER, keys=keys)) + except Exception as exc: + self._cleanup_pending_user(cache_keys, future) + raise + + try: + return await asyncio.wait_for(asyncio.shield(future), timeout=RESOURCE_TIMEOUT_MS / 1000) + except asyncio.TimeoutError: + self._cleanup_pending_user(cache_keys, future) + raise TimeoutError("Timeout while waiting for user data") + + async def get_flag(self, flag_key: str) -> Optional[RulesengineFlag]: + """Retrieve a flag by key from cache.""" + cache_key = self._flag_cache_key(flag_key) + self._logger.debug("Looking up flag cache key: %s (version=%s)", cache_key, self._get_version_key()) + try: + raw = await self._flag_cache.get(cache_key) + if raw is None: + self._logger.debug("Flag cache miss for key: %s", cache_key) + return None + return _validate(RulesengineFlag, raw) + except Exception as exc: + self._logger.warning("Failed to retrieve flag from cache: %s", exc) + return None + + async def get_all_flags(self) -> None: + """Request a refresh of all flags from the datastream.""" + if self._pending_flags is not None and not self._pending_flags.done(): + # Wait for existing request + await asyncio.wait_for(asyncio.shield(self._pending_flags), timeout=RESOURCE_TIMEOUT_MS / 1000) + return + + loop = asyncio.get_running_loop() + self._pending_flags = loop.create_future() + + try: + await self._send_request(DataStreamReq(entity_type=EntityType.FLAGS)) + except Exception: + fut = self._pending_flags + self._pending_flags = None + if not fut.done(): + fut.set_result(False) + raise + + try: + await asyncio.wait_for(asyncio.shield(self._pending_flags), timeout=RESOURCE_TIMEOUT_MS / 1000) + except asyncio.TimeoutError: + self._pending_flags = None + raise TimeoutError("Timeout while waiting for flags data") + + async def check_flag( + self, + eval_ctx: CheckFlagRequestBody, + flag_key: str, + ) -> RulesengineCheckFlagResult: + """Evaluate a flag for a company and/or user context.""" + flag = await self.get_flag(flag_key) + if flag is None: + raise RuntimeError(f"Flag not found: {flag_key}") + + company_keys = eval_ctx.company + user_keys = eval_ctx.user + needs_company = bool(company_keys) + needs_user = bool(user_keys) + + cached_company: Optional[Any] = None + cached_user: Optional[Any] = None + + if needs_company: + cached_company = await self._get_company_from_cache(company_keys) # type: ignore[arg-type] + if needs_user: + cached_user = await self._get_user_from_cache(user_keys) # type: ignore[arg-type] + + # Replicator mode — evaluate with whatever is cached + if self._replicator_mode: + return self._evaluate_flag(flag, cached_company, cached_user) + + # If we have all required entities cached, evaluate immediately + if (not needs_company or cached_company) and (not needs_user or cached_user): + return self._evaluate_flag(flag, cached_company, cached_user) + + if not self.is_connected(): + raise RuntimeError("Datastream not connected and required entities not in cache") + + # Fetch missing data in parallel + tasks = [] + if needs_company and not cached_company: + tasks.append(self.get_company(company_keys)) # type: ignore[arg-type] + else: + tasks.append(_resolved(cached_company)) + + if needs_user and not cached_user: + tasks.append(self.get_user(user_keys)) # type: ignore[arg-type] + else: + tasks.append(_resolved(cached_user)) + + results: list = await asyncio.gather(*tasks) + return self._evaluate_flag(flag, results[0], results[1]) + + async def update_company_metrics(self, keys: Dict[str, str], event: str, quantity: int) -> None: + """Update company metrics locally in cache (for track events).""" + company = await self._get_company_from_cache(keys) + if company is None: + return + + updated = company.model_copy(deep=True) + if updated.metrics: + new_metrics = [ + metric.model_copy(update={"value": (metric.value or 0) + quantity}) + if metric.event_subtype == event else metric + for metric in updated.metrics + ] + updated = updated.model_copy(update={"metrics": new_metrics}) + + await self._cache_company(updated) + + async def close(self) -> None: + """Gracefully close the datastream client.""" + self._logger.info("Closing DataStream client") + + if self._replicator_health_task is not None: + self._replicator_health_task.cancel() + self._replicator_health_task = None + + self._clear_pending_requests() + + if self._ws_client is not None: + await self._ws_client.close() + self._ws_client = None + + self._logger.info("DataStream client closed") + + # ------------------------------------------------------------------ + # WebSocket callbacks + # ------------------------------------------------------------------ + + def _on_ws_connected(self) -> None: + if self._on_connected: + self._on_connected() + + def _on_ws_disconnected(self) -> None: + self._clear_pending_requests() + if self._on_disconnected: + self._on_disconnected() + + def _on_ws_error(self, error: Exception) -> None: + if self._on_error: + self._on_error(error) + + # ------------------------------------------------------------------ + # Message handling + # ------------------------------------------------------------------ + + async def _handle_message(self, message: DataStreamResp) -> None: + self._logger.debug( + "Processing datastream message: EntityType=%s, MessageType=%s", + message.entity_type, message.message_type, + ) + try: + if message.message_type == MessageType.ERROR.value: + await self._handle_error_message(message) + return + + et = message.entity_type + if et in (EntityType.COMPANY.value, "rulesengine.Company"): + await self._handle_company_message(message) + elif et in (EntityType.USER.value, "rulesengine.User"): + await self._handle_user_message(message) + elif et in (EntityType.FLAGS.value, "rulesengine.Flags"): + await self._handle_flags_message(message) + elif et in (EntityType.FLAG.value, "rulesengine.Flag"): + await self._handle_flag_message(message) + else: + self._logger.warning("Unknown entity type: %s", et) + except Exception as exc: + self._logger.error("Error processing datastream message: %s", exc) + if self._on_error: + self._on_error(exc if isinstance(exc, Exception) else Exception(str(exc))) + + async def _handle_company_message(self, message: DataStreamResp) -> None: + raw = message.data + if not raw: + return + + # For partial updates, we need the raw dict to merge into the existing model + if message.message_type == MessageType.PARTIAL.value: + partial_data = raw if isinstance(raw, dict) else raw.model_dump() + entity_id = partial_data.get("id") if isinstance(partial_data, dict) else getattr(partial_data, "id", None) + if not entity_id: + self._logger.warning("Partial company message missing id") + return + + rk = self._resource_id_cache_key(_PREFIX_COMPANY, entity_id) + raw_existing = await self._company_cache.get(rk) + if raw_existing is None: + self._logger.warning("Partial company update for unknown entity: %s", entity_id) + return + + existing = _validate(RulesengineCompany, raw_existing) + try: + company = partial_company(existing, partial_data) + except Exception as exc: + self._logger.error("Failed to merge partial company: %s", exc) + return + else: + company = _validate(RulesengineCompany, raw) + + if message.message_type == MessageType.DELETE.value: + await self._delete_entity( + company.id, company.keys, _PREFIX_COMPANY, self._company_cache, self._company_key_cache, + ) + return + + await self._cache_company(company) + self._notify_pending_company(company.keys or {}, company) + + async def _handle_user_message(self, message: DataStreamResp) -> None: + raw = message.data + if not raw: + return + + # For partial updates, we need the raw dict to merge into the existing model + if message.message_type == MessageType.PARTIAL.value: + partial_data = raw if isinstance(raw, dict) else raw.model_dump() + entity_id = partial_data.get("id") if isinstance(partial_data, dict) else getattr(partial_data, "id", None) + if not entity_id: + self._logger.warning("Partial user message missing id") + return + + rk = self._resource_id_cache_key(_PREFIX_USER, entity_id) + raw_existing = await self._user_cache.get(rk) + if raw_existing is None: + self._logger.warning("Partial user update for unknown entity: %s", entity_id) + return + + existing = _validate(RulesengineUser, raw_existing) + try: + user = partial_user(existing, partial_data) + except Exception as exc: + self._logger.error("Failed to merge partial user: %s", exc) + return + else: + user = _validate(RulesengineUser, raw) + + if message.message_type == MessageType.DELETE.value: + await self._delete_entity( + user.id, user.keys, _PREFIX_USER, self._user_cache, self._user_key_cache, + ) + return + + await self._cache_user(user) + self._notify_pending_user(user.keys or {}, user) + + async def _handle_flags_message(self, message: DataStreamResp) -> None: + raw_flags = message.data + if not isinstance(raw_flags, list): + self._logger.warning("Expected flags array in bulk flags message") + return + + cached_keys: List[str] = [] + for raw_flag in raw_flags: + flag = _validate(RulesengineFlag, raw_flag) + flag_key = flag.key + if not flag_key: + continue + ck = self._flag_cache_key(flag_key) + try: + await self._flag_cache.set(ck, flag) + cached_keys.append(ck) + except Exception as exc: + self._logger.warning("Failed to cache flag: %s", exc) + + # Delete flags not in the response + try: + await self._flag_cache.delete_missing(cached_keys, scan_pattern="flags:*") + except (NotImplementedError, Exception) as exc: + self._logger.debug("delete_missing not supported or failed: %s", exc) + + if self._pending_flags is not None and not self._pending_flags.done(): + self._pending_flags.set_result(True) + + async def _handle_flag_message(self, message: DataStreamResp) -> None: + raw = message.data + flag = _validate(RulesengineFlag, raw) + flag_key = flag.key + if not flag_key: + return + + ck = self._flag_cache_key(flag_key) + try: + if message.message_type == MessageType.DELETE.value: + await self._flag_cache.delete(ck) + elif message.message_type == MessageType.FULL.value: + await self._flag_cache.set(ck, flag) + except Exception as exc: + self._logger.warning("Failed to update flag cache: %s", exc) + + if self._pending_flags is not None and not self._pending_flags.done(): + self._pending_flags.set_result(True) + + async def _handle_error_message(self, message: DataStreamResp) -> None: + error_data = message.data + if isinstance(error_data, dict): + keys = error_data.get("keys") + entity_type = error_data.get("entity_type") + if keys and entity_type: + if entity_type in (EntityType.COMPANY.value, "rulesengine.Company"): + self._notify_pending_company(keys, None) + elif entity_type in (EntityType.USER.value, "rulesengine.User"): + self._notify_pending_user(keys, None) + + error_msg = error_data.get("error", "Unknown datastream error") + self._logger.warning("DataStream error received: %s", error_msg) + + async def _handle_connection_ready(self) -> None: + self._logger.info("DataStream connection is ready") + try: + # Only send the flags request — don't await the response here. + # The response will be processed by the message loop, which hasn't + # started yet (it begins after this handler returns). + loop = asyncio.get_running_loop() + self._pending_flags = loop.create_future() + await self._send_request(DataStreamReq(entity_type=EntityType.FLAGS)) + self._logger.debug("Sent initial flag data request") + except Exception as exc: + self._logger.error("Failed to request initial flag data: %s", exc) + self._pending_flags = None + raise + + # ------------------------------------------------------------------ + # Request sending + # ------------------------------------------------------------------ + + async def _send_request(self, request: DataStreamReq) -> None: + if self._ws_client is None or not self._ws_client.is_connected(): + raise RuntimeError("DataStream client is not connected") + + self._logger.debug( + "Sending datastream request: EntityType=%s, Keys=%s", + request.entity_type, request.keys, + ) + await self._ws_client.send_message(DataStreamBaseReq(data=request)) + + # ------------------------------------------------------------------ + # Cache helpers + # ------------------------------------------------------------------ + + async def _get_company_from_cache(self, keys: Dict[str, str]) -> Optional[RulesengineCompany]: + for key, value in keys.items(): + ck = self._resource_key_to_cache_key(_PREFIX_COMPANY, key, value) + try: + company_id = await self._company_key_cache.get(ck) + self._logger.debug("Company lookup key %s -> %s", ck, company_id) + if company_id: + rk = self._resource_id_cache_key(_PREFIX_COMPANY, company_id) + raw = await self._company_cache.get(rk) + self._logger.debug("Company ID key %s -> %s", rk, "hit" if raw is not None else "miss") + if raw is not None: + company = _validate(RulesengineCompany, raw) + return company.model_copy(deep=True) + except Exception as exc: + self._logger.warning("Failed to retrieve company from cache: %s", exc) + return None + + async def _get_user_from_cache(self, keys: Dict[str, str]) -> Optional[RulesengineUser]: + for key, value in keys.items(): + ck = self._resource_key_to_cache_key(_PREFIX_USER, key, value) + try: + user_id = await self._user_key_cache.get(ck) + if user_id: + rk = self._resource_id_cache_key(_PREFIX_USER, user_id) + raw = await self._user_cache.get(rk) + if raw is not None: + user = _validate(RulesengineUser, raw) + return user.model_copy(deep=True) + except Exception as exc: + self._logger.warning("Failed to retrieve user from cache: %s", exc) + return None + + async def _cache_company(self, company: RulesengineCompany) -> None: + keys = company.keys + company_id = company.id + if not keys: + return + + rk = self._resource_id_cache_key(_PREFIX_COMPANY, company_id) + + # Clean up stale lookup keys by diffing old vs new + raw = await self._company_cache.get(rk) + if raw is not None: + old = _validate(RulesengineCompany, raw) + old_keys = old.keys or {} + await self._delete_stale_lookup_keys( + self._company_key_cache, _PREFIX_COMPANY, old_keys, keys, + ) + + await self._company_cache.set(rk, company, self._cache_ttl) + + for key, value in keys.items(): + ck = self._resource_key_to_cache_key(_PREFIX_COMPANY, key, value) + try: + await self._company_key_cache.set(ck, company_id, self._cache_ttl) + except Exception as exc: + self._logger.warning("Failed to cache company key mapping '%s': %s", ck, exc) + + async def _cache_user(self, user: RulesengineUser) -> None: + keys = user.keys + user_id = user.id + if not keys: + return + + rk = self._resource_id_cache_key(_PREFIX_USER, user_id) + + # Clean up stale lookup keys by diffing old vs new + raw = await self._user_cache.get(rk) + if raw is not None: + old = _validate(RulesengineUser, raw) + old_keys = old.keys or {} + await self._delete_stale_lookup_keys( + self._user_key_cache, _PREFIX_USER, old_keys, keys, + ) + + await self._user_cache.set(rk, user, self._cache_ttl) + + for key, value in keys.items(): + ck = self._resource_key_to_cache_key(_PREFIX_USER, key, value) + try: + await self._user_key_cache.set(ck, user_id, self._cache_ttl) + except Exception as exc: + self._logger.warning("Failed to cache user key mapping '%s': %s", ck, exc) + + async def _delete_entity( + self, + entity_id: Optional[str], + message_keys: Optional[Dict[str, str]], + prefix: str, + entity_cache: AsyncCacheProvider[Any], + lookup_cache: AsyncCacheProvider[str], + ) -> None: + """Delete an entity and all its lookup keys from cache. + + Fetches the cached entity first to discover all lookup keys, + since the delete message may not include them all. + """ + all_keys = message_keys + if entity_id: + rk = self._resource_id_cache_key(prefix, entity_id) + cached = await entity_cache.get(rk) + if cached is not None: + if isinstance(cached, dict): + all_keys = cached.get("keys") or message_keys + else: + all_keys = cached.keys or message_keys + + if all_keys: + for key, value in all_keys.items(): + ck = self._resource_key_to_cache_key(prefix, key, value) + try: + await lookup_cache.delete(ck) + except Exception as exc: + self._logger.warning("Failed to delete %s key mapping: %s", prefix, exc) + + if entity_id: + rk = self._resource_id_cache_key(prefix, entity_id) + try: + await entity_cache.delete(rk) + except Exception as exc: + self._logger.warning("Failed to delete %s resource: %s", prefix, exc) + + async def _delete_stale_lookup_keys( + self, + lookup_cache: AsyncCacheProvider[str], + prefix: str, + old_keys: Dict[str, str], + new_keys: Dict[str, str], + ) -> None: + """Delete lookup cache entries for keys that are no longer present.""" + old_set = {(k, v) for k, v in old_keys.items()} + new_set = {(k, v) for k, v in new_keys.items()} + for key, value in old_set - new_set: + ck = self._resource_key_to_cache_key(prefix, key, value) + try: + await lookup_cache.delete(ck) + except Exception as exc: + self._logger.warning("Failed to delete stale lookup key '%s': %s", ck, exc) + + # ------------------------------------------------------------------ + # Cache key generation + # ------------------------------------------------------------------ + + def _flag_cache_key(self, key: str) -> str: + version = self._get_version_key() + return f"{_PREFIX_FLAGS}:{version}:{key.lower()}" + + def _resource_id_cache_key(self, resource_type: str, entity_id: str) -> str: + version = self._get_version_key() + return f"{resource_type}:{version}:{entity_id}" + + def _resource_key_to_cache_key(self, resource_type: str, key: str, value: str) -> str: + version = self._get_version_key() + return f"{resource_type}:{version}:{key.lower()}:{value.lower()}" + + def _get_version_key(self) -> str: + if self._replicator_mode and self._replicator_cache_version: + return self._replicator_cache_version + try: + if self._rules_engine.is_initialized(): + return self._rules_engine.get_version_key() + except Exception: + pass + if self._replicator_mode: + self._logger.warning( + "Replicator mode active but cache version unknown — " + "cache lookups will use fallback version '1' and likely miss" + ) + return "1" + + def _generate_company_cache_keys(self, keys: Dict[str, str]) -> List[str]: + return [self._resource_key_to_cache_key(_PREFIX_COMPANY, k, v) for k, v in keys.items()] + + def _generate_user_cache_keys(self, keys: Dict[str, str]) -> List[str]: + return [self._resource_key_to_cache_key(_PREFIX_USER, k, v) for k, v in keys.items()] + + # ------------------------------------------------------------------ + # Pending request management + # ------------------------------------------------------------------ + + def _notify_pending_company(self, keys: Dict[str, str], company: Any) -> None: + for key, value in keys.items(): + ck = self._resource_key_to_cache_key(_PREFIX_COMPANY, key, value) + futures = self._pending_company.pop(ck, []) + for fut in futures: + if not fut.done(): + if company is not None: + fut.set_result(company) + else: + fut.set_exception(RuntimeError("Company not found")) + + def _notify_pending_user(self, keys: Dict[str, str], user: Any) -> None: + for key, value in keys.items(): + ck = self._resource_key_to_cache_key(_PREFIX_USER, key, value) + futures = self._pending_user.pop(ck, []) + for fut in futures: + if not fut.done(): + if user is not None: + fut.set_result(user) + else: + fut.set_exception(RuntimeError("User not found")) + + def _cleanup_pending_company(self, cache_keys: List[str], future: asyncio.Future[Any]) -> None: + for ck in cache_keys: + futures = self._pending_company.get(ck) + if futures: + try: + futures.remove(future) + except ValueError: + pass + if not futures: + del self._pending_company[ck] + + def _cleanup_pending_user(self, cache_keys: List[str], future: asyncio.Future[Any]) -> None: + for ck in cache_keys: + futures = self._pending_user.get(ck) + if futures: + try: + futures.remove(future) + except ValueError: + pass + if not futures: + del self._pending_user[ck] + + def _clear_pending_requests(self) -> None: + for futures in self._pending_company.values(): + for fut in futures: + if not fut.done(): + fut.set_exception(RuntimeError("DataStream client disconnected")) + self._pending_company.clear() + + for user_futures in self._pending_user.values(): + for user_fut in user_futures: + if not user_fut.done(): + user_fut.set_exception(RuntimeError("DataStream client disconnected")) + self._pending_user.clear() + + if self._pending_flags is not None and not self._pending_flags.done(): + self._pending_flags.set_result(False) + self._pending_flags = None + + # ------------------------------------------------------------------ + # Flag evaluation + # ------------------------------------------------------------------ + + def _evaluate_flag( + self, + flag: RulesengineFlag, + company: Optional[RulesengineCompany], + user: Optional[RulesengineUser], + ) -> RulesengineCheckFlagResult: + default_value = flag.default_value + + try: + if self._rules_engine.is_initialized(): + return self._rules_engine.check_flag(flag, company, user) + else: + self._logger.warning("Rules engine not initialized, using default flag value") + return self._make_default_result(flag, company, user, default_value, "RULES_ENGINE_UNAVAILABLE") + except Exception as exc: + self._logger.error("Rules engine evaluation failed: %s", exc) + return self._make_default_result(flag, company, user, default_value, "RULES_ENGINE_ERROR") + + @staticmethod + def _make_default_result( + flag: RulesengineFlag, + company: Optional[RulesengineCompany], + user: Optional[RulesengineUser], + value: bool, + reason: str, + ) -> RulesengineCheckFlagResult: + return RulesengineCheckFlagResult( + value=value, + reason=reason, + flag_key=flag.key, + flag_id=flag.id, + company_id=company.id if company else None, + user_id=user.id if user else None, + ) + + # ------------------------------------------------------------------ + # Replicator health checking + # ------------------------------------------------------------------ + + def _start_replicator_health_check(self) -> None: + if not self._replicator_health_url: + return + self._logger.info( + "Starting replicator health check: url=%s, interval=%dms", + self._replicator_health_url, self._replicator_health_check_ms, + ) + self._replicator_health_task = asyncio.ensure_future(self._replicator_health_loop()) + + async def _replicator_health_loop(self) -> None: + interval_s = self._replicator_health_check_ms / 1000 + while True: + await self._check_replicator_health() + try: + await asyncio.sleep(interval_s) + except asyncio.CancelledError: + break + + async def _check_replicator_health(self) -> None: + if not self._replicator_health_url: + return + try: + import httpx + + async with httpx.AsyncClient(timeout=REPLICATOR_HEALTH_TIMEOUT_S) as client: + resp = await client.get(self._replicator_health_url) + resp.raise_for_status() + health_data = resp.json() + + self._logger.debug("Replicator health response: %s", health_data) + + was_ready = self._replicator_ready + self._replicator_ready = health_data.get("ready", False) + + new_version = health_data.get("cache_version") or health_data.get("cacheVersion") + if new_version and new_version != self._replicator_cache_version: + old = self._replicator_cache_version + self._replicator_cache_version = new_version + self._logger.info("Cache version changed from %s to %s", old, new_version) + + if self._replicator_ready and not was_ready: + self._logger.info("External replicator is now ready") + if self._on_replicator_health_changed: + self._on_replicator_health_changed(True) + elif not self._replicator_ready and was_ready: + self._logger.info("External replicator is no longer ready") + if self._on_replicator_health_changed: + self._on_replicator_health_changed(False) + + except Exception as exc: + if self._replicator_ready: + self._replicator_ready = False + self._logger.info("External replicator is no longer ready") + if self._on_replicator_health_changed: + self._on_replicator_health_changed(False) + self._logger.debug("Replicator health check failed: %s", exc) + + async def get_replicator_cache_version_async(self, timeout_s: float = REPLICATOR_CACHE_VERSION_TIMEOUT_S) -> Optional[str]: + """Attempt to fetch cache version immediately if not already available.""" + if self._replicator_cache_version: + return self._replicator_cache_version + + if self._replicator_mode and self._replicator_health_url: + try: + import httpx + + async with httpx.AsyncClient(timeout=timeout_s) as client: + resp = await client.get(self._replicator_health_url) + if resp.status_code == 200: + data = resp.json() + version = data.get("cache_version") or data.get("cacheVersion") + if version: + self._replicator_cache_version = version + return version + except Exception as exc: + self._logger.debug("Failed to fetch replicator cache version: %s", exc) + + return None + + +async def _resolved(value: Any) -> Any: + """Helper that acts like an immediately-resolved coroutine.""" + return value diff --git a/src/schematic/datastream/merge.py b/src/schematic/datastream/merge.py new file mode 100644 index 0000000..139fe0a --- /dev/null +++ b/src/schematic/datastream/merge.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Tuple + +from ..types.rulesengine_company import RulesengineCompany +from ..types.rulesengine_company_metric import RulesengineCompanyMetric +from ..types.rulesengine_user import RulesengineUser + + +def extract_id(data: Any) -> Optional[str]: + """Extract the 'id' field from a dict or Pydantic model.""" + if isinstance(data, dict): + return data.get("id") + return getattr(data, "id", None) + + +def partial_company(existing: RulesengineCompany, partial: Dict[str, Any]) -> RulesengineCompany: + """Merge a partial update dict into an existing Company. + + Only fields present in `partial` are applied. Maps (keys, credit_balances) + merge additively. Metrics are upserted by (event_subtype, period, month_reset). + All other fields replace the existing value. The original is not mutated. + """ + if "id" not in partial: + raise ValueError("partial company message missing required field: id") + + updates: Dict[str, Any] = {} + + for key, value in partial.items(): + if key == "keys": + merged_keys = dict(existing.keys) if existing.keys else {} + merged_keys.update(value or {}) + updates["keys"] = merged_keys + elif key == "credit_balances": + merged_cb = dict(existing.credit_balances) if existing.credit_balances else {} + merged_cb.update(value or {}) + updates["credit_balances"] = merged_cb + elif key == "metrics": + incoming = _parse_metrics(value) + existing_metrics = [m.model_dump() for m in (existing.metrics or [])] + updates["metrics"] = _upsert_metrics(existing_metrics, incoming) + else: + updates[key] = value + + base = existing.model_dump() + base.update(updates) + return RulesengineCompany.model_validate(base) + + +def partial_user(existing: RulesengineUser, partial: Dict[str, Any]) -> RulesengineUser: + """Merge a partial update dict into an existing User. + + Only fields present in `partial` are applied. Keys map merges additively. + All other fields replace the existing value. The original is not mutated. + """ + if "id" not in partial: + raise ValueError("partial user message missing required field: id") + + updates: Dict[str, Any] = {} + + for key, value in partial.items(): + if key == "keys": + merged_keys = dict(existing.keys) if existing.keys else {} + merged_keys.update(value or {}) + updates["keys"] = merged_keys + else: + updates[key] = value + + base = existing.model_dump() + base.update(updates) + return RulesengineUser.model_validate(base) + + +def deep_copy_company(company: Optional[RulesengineCompany]) -> Optional[RulesengineCompany]: + """Create a deep copy of a Company. Returns None if input is None.""" + if company is None: + return None + return company.model_copy(deep=True) + + +def deep_copy_user(user: Optional[RulesengineUser]) -> Optional[RulesengineUser]: + """Create a deep copy of a User. Returns None if input is None.""" + if user is None: + return None + return user.model_copy(deep=True) + + +def _metric_key(metric: Any) -> Tuple[str, str, str]: + """Build the composite key used for metric upsert matching.""" + if isinstance(metric, dict): + return ( + metric.get("event_subtype", ""), + metric.get("period", ""), + metric.get("month_reset", ""), + ) + return ( + getattr(metric, "event_subtype", ""), + str(getattr(metric, "period", "")), + str(getattr(metric, "month_reset", "")), + ) + + +def _parse_metrics(raw: Any) -> List[Any]: + """Normalise incoming metrics to a list.""" + if raw is None: + return [] + if isinstance(raw, list): + return raw + return [raw] + + +def _upsert_metrics(existing: List[Any], incoming: List[Any]) -> List[Any]: + """Merge incoming metrics into existing ones. + + Metrics are matched by (event_subtype, period, month_reset). + Matches are replaced in place; new metrics are appended. + """ + result = list(existing) + index: Dict[Tuple[str, str, str], int] = {} + for i, m in enumerate(result): + if m is not None: + index[_metric_key(m)] = i + + for m in incoming: + if m is None: + continue + k = _metric_key(m) + if k in index: + result[index[k]] = m + else: + result.append(m) + + return result diff --git a/src/schematic/datastream/rules_engine.py b/src/schematic/datastream/rules_engine.py new file mode 100644 index 0000000..e86520e --- /dev/null +++ b/src/schematic/datastream/rules_engine.py @@ -0,0 +1,193 @@ +from __future__ import annotations + +import json +import logging +import re +from pathlib import Path +from typing import Any, Optional + +from ..types.rulesengine_check_flag_result import RulesengineCheckFlagResult +from ..types.rulesengine_company import RulesengineCompany +from ..types.rulesengine_flag import RulesengineFlag +from ..types.rulesengine_user import RulesengineUser + +logger = logging.getLogger(__name__) + +_CAMEL_RE = re.compile(r"([A-Z])") + + +def _camel_to_snake(name: str) -> str: + """Convert a camelCase string to snake_case.""" + return _CAMEL_RE.sub(r"_\1", name).lower().lstrip("_") + + +def _deep_camel_to_snake(obj: Any) -> Any: + """Recursively convert all dict keys from camelCase to snake_case.""" + if isinstance(obj, dict): + return {_camel_to_snake(k): _deep_camel_to_snake(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_deep_camel_to_snake(item) for item in obj] + return obj + +# Path to the WASM binary shipped alongside this module +_WASM_PATH = Path(__file__).parent / "wasm" / "rulesengine.wasm" + + +class RulesEngineClient: + """Wrapper around the Rust WASM rules engine for local flag evaluation. + + Uses ``wasmtime`` to load and execute the raw WASM binary. The WASM + module exposes a ``checkFlagCombined`` function that accepts a single + JSON envelope ``{"flag": ..., "company": ..., "user": ...}`` and returns + the evaluation result. All WASM memory management is handled internally + via the module's ``alloc``/``dealloc`` exports. + + Usage:: + + engine = RulesEngineClient() + await engine.initialize() + result = engine.check_flag(flag, company, user) + """ + + def __init__(self, *, wasm_path: Optional[str] = None) -> None: + self._wasm_path = Path(wasm_path) if wasm_path else _WASM_PATH + self._initialized = False + # wasmtime objects — set during initialize() + self._store: Any = None + self._instance: Any = None + self._memory: Any = None + self._alloc_fn: Any = None + self._dealloc_fn: Any = None + self._check_flag_fn: Any = None + self._get_result_json_fn: Any = None + self._get_result_json_length_fn: Any = None + self._get_version_key_fn: Any = None + + async def initialize(self) -> None: + """Load and instantiate the WASM module. Safe to call multiple times.""" + if self._initialized: + return + + try: + import wasmtime # type: ignore[import-untyped] + except ImportError: + raise ImportError( + "wasmtime is required for the rules engine. " + "Install it with: pip install 'schematichq[datastream]' or pip install wasmtime" + ) + + if not self._wasm_path.exists(): + raise FileNotFoundError( + f"WASM binary not found at {self._wasm_path}. " + "Ensure the rules engine WASM has been deployed to this SDK." + ) + + engine = wasmtime.Engine() + module = wasmtime.Module.from_file(engine, str(self._wasm_path)) + linker = wasmtime.Linker(engine) + linker.define_wasi() + + wasi_config = wasmtime.WasiConfig() + self._store = wasmtime.Store(engine) + self._store.set_wasi(wasi_config) + + self._instance = linker.instantiate(self._store, module) + exports = self._instance.exports(self._store) + + self._memory = exports.get("memory") + self._alloc_fn = exports.get("alloc") + self._dealloc_fn = exports.get("dealloc") + self._check_flag_fn = exports.get("checkFlagCombined") + self._get_result_json_fn = exports.get("getResultJson") + self._get_result_json_length_fn = exports.get("getResultJsonLength") + self._get_version_key_fn = exports.get("get_version_key_wasm") + + if self._memory is None: + raise RuntimeError("WASM module does not export 'memory'") + if self._alloc_fn is None: + raise RuntimeError("WASM module does not export 'alloc'") + if self._check_flag_fn is None: + raise RuntimeError("WASM module does not export 'checkFlagCombined'") + + self._initialized = True + logger.debug("Rules engine WASM initialized (version: %s)", self.get_version_key()) + + def is_initialized(self) -> bool: + return self._initialized + + def check_flag( + self, + flag: RulesengineFlag, + company: Optional[RulesengineCompany] = None, + user: Optional[RulesengineUser] = None, + ) -> RulesengineCheckFlagResult: + """Evaluate a flag using the WASM rules engine. + + Accepts Fern-generated Pydantic models (or plain dicts). Serialises + them into a single JSON envelope, passes it to the WASM module, and + returns a ``RulesengineCheckFlagResult``. + """ + self._ensure_initialized() + + envelope = { + "flag": flag.model_dump(exclude_none=True, mode="json"), + "company": company.model_dump(exclude_none=True, mode="json") if company else None, + "user": user.model_dump(exclude_none=True, mode="json") if user else None, + } + + result_json = self._call_wasm(json.dumps(envelope)) + result_data = _deep_camel_to_snake(json.loads(result_json)) + return RulesengineCheckFlagResult(**result_data) + + def get_version_key(self) -> str: + """Get the version key from the WASM rules engine. + + Used for cache key generation to ensure cache invalidation on engine updates. + """ + self._ensure_initialized() + + if self._get_version_key_fn is None: + return "1" + + ptr = self._get_version_key_fn(self._store) + return self._read_null_terminated_string(ptr) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _ensure_initialized(self) -> None: + if not self._initialized: + raise RuntimeError("Rules engine not initialized. Call initialize() first.") + + def _call_wasm(self, input_json: str) -> str: + """Write *input_json* into WASM memory, invoke the engine, and return + the result JSON string. All memory is allocated/freed via the WASM + module's own ``alloc``/``dealloc`` exports.""" + + data = input_json.encode("utf-8") + length = len(data) + + # Allocate a buffer inside WASM memory and copy our JSON into it + ptr = self._alloc_fn(self._store, length) + try: + self._memory.write(self._store, data, ptr) + + result_len = self._check_flag_fn(self._store, ptr, length) + if result_len < 0: + raise RuntimeError("WASM checkFlagCombined returned error code") + finally: + self._dealloc_fn(self._store, ptr, length) + + # Read the result (owned by WASM thread-local, no need to free) + result_ptr = self._get_result_json_fn(self._store) + actual_len = self._get_result_json_length_fn(self._store) + return bytes(self._memory.read(self._store, result_ptr, result_ptr + actual_len)).decode("utf-8") + + def _read_null_terminated_string(self, ptr: int, max_length: int = 256) -> str: + """Read a null-terminated UTF-8 string from WASM linear memory.""" + raw = self._memory.read(self._store, ptr, ptr + max_length) + null_idx = raw.find(0) + if null_idx >= 0: + raw = raw[:null_idx] + return raw.decode("utf-8") diff --git a/src/schematic/datastream/types.py b/src/schematic/datastream/types.py new file mode 100644 index 0000000..a7a6c97 --- /dev/null +++ b/src/schematic/datastream/types.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import enum +from dataclasses import dataclass, field +from typing import Any, Dict, Optional + + +class EntityType(str, enum.Enum): + COMPANY = "company" + USER = "user" + FLAG = "flag" + FLAGS = "flags" + + +class MessageType(str, enum.Enum): + FULL = "full" + PARTIAL = "partial" + DELETE = "delete" + ERROR = "error" + UNKNOWN = "unknown" + + +class Action(str, enum.Enum): + START = "start" + STOP = "stop" + + +@dataclass +class DataStreamReq: + """Request message sent to the datastream.""" + + entity_type: EntityType + keys: Optional[Dict[str, str]] = None + action: Action = Action.START + + def to_dict(self) -> Dict[str, Any]: + d: Dict[str, Any] = { + "action": self.action.value, + "entity_type": self.entity_type.value, + } + if self.keys is not None: + d["keys"] = self.keys + return d + + +@dataclass +class DataStreamBaseReq: + """Wrapper around DataStreamReq — the wire format expected by the server.""" + + data: DataStreamReq + + def to_dict(self) -> Dict[str, Any]: + return {"data": self.data.to_dict()} + + +@dataclass +class DataStreamResp: + """Response message received from the datastream.""" + + data: Any + entity_type: str + message_type: str + + +@dataclass +class DataStreamError: + """Error message received from the datastream.""" + + error: str + keys: Optional[Dict[str, str]] = None + entity_type: Optional[EntityType] = None diff --git a/src/schematic/datastream/websocket_client.py b/src/schematic/datastream/websocket_client.py new file mode 100644 index 0000000..c71387b --- /dev/null +++ b/src/schematic/datastream/websocket_client.py @@ -0,0 +1,423 @@ +from __future__ import annotations + +# Note: This client is designed for server-side async environments only. + +import asyncio +import json +import logging +import random +from dataclasses import dataclass +from typing import Any, Awaitable, Callable, Dict, Optional, Union +from urllib.parse import urlparse, urlunparse + +try: + import websockets # type: ignore[import-untyped] +except ImportError: + websockets = None # type: ignore[assignment] + +from .types import DataStreamBaseReq, DataStreamResp + + +# Connection timing constants (seconds) +WRITE_WAIT = 10.0 +PONG_WAIT = 60.0 +PING_PERIOD = (PONG_WAIT * 9) / 10 # 54 seconds +CONNECTION_TIMEOUT = 30.0 + +# Reconnection constants +MAX_RECONNECT_ATTEMPTS = 10 +MIN_RECONNECT_DELAY = 1.0 # seconds +MAX_RECONNECT_DELAY = 30.0 # seconds + +MessageHandlerFunc = Callable[[DataStreamResp], Awaitable[None]] +ConnectionReadyHandlerFunc = Callable[[], Awaitable[None]] + + +def convert_api_url_to_websocket_url(api_url: str) -> str: + """Convert an HTTP API URL to a WebSocket datastream URL. + + Examples: + https://api.schematichq.com -> wss://datastream.schematichq.com/datastream + https://api.staging.x.com -> wss://datastream.staging.x.com/datastream + https://custom.example.com -> wss://custom.example.com/datastream + http://localhost:8080 -> ws://localhost:8080/datastream + """ + parsed = urlparse(api_url) + + if parsed.scheme == "https": + scheme = "wss" + elif parsed.scheme == "http": + scheme = "ws" + else: + raise ValueError(f"Unsupported scheme: {parsed.scheme!r} (must be http or https)") + + # Replace 'api' subdomain with 'datastream' if present + hostname = parsed.hostname or "" + parts = hostname.split(".") + if len(parts) > 1 and parts[0] == "api": + parts[0] = "datastream" + hostname = ".".join(parts) + + netloc = f"{hostname}:{parsed.port}" if parsed.port else hostname + return urlunparse((scheme, netloc, "/datastream", "", "", "")) + + +@dataclass +class ClientOptions: + """Configuration for DatastreamWSClient.""" + + # Required + url: str + api_key: str + message_handler: MessageHandlerFunc + logger: logging.Logger + + # Optional + connection_ready_handler: Optional[ConnectionReadyHandlerFunc] = None + max_reconnect_attempts: int = MAX_RECONNECT_ATTEMPTS + min_reconnect_delay: float = MIN_RECONNECT_DELAY + max_reconnect_delay: float = MAX_RECONNECT_DELAY + + # Event callbacks — called on state transitions + on_connected: Optional[Callable[[], None]] = None + on_disconnected: Optional[Callable[[], None]] = None + on_ready: Optional[Callable[[], None]] = None + on_not_ready: Optional[Callable[[], None]] = None + on_error: Optional[Callable[[Exception], None]] = None + + +class DatastreamWSClient: + """WebSocket client for the Schematic datastream with automatic reconnection. + + Connects to the Schematic datastream, delivers incoming messages to the + provided ``message_handler``, and transparently reconnects with exponential + backoff whenever the connection drops. + + Usage:: + + async def handle_message(msg: DataStreamResp) -> None: + print(msg) + + client = DatastreamWSClient(ClientOptions( + url="https://api.schematichq.com", + api_key="your-api-key", + message_handler=handle_message, + logger=logging.getLogger(__name__), + )) + client.start() + # ... later: + await client.close() + """ + + def __init__(self, options: ClientOptions) -> None: + if not options.url: + raise ValueError("url is required") + if not options.api_key: + raise ValueError("api_key is required") + if options.message_handler is None: # type: ignore[operator] + raise ValueError("message_handler is required") + + # Auto-convert HTTP(S) URLs to WebSocket URLs + if options.url.startswith(("http://", "https://")): + self._url = convert_api_url_to_websocket_url(options.url) + else: + self._url = options.url + + self._headers: Dict[str, str] = {"X-Schematic-Api-Key": options.api_key} + self._logger = options.logger + self._message_handler = options.message_handler + self._connection_ready_handler = options.connection_ready_handler + self._max_reconnect_attempts = options.max_reconnect_attempts + self._min_reconnect_delay = options.min_reconnect_delay + self._max_reconnect_delay = options.max_reconnect_delay + + # Event callbacks + self._on_connected = options.on_connected + self._on_disconnected = options.on_disconnected + self._on_ready = options.on_ready + self._on_not_ready = options.on_not_ready + self._on_error = options.on_error + + # Connection state + self._ws: Any = None + self._connected: bool = False + self._ready: bool = False + + # Control state + self._should_reconnect: bool = False + self._reconnect_attempts: int = 0 + self._task: Optional[asyncio.Task[None]] = None + self._ping_task: Optional[asyncio.Task[None]] = None + + def start(self) -> None: + """Begin the WebSocket connection loop as a background asyncio task.""" + self._should_reconnect = True + self._reconnect_attempts = 0 + self._task = asyncio.ensure_future(self._connect_and_read()) + + def is_connected(self) -> bool: + """Return whether the WebSocket is currently connected.""" + return self._connected + + def is_ready(self) -> bool: + """Return whether the client is connected and fully initialised.""" + return self._ready and self._connected + + async def send_message(self, message: DataStreamBaseReq) -> None: + """Send a message over the WebSocket connection. + + Raises ``RuntimeError`` if the connection is not available or the send + times out after ``WRITE_WAIT`` seconds. + """ + if not self.is_connected() or self._ws is None: + raise RuntimeError("WebSocket connection is not available") + + payload = json.dumps(message.to_dict()) + try: + await asyncio.wait_for(self._ws.send(payload), timeout=WRITE_WAIT) + except asyncio.TimeoutError: + raise RuntimeError("Write timeout") + + async def close(self) -> None: + """Gracefully close the connection and stop the reconnection loop.""" + self._logger.info("Closing WebSocket connection") + + self._should_reconnect = False + self._set_ready(False) + self._set_connected(False) + + if self._ws is not None: + try: + await self._ws.close() + except Exception: + pass + self._ws = None + + # Cancel and await the ping task + if self._ping_task is not None: + self._ping_task.cancel() + try: + await self._ping_task + except (asyncio.CancelledError, Exception): + pass + self._ping_task = None + + # Cancel and await the main connection task + if self._task is not None: + self._task.cancel() + try: + await self._task + except (asyncio.CancelledError, Exception): + pass + self._task = None + + # ------------------------------------------------------------------ + # Internal connection loop + # ------------------------------------------------------------------ + + async def _connect_and_read(self) -> None: + while self._should_reconnect: + try: + if websockets is None: + raise ImportError( + "websockets is required for DataStream. " + "Install it with: pip install 'schematichq[datastream]'" + ) + async with websockets.connect( + self._url, + additional_headers=self._headers, + open_timeout=CONNECTION_TIMEOUT, + # Disable the library's built-in keepalive — we manage + # ping/pong ourselves to match the Node SDK behaviour and + # avoid conflicts with the Go server's gorilla/websocket. + ping_interval=None, + ping_timeout=None, + ) as ws: + self._ws = ws + self._reconnect_attempts = 0 + self._set_connected(True) + + # Run the ready handler before marking the client ready + if self._connection_ready_handler is not None: + try: + await self._connection_ready_handler() + self._logger.debug("Connection ready handler completed successfully") + except Exception as err: + self._reconnect_attempts += 1 + self._logger.error(f"Connection ready handler failed: {err}") + if self._reconnect_attempts >= self._max_reconnect_attempts: + self._logger.error("Max reconnection attempts reached") + if self._on_error is not None: + self._on_error(Exception("Max reconnection attempts reached")) + break + continue + + self._set_ready(True) + self._logger.debug("WebSocket client is ready") + + self._start_ping_pong() + try: + async for raw_message in ws: + await self._handle_message(raw_message) + finally: + self._stop_ping_pong() + + self._logger.info("WebSocket connection closed") + + except asyncio.CancelledError: + break + + except Exception as err: + self._reconnect_attempts += 1 + self._logger.warning( + f"WebSocket connection failed: {err}, " + f"attempt {self._reconnect_attempts}/{self._max_reconnect_attempts}" + ) + + if self._reconnect_attempts >= self._max_reconnect_attempts: + self._logger.error("Max reconnection attempts reached") + if self._on_error is not None: + self._on_error(Exception("Max reconnection attempts reached")) + break + + finally: + self._set_connected(False) + self._set_ready(False) + self._ws = None + + if not self._should_reconnect: + break + + delay = self._calculate_backoff_delay(self._reconnect_attempts) + self._logger.debug(f"Waiting {delay:.1f}s before reconnecting...") + try: + await asyncio.sleep(delay) + except asyncio.CancelledError: + break + + # ------------------------------------------------------------------ + # Keepalive ping/pong + # ------------------------------------------------------------------ + + def _start_ping_pong(self) -> None: + """Start the application-level ping/pong keepalive loop.""" + self._logger.debug("Starting ping/pong keepalive mechanism") + self._ping_task = asyncio.ensure_future(self._ping_loop()) + + def _stop_ping_pong(self) -> None: + """Stop the keepalive loop.""" + if self._ping_task is not None: + self._ping_task.cancel() + self._ping_task = None + + async def _ping_loop(self) -> None: + """Send pings every PING_PERIOD seconds, close if multiple consecutive pongs are missed. + + Pong frames are only processed by the websockets library during recv(), + so if the message handler is busy, a single pong may be delayed. We + tolerate up to 2 consecutive missed pongs before closing. + """ + missed = 0 + max_missed = 2 + + while self._connected and self._ws is not None: + try: + await asyncio.sleep(PING_PERIOD) + except asyncio.CancelledError: + return + + if not self._connected or self._ws is None: + return + + try: + pong_waiter = await self._ws.ping() + await asyncio.wait_for(pong_waiter, timeout=PONG_WAIT) + missed = 0 + self._logger.debug("Pong received from server") + except asyncio.TimeoutError: + missed += 1 + self._logger.warning("Pong timeout (%d/%d)", missed, max_missed) + if missed >= max_missed: + self._logger.warning("Max pong timeouts reached — closing connection") + self._set_connected(False) + if self._ws is not None: + await self._ws.close() + return + except asyncio.CancelledError: + return + except Exception as exc: + self._logger.debug("Ping failed: %s", exc) + return + + # ------------------------------------------------------------------ + # Message handling + # ------------------------------------------------------------------ + + async def _handle_message(self, raw_message: Union[str, bytes]) -> None: + try: + if isinstance(raw_message, bytes): + message_str = raw_message.decode("utf-8") + else: + message_str = str(raw_message) + + try: + data = json.loads(message_str) + except json.JSONDecodeError as err: + if self._on_error is not None: + self._on_error(Exception(f"Failed to parse datastream message: {err}")) + return + + message = DataStreamResp( + data=data.get("data"), + entity_type=data.get("entity_type", ""), + message_type=data.get("message_type", ""), + ) + + try: + await self._message_handler(message) + except Exception as err: + if self._on_error is not None: + self._on_error(Exception(f"Message handler error: {err}")) + + except Exception as err: + if self._on_error is not None: + self._on_error(err) + + # ------------------------------------------------------------------ + # State management + # ------------------------------------------------------------------ + + def _set_connected(self, connected: bool) -> None: + was_connected = self._connected + self._connected = connected + + if not connected: + self._ready = False + + if was_connected != connected: + self._logger.debug(f"Connection state changed: {connected}") + if connected: + if self._on_connected is not None: + self._on_connected() + else: + if self._on_disconnected is not None: + self._on_disconnected() + + def _set_ready(self, ready: bool) -> None: + was_ready = self._ready + self._ready = ready + + if was_ready != ready: + self._logger.debug(f"Ready state changed: {ready}") + if ready: + if self._on_ready is not None: + self._on_ready() + else: + if self._on_not_ready is not None: + self._on_not_ready() + + def _calculate_backoff_delay(self, attempt: int) -> float: + """Exponential backoff with jitter, capped at max_reconnect_delay.""" + jitter = random.uniform(0, self._min_reconnect_delay) + delay = (2 ** (attempt - 1)) * self._min_reconnect_delay + jitter + return min(delay, self._max_reconnect_delay + jitter) diff --git a/tests/custom/test_cache.py b/tests/custom/test_cache.py index 5f7ba7e..3cd5c8f 100644 --- a/tests/custom/test_cache.py +++ b/tests/custom/test_cache.py @@ -1,3 +1,5 @@ +import time +import threading import unittest from schematic.cache import LocalCache @@ -27,5 +29,221 @@ def test_cache_size_limit(self): self.assertEqual(val3, "value3") +class TestLocalCacheGetSet(unittest.TestCase): + """Corresponds to Go TestLocalCache_Get_Set.""" + + def test_basic_set_and_get(self): + cache = LocalCache(max_size=10, ttl=5000) + cache.set("key1", "value1") + self.assertEqual(cache.get("key1"), "value1") + + def test_get_nonexistent_key(self): + cache = LocalCache(max_size=10, ttl=5000) + self.assertIsNone(cache.get("missing")) + + def test_overwrite_existing_key(self): + cache = LocalCache(max_size=10, ttl=5000) + cache.set("key1", "value1") + cache.set("key1", "value2") + self.assertEqual(cache.get("key1"), "value2") + + def test_set_with_custom_ttl(self): + cache = LocalCache(max_size=10, ttl=5000) + cache.set("key1", "value1", ttl_override=1) # 1ms TTL + time.sleep(0.05) + self.assertIsNone(cache.get("key1")) + + +class TestLocalCacheDelete(unittest.TestCase): + """Corresponds to Go TestLocalCache_Delete.""" + + def test_delete_existing_key(self): + cache = LocalCache(max_size=10, ttl=5000) + cache.set("key1", "value1") + cache.set("key2", "value2") + # Remove key1 by overwriting the cache internals + del cache.cache["key1"] + self.assertIsNone(cache.get("key1")) + self.assertEqual(cache.get("key2"), "value2") + + def test_delete_nonexistent_key(self): + """Deleting a nonexistent key should not raise.""" + cache = LocalCache(max_size=10, ttl=5000) + # Accessing a missing key via cache internals should not error + cache.cache.pop("missing", None) + + +class TestLocalCacheLRU(unittest.TestCase): + """Corresponds to Go TestLocalCache_LRU.""" + + def test_lru_eviction(self): + cache = LocalCache(max_size=3, ttl=5000) + cache.set("key1", "value1") + cache.set("key2", "value2") + cache.set("key3", "value3") + + # Access key1 to make it most recently used + cache.get("key1") + + # Adding key4 should evict key2 (least recently used) + cache.set("key4", "value4") + + self.assertEqual(cache.get("key1"), "value1") + self.assertIsNone(cache.get("key2")) + self.assertEqual(cache.get("key3"), "value3") + self.assertEqual(cache.get("key4"), "value4") + + +class TestLocalCacheExpiration(unittest.TestCase): + """Corresponds to Go TestLocalCache_Expiration.""" + + def test_items_expire_after_ttl(self): + cache = LocalCache(max_size=10, ttl=50) # 50ms TTL + cache.set("key1", "value1") + + # Immediately available + self.assertEqual(cache.get("key1"), "value1") + + # Gone after TTL + time.sleep(0.1) + self.assertIsNone(cache.get("key1")) + + def test_new_items_work_after_expiration(self): + cache = LocalCache(max_size=10, ttl=50) + cache.set("key1", "value1") + time.sleep(0.1) + self.assertIsNone(cache.get("key1")) + + # New items should still work + cache.set("key2", "value2") + self.assertEqual(cache.get("key2"), "value2") + + +class TestLocalCacheCleanExpired(unittest.TestCase): + """Corresponds to Go TestLocalCache_CleanupRoutine.""" + + def test_clean_expired_removes_stale_items(self): + cache = LocalCache(max_size=10, ttl=50) + for i in range(5): + cache.set(f"key{i}", f"value{i}") + + time.sleep(0.1) + cache.clean_expired() + + for i in range(5): + self.assertIsNone(cache.get(f"key{i}")) + + def test_clean_expired_keeps_valid_items(self): + cache = LocalCache(max_size=10, ttl=5000) + cache.set("key1", "value1") + cache.clean_expired() + self.assertEqual(cache.get("key1"), "value1") + + +class TestLocalCacheZeroSize(unittest.TestCase): + """Corresponds to Go TestLocalCache_NilSafety (zero-size cache acts as disabled).""" + + def test_get_returns_none(self): + cache = LocalCache(max_size=0, ttl=5000) + self.assertIsNone(cache.get("key1")) + + def test_set_is_noop(self): + cache = LocalCache(max_size=0, ttl=5000) + cache.set("key1", "value1") + self.assertIsNone(cache.get("key1")) + + +class TestLocalCacheDefaults(unittest.TestCase): + """Corresponds to Go TestLocalCache_DefaultCache.""" + + def test_default_cache_has_correct_defaults(self): + from schematic.cache.local import DEFAULT_CACHE_SIZE, DEFAULT_CACHE_TTL + cache = LocalCache() + self.assertEqual(cache.max_size, DEFAULT_CACHE_SIZE) + self.assertEqual(cache.ttl, DEFAULT_CACHE_TTL) + + +class TestLocalCacheDifferentTypes(unittest.TestCase): + """Corresponds to Go TestLocalCache_DifferentTypes.""" + + def test_string_cache(self): + cache = LocalCache(max_size=10, ttl=5000) + cache.set("key", "hello") + self.assertEqual(cache.get("key"), "hello") + + def test_int_cache(self): + cache = LocalCache(max_size=10, ttl=5000) + cache.set("key", 42) + self.assertEqual(cache.get("key"), 42) + + def test_dict_cache(self): + cache = LocalCache(max_size=10, ttl=5000) + val = {"name": "test", "items": [1, 2, 3]} + cache.set("key", val) + self.assertEqual(cache.get("key"), val) + + +class TestLocalCacheConcurrency(unittest.TestCase): + """Corresponds to Go TestLocalCache_Concurrency and TestLocalCache_ConcurrentSafe.""" + + def test_concurrent_reads_and_writes(self): + cache = LocalCache(max_size=100, ttl=5000) + errors = [] + + def worker(worker_id): + try: + for i in range(20): + key = f"key-{worker_id}-{i}" + cache.set(key, f"value-{i}") + cache.get(key) + cache.get(f"key-{(worker_id + 1) % 5}-{i}") + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + self.assertEqual(len(errors), 0, f"Concurrent access errors: {errors}") + + +class TestLocalCacheEdgeCases(unittest.TestCase): + """Corresponds to Go TestLocalCache_EdgeCases.""" + + def test_very_short_ttl(self): + cache = LocalCache(max_size=10, ttl=1) # 1ms + cache.set("key", "value") + time.sleep(0.01) + self.assertIsNone(cache.get("key")) + + def test_very_large_ttl(self): + cache = LocalCache(max_size=10, ttl=100 * 365 * 24 * 60 * 60 * 1000) # 100 years + cache.set("key", "value") + self.assertEqual(cache.get("key"), "value") + + def test_zero_ttl(self): + """Zero TTL means items expire immediately.""" + cache = LocalCache(max_size=10, ttl=0) + cache.set("key", "value") + # With 0 TTL, expiration = now, so item is already expired + self.assertIsNone(cache.get("key")) + + def test_ttl_override_shorter_than_default(self): + cache = LocalCache(max_size=10, ttl=5000) + cache.set("key", "value", ttl_override=1) # 1ms override + time.sleep(0.05) + self.assertIsNone(cache.get("key")) + + def test_max_size_enforcement(self): + cache = LocalCache(max_size=3, ttl=5000) + for i in range(10): + cache.set(f"key{i}", f"value{i}") + # Only the last 3 should remain + count = sum(1 for i in range(10) if cache.get(f"key{i}") is not None) + self.assertEqual(count, 3) + + if __name__ == "__main__": unittest.main() diff --git a/tests/custom/test_cache_key.py b/tests/custom/test_cache_key.py new file mode 100644 index 0000000..6fdae11 --- /dev/null +++ b/tests/custom/test_cache_key.py @@ -0,0 +1,90 @@ +"""Tests for flag check cache key generation. + +Corresponds to Go flags/flags_test.go TestFlagCheckCacheKey. +""" + +import unittest + +from schematic.client import _build_cache_key + + +class TestBuildCacheKey(unittest.TestCase): + """Corresponds to Go TestFlagCheckCacheKey.""" + + def test_empty_context_and_flag_key(self): + result = _build_cache_key("") + self.assertEqual(result, "") + + def test_flag_key_only(self): + result = _build_cache_key("feature_flag_1") + self.assertEqual(result, "feature_flag_1") + + def test_with_company_and_user(self): + result = _build_cache_key( + "feature_flag_1", + company={"id": "123", "name": "ACME Inc."}, + user={"id": "456", "email": "john@example.com"}, + ) + # Should include flag key, company, and user in the cache key + self.assertIn("feature_flag_1", result) + self.assertIn("123", result) + self.assertIn("ACME Inc.", result) + self.assertIn("456", result) + self.assertIn("john@example.com", result) + + def test_with_company_only(self): + result = _build_cache_key( + "feature_flag_2", + company={"id": "789", "name": "XYZ Corp."}, + ) + self.assertIn("feature_flag_2", result) + self.assertIn("789", result) + self.assertIn("XYZ Corp.", result) + + def test_with_user_only(self): + result = _build_cache_key( + "feature_flag_3", + user={"id": "abc", "email": "jane@example.com"}, + ) + self.assertIn("feature_flag_3", result) + self.assertIn("abc", result) + self.assertIn("jane@example.com", result) + + def test_different_contexts_produce_different_keys(self): + """Different company/user contexts should produce different cache keys.""" + key1 = _build_cache_key("flag", company={"id": "comp-1"}) + key2 = _build_cache_key("flag", company={"id": "comp-2"}) + key3 = _build_cache_key("flag", user={"id": "user-1"}) + self.assertNotEqual(key1, key2) + self.assertNotEqual(key1, key3) + + def test_same_context_produces_same_key(self): + """Same inputs should always produce the same cache key.""" + key1 = _build_cache_key("flag", company={"id": "comp-1"}, user={"id": "user-1"}) + key2 = _build_cache_key("flag", company={"id": "comp-1"}, user={"id": "user-1"}) + self.assertEqual(key1, key2) + + def test_deterministic_regardless_of_insertion_order(self): + """Dict insertion order must not affect the cache key.""" + key1 = _build_cache_key("flag", company={"id": "comp-1", "name": "ACME"}) + key2 = _build_cache_key("flag", company={"name": "ACME", "id": "comp-1"}) + self.assertEqual(key1, key2) + + def test_deterministic_with_multiple_user_keys(self): + """Multiple user keys in different order produce the same cache key.""" + key1 = _build_cache_key("flag", user={"email": "a@b.com", "id": "u1", "phone": "555"}) + key2 = _build_cache_key("flag", user={"phone": "555", "email": "a@b.com", "id": "u1"}) + self.assertEqual(key1, key2) + + def test_exact_format(self): + """Verify the canonical cache key format.""" + result = _build_cache_key( + "my-flag", + company={"id": "co1"}, + user={"email": "a@b.com", "id": "u1"}, + ) + self.assertEqual(result, "my-flag:company:id=co1:user:email=a@b.com;id=u1") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/custom/test_client.py b/tests/custom/test_client.py index c3fdec5..e646693 100644 --- a/tests/custom/test_client.py +++ b/tests/custom/test_client.py @@ -1,15 +1,19 @@ +import time import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from httpx import AsyncClient, Client +from schematic.cache import LocalCache from schematic.client import ( AsyncSchematic, AsyncSchematicConfig, + CheckFlagOptions, Schematic, SchematicConfig, ) +from schematic.types import CheckFlagResponseData, FeatureEntitlement class TestSchematic(unittest.TestCase): @@ -44,6 +48,89 @@ def test_check_flag_online(self): ) self.assertTrue(result) + def test_check_flag_with_entitlement_offline(self): + self.schematic.offline = True + self.schematic.flag_defaults = {"test_flag": True} + result = self.schematic.check_flag_with_entitlement( + "test_flag", + company={"id": "company_id"}, + user={"id": "user_id"}, + ) + self.assertIsInstance(result, CheckFlagResponseData) + self.assertTrue(result.value) + self.assertEqual(result.flag, "test_flag") + self.assertEqual(result.reason, "flag default") + + def test_check_flag_with_entitlement_online(self): + self.schematic.offline = False + mock_data = CheckFlagResponseData( + value=True, + company_id="comp_123", + entitlement=None, + error=None, + flag="test_flag", + flag_id="flag_123", + reason="rule_match", + rule_id="rule_123", + rule_type="override", + user_id="user_123", + ) + self.schematic.features.check_flag = MagicMock( + return_value=MagicMock(data=mock_data) + ) + result = self.schematic.check_flag_with_entitlement( + "test_flag", + company={"id": "company_id"}, + user={"id": "user_id"}, + ) + self.assertIsInstance(result, CheckFlagResponseData) + self.assertTrue(result.value) + self.assertEqual(result.company_id, "comp_123") + self.assertEqual(result.reason, "rule_match") + self.assertEqual(result.rule_id, "rule_123") + + def test_check_flag_with_options_default_value(self): + self.schematic.offline = True + options = CheckFlagOptions(default_value=True) + result = self.schematic.check_flag("missing_flag", options=options) + self.assertTrue(result) + + def test_check_flag_with_options_callable_default(self): + self.schematic.offline = True + options = CheckFlagOptions(default_value=lambda: True) + result = self.schematic.check_flag("missing_flag", options=options) + self.assertTrue(result) + + def test_check_flag_caches_full_response(self): + """Verify that cache stores the full response, not just a bool.""" + self.schematic.offline = False + mock_data = CheckFlagResponseData( + value=True, + company_id="comp_123", + entitlement=None, + error=None, + flag="test_flag", + flag_id="flag_123", + reason="rule_match", + rule_id="rule_123", + rule_type=None, + user_id=None, + ) + self.schematic.features.check_flag = MagicMock( + return_value=MagicMock(data=mock_data) + ) + + # First call populates cache + result1 = self.schematic.check_flag_with_entitlement("test_flag") + self.assertEqual(result1.company_id, "comp_123") + + # Second call should hit cache + result2 = self.schematic.check_flag_with_entitlement("test_flag") + self.assertEqual(result2.company_id, "comp_123") + + # API should only have been called once + self.schematic.features.check_flag.assert_called_once() + def test_identify(self): with patch.object(self.schematic.event_buffer, "push") as mock_push: self.schematic.identify( @@ -61,6 +148,261 @@ def test_track(self): ) mock_push.assert_called_once() + def test_track_with_quantity(self): + with patch.object(self.schematic.event_buffer, "push") as mock_push: + self.schematic.track( + event="api-call", + company={"id": "company_id"}, + quantity=5, + ) + mock_push.assert_called_once() + + def test_check_flag_with_no_cache(self): + """Verify that when cache_providers is empty, every call hits the API.""" + config = SchematicConfig( + event_buffer_period=1, + logger=MagicMock(), + httpx_client=MagicMock(spec=Client), + cache_providers=[], + ) + client = Schematic("api_key", config) + try: + mock_data = CheckFlagResponseData( + value=True, + flag="test_flag", + reason="match", + ) + client.features.check_flag = MagicMock( + return_value=MagicMock(data=mock_data) + ) + + result1 = client.check_flag("test_flag") + result2 = client.check_flag("test_flag") + self.assertTrue(result1) + self.assertTrue(result2) + self.assertEqual(client.features.check_flag.call_count, 2) + finally: + client.event_buffer.stop() + + def test_check_flag_with_cache_ttl_expiry(self): + """Verify cache expires after TTL.""" + short_ttl_cache = LocalCache(max_size=1000, ttl=50) # 50ms TTL + config = SchematicConfig( + event_buffer_period=1, + logger=MagicMock(), + httpx_client=MagicMock(spec=Client), + cache_providers=[short_ttl_cache], + ) + client = Schematic("api_key", config) + try: + mock_data = CheckFlagResponseData( + value=True, + flag="test_flag", + reason="match", + ) + client.features.check_flag = MagicMock( + return_value=MagicMock(data=mock_data) + ) + + # First call hits API and caches + self.assertTrue(client.check_flag("test_flag")) + # Second call should hit cache + self.assertTrue(client.check_flag("test_flag")) + self.assertEqual(client.features.check_flag.call_count, 1) + + # Wait for TTL to expire + time.sleep(0.1) + + # Third call should miss cache and hit API again + self.assertTrue(client.check_flag("test_flag")) + self.assertEqual(client.features.check_flag.call_count, 2) + finally: + client.event_buffer.stop() + + def test_check_flag_returns_default_on_api_error(self): + """Verify that API errors return the flag default value.""" + self.schematic.flag_defaults = {"test_flag": True} + self.schematic.flag_check_cache_providers = [] + self.schematic.features.check_flag = MagicMock( + side_effect=Exception("api error") + ) + result = self.schematic.check_flag("test_flag") + self.assertTrue(result) + + def test_check_flag_returns_false_on_error_no_default(self): + """Verify that API errors with no default return False.""" + self.schematic.flag_check_cache_providers = [] + self.schematic.features.check_flag = MagicMock( + side_effect=Exception("connection refused") + ) + result = self.schematic.check_flag("test_flag") + self.assertFalse(result) + + def test_check_flag_offline_no_default(self): + """Verify that offline mode with no default returns False.""" + self.schematic.offline = True + result = self.schematic.check_flag("test_flag") + self.assertFalse(result) + + def test_check_flag_with_company_context_only(self): + """Verify flag check passes company context correctly.""" + self.schematic.flag_check_cache_providers = [] + mock_data = CheckFlagResponseData( + value=True, + flag="test_flag", + reason="match", + ) + self.schematic.features.check_flag = MagicMock( + return_value=MagicMock(data=mock_data) + ) + result = self.schematic.check_flag( + "test_flag", + company={"company-id": "comp-123"}, + ) + self.assertTrue(result) + call_kwargs = self.schematic.features.check_flag.call_args + self.assertEqual(call_kwargs.kwargs["company"], {"company-id": "comp-123"}) + self.assertIsNone(call_kwargs.kwargs["user"]) + + def test_check_flag_with_user_context_only(self): + """Verify flag check passes user context correctly.""" + self.schematic.flag_check_cache_providers = [] + mock_data = CheckFlagResponseData( + value=True, + flag="test_flag", + reason="match", + ) + self.schematic.features.check_flag = MagicMock( + return_value=MagicMock(data=mock_data) + ) + result = self.schematic.check_flag( + "test_flag", + user={"user-id": "user-123"}, + ) + self.assertTrue(result) + call_kwargs = self.schematic.features.check_flag.call_args + self.assertIsNone(call_kwargs.kwargs["company"]) + self.assertEqual(call_kwargs.kwargs["user"], {"user-id": "user-123"}) + + def test_check_flag_with_both_contexts(self): + """Verify flag check passes both company and user context.""" + self.schematic.flag_check_cache_providers = [] + mock_data = CheckFlagResponseData( + value=True, + flag="test_flag", + reason="match", + ) + self.schematic.features.check_flag = MagicMock( + return_value=MagicMock(data=mock_data) + ) + result = self.schematic.check_flag( + "test_flag", + company={"company-id": "comp-123"}, + user={"user-id": "user-123"}, + ) + self.assertTrue(result) + call_kwargs = self.schematic.features.check_flag.call_args + self.assertEqual(call_kwargs.kwargs["company"], {"company-id": "comp-123"}) + self.assertEqual(call_kwargs.kwargs["user"], {"user-id": "user-123"}) + + def test_check_flag_with_entitlement_nil_entitlement(self): + """Verify handling of API response with no entitlement.""" + self.schematic.flag_check_cache_providers = [] + mock_data = CheckFlagResponseData( + value=False, + flag="test_flag", + reason="no matching rules", + entitlement=None, + rule_type=None, + ) + self.schematic.features.check_flag = MagicMock( + return_value=MagicMock(data=mock_data) + ) + result = self.schematic.check_flag_with_entitlement("test_flag") + self.assertIsInstance(result, CheckFlagResponseData) + self.assertFalse(result.value) + self.assertIsNone(result.entitlement) + self.assertIsNone(result.rule_type) + + def test_check_flag_with_entitlement_cache_preserves_entitlement(self): + """Verify cache hit preserves full entitlement data.""" + entitlement = FeatureEntitlement( + feature_id="feat-123", + feature_key="test-feature", + value_type="numeric", + allocation=100, + usage=50, + ) + mock_data = CheckFlagResponseData( + value=True, + flag="test_flag", + reason="entitlement matched", + company_id="comp-123", + flag_id="flag-456", + rule_id="rule-789", + rule_type="plan_entitlement", + user_id="user-321", + entitlement=entitlement, + ) + self.schematic.features.check_flag = MagicMock( + return_value=MagicMock(data=mock_data) + ) + + # First call hits API + result1 = self.schematic.check_flag_with_entitlement("test_flag") + self.assertTrue(result1.value) + self.assertIsNotNone(result1.entitlement) + self.assertEqual(result1.entitlement.feature_id, "feat-123") + + # Second call should hit cache and preserve entitlement + result2 = self.schematic.check_flag_with_entitlement("test_flag") + self.assertTrue(result2.value) + self.assertEqual(result2.reason, "entitlement matched") + self.assertEqual(result2.company_id, "comp-123") + self.assertEqual(result2.flag_id, "flag-456") + self.assertEqual(result2.rule_id, "rule-789") + self.assertEqual(result2.rule_type, "plan_entitlement") + self.assertEqual(result2.user_id, "user-321") + self.assertIsNotNone(result2.entitlement) + self.assertEqual(result2.entitlement.feature_id, "feat-123") + self.assertEqual(result2.entitlement.feature_key, "test-feature") + self.assertEqual(result2.entitlement.allocation, 100) + + # API should only have been called once + self.schematic.features.check_flag.assert_called_once() + + def test_check_flag_with_entitlement_reason_strings(self): + """Corresponds to Go TestCheckFlagWithEntitlement_ReasonStrings. + + Verify that reason, rule_type, and other string fields are preserved. + """ + self.schematic.flag_check_cache_providers = [] + mock_data = CheckFlagResponseData( + value=True, + flag="test_flag", + reason="match", + company_id="comp-123", + flag_id="flag-456", + rule_id="rule-789", + rule_type="override", + entitlement=None, + ) + self.schematic.features.check_flag = MagicMock( + return_value=MagicMock(data=mock_data) + ) + result = self.schematic.check_flag_with_entitlement( + "test_flag", + company={"company-id": "comp-123"}, + ) + self.assertIsInstance(result, CheckFlagResponseData) + self.assertTrue(result.value) + self.assertEqual(result.reason, "match") + self.assertEqual(result.company_id, "comp-123") + self.assertEqual(result.flag_id, "flag-456") + self.assertEqual(result.rule_id, "rule-789") + self.assertEqual(result.rule_type, "override") + self.assertIsNone(result.entitlement) + def tearDown(self): self.schematic.event_buffer.stop() @@ -102,6 +444,50 @@ async def test_check_flag_online(self): ) assert result + async def test_check_flag_with_entitlement_offline(self): + self.async_schematic.offline = True + self.async_schematic.flag_defaults = {"test_flag": True} + result = await self.async_schematic.check_flag_with_entitlement( + "test_flag", + company={"id": "company_id"}, + ) + assert isinstance(result, CheckFlagResponseData) + assert result.value is True + assert result.flag == "test_flag" + assert result.reason == "flag default" + + async def test_check_flag_with_entitlement_online(self): + self.async_schematic.offline = False + mock_data = CheckFlagResponseData( + value=True, + company_id="comp_123", + entitlement=None, + error=None, + flag="test_flag", + flag_id="flag_123", + reason="rule_match", + rule_id="rule_123", + rule_type="override", + user_id="user_123", + ) + self.async_schematic.features.check_flag = AsyncMock( + return_value=MagicMock(data=mock_data) + ) + result = await self.async_schematic.check_flag_with_entitlement( + "test_flag", + company={"id": "company_id"}, + ) + assert isinstance(result, CheckFlagResponseData) + assert result.value is True + assert result.company_id == "comp_123" + assert result.reason == "rule_match" + + async def test_check_flag_with_options(self): + self.async_schematic.offline = True + options = CheckFlagOptions(default_value=True) + result = await self.async_schematic.check_flag("missing_flag", options=options) + assert result is True + async def test_identify(self): with patch.object(self.async_schematic.event_buffer, "push") as mock_push: await self.async_schematic.identify( @@ -119,6 +505,251 @@ async def test_track(self): ) mock_push.assert_called_once() + async def test_check_flag_with_no_cache(self): + """Verify that when cache_providers is empty, every call hits the API.""" + config = AsyncSchematicConfig( + event_buffer_period=1, + logger=MagicMock(), + httpx_client=MagicMock(spec=AsyncClient), + cache_providers=[], + ) + client = AsyncSchematic("test_key", config) + try: + mock_data = CheckFlagResponseData( + value=True, + flag="test_flag", + reason="match", + ) + client.features.check_flag = AsyncMock( + return_value=MagicMock(data=mock_data) + ) + + result1 = await client.check_flag("test_flag") + result2 = await client.check_flag("test_flag") + assert result1 is True + assert result2 is True + assert client.features.check_flag.call_count == 2 + finally: + await client.event_buffer.stop() + + async def test_check_flag_returns_default_on_api_error(self): + """Verify that API errors return the flag default value.""" + self.async_schematic.flag_defaults = {"test_flag": True} + self.async_schematic.flag_check_cache_providers = [] + self.async_schematic.features.check_flag = AsyncMock( + side_effect=Exception("api error") + ) + result = await self.async_schematic.check_flag("test_flag") + assert result is True + + async def test_check_flag_returns_false_on_error_no_default(self): + """Verify that API errors with no default return False.""" + self.async_schematic.flag_check_cache_providers = [] + self.async_schematic.features.check_flag = AsyncMock( + side_effect=Exception("connection refused") + ) + result = await self.async_schematic.check_flag("test_flag") + assert result is False + + async def test_check_flag_offline_no_default(self): + """Verify that offline mode with no default returns False.""" + self.async_schematic.offline = True + result = await self.async_schematic.check_flag("test_flag") + assert result is False + + async def test_check_flag_with_company_context_only(self): + """Verify flag check passes company context correctly.""" + self.async_schematic.flag_check_cache_providers = [] + mock_data = CheckFlagResponseData( + value=True, + flag="test_flag", + reason="match", + ) + self.async_schematic.features.check_flag = AsyncMock( + return_value=MagicMock(data=mock_data) + ) + result = await self.async_schematic.check_flag( + "test_flag", + company={"company-id": "comp-123"}, + ) + assert result is True + call_kwargs = self.async_schematic.features.check_flag.call_args + assert call_kwargs.kwargs["company"] == {"company-id": "comp-123"} + assert call_kwargs.kwargs["user"] is None + + async def test_check_flag_with_user_context_only(self): + """Verify flag check passes user context correctly.""" + self.async_schematic.flag_check_cache_providers = [] + mock_data = CheckFlagResponseData( + value=True, + flag="test_flag", + reason="match", + ) + self.async_schematic.features.check_flag = AsyncMock( + return_value=MagicMock(data=mock_data) + ) + result = await self.async_schematic.check_flag( + "test_flag", + user={"user-id": "user-123"}, + ) + assert result is True + call_kwargs = self.async_schematic.features.check_flag.call_args + assert call_kwargs.kwargs["company"] is None + assert call_kwargs.kwargs["user"] == {"user-id": "user-123"} + + async def test_check_flag_with_both_contexts(self): + """Verify flag check passes both company and user context.""" + self.async_schematic.flag_check_cache_providers = [] + mock_data = CheckFlagResponseData( + value=True, + flag="test_flag", + reason="match", + ) + self.async_schematic.features.check_flag = AsyncMock( + return_value=MagicMock(data=mock_data) + ) + result = await self.async_schematic.check_flag( + "test_flag", + company={"company-id": "comp-123"}, + user={"user-id": "user-123"}, + ) + assert result is True + call_kwargs = self.async_schematic.features.check_flag.call_args + assert call_kwargs.kwargs["company"] == {"company-id": "comp-123"} + assert call_kwargs.kwargs["user"] == {"user-id": "user-123"} + + async def test_check_flag_with_entitlement_nil_entitlement(self): + """Verify handling of API response with no entitlement.""" + self.async_schematic.flag_check_cache_providers = [] + mock_data = CheckFlagResponseData( + value=False, + flag="test_flag", + reason="no matching rules", + entitlement=None, + rule_type=None, + ) + self.async_schematic.features.check_flag = AsyncMock( + return_value=MagicMock(data=mock_data) + ) + result = await self.async_schematic.check_flag_with_entitlement("test_flag") + assert isinstance(result, CheckFlagResponseData) + assert result.value is False + assert result.entitlement is None + assert result.rule_type is None + + async def test_check_flag_with_entitlement_cache_preserves_entitlement(self): + """Verify cache hit preserves full entitlement data.""" + entitlement = FeatureEntitlement( + feature_id="feat-123", + feature_key="test-feature", + value_type="numeric", + allocation=100, + usage=50, + ) + mock_data = CheckFlagResponseData( + value=True, + flag="test_flag", + reason="entitlement matched", + company_id="comp-123", + flag_id="flag-456", + rule_id="rule-789", + rule_type="plan_entitlement", + user_id="user-321", + entitlement=entitlement, + ) + self.async_schematic.features.check_flag = AsyncMock( + return_value=MagicMock(data=mock_data) + ) + + # First call hits API + result1 = await self.async_schematic.check_flag_with_entitlement("test_flag") + assert result1.value is True + assert result1.entitlement is not None + assert result1.entitlement.feature_id == "feat-123" + + # Second call should hit cache and preserve entitlement + result2 = await self.async_schematic.check_flag_with_entitlement("test_flag") + assert result2.value is True + assert result2.reason == "entitlement matched" + assert result2.company_id == "comp-123" + assert result2.flag_id == "flag-456" + assert result2.rule_id == "rule-789" + assert result2.rule_type == "plan_entitlement" + assert result2.user_id == "user-321" + assert result2.entitlement is not None + assert result2.entitlement.feature_id == "feat-123" + assert result2.entitlement.feature_key == "test-feature" + assert result2.entitlement.allocation == 100 + + # API should only have been called once + self.async_schematic.features.check_flag.assert_called_once() + + async def test_check_flag_with_entitlement_reason_strings(self): + """Corresponds to Go TestCheckFlagWithEntitlement_ReasonStrings (async).""" + self.async_schematic.flag_check_cache_providers = [] + mock_data = CheckFlagResponseData( + value=True, + flag="test_flag", + reason="match", + company_id="comp-123", + flag_id="flag-456", + rule_id="rule-789", + rule_type="override", + entitlement=None, + ) + self.async_schematic.features.check_flag = AsyncMock( + return_value=MagicMock(data=mock_data) + ) + result = await self.async_schematic.check_flag_with_entitlement( + "test_flag", + company={"company-id": "comp-123"}, + ) + assert isinstance(result, CheckFlagResponseData) + assert result.value is True + assert result.reason == "match" + assert result.company_id == "comp-123" + assert result.flag_id == "flag-456" + assert result.rule_id == "rule-789" + assert result.rule_type == "override" + assert result.entitlement is None + + async def test_check_flag_datastream_fallback_to_api(self): + """Corresponds to Go TestCheckFlagDatastreamFallbackToAPI. + + When datastream is configured but fails, should fall back to API. + """ + config = AsyncSchematicConfig( + logger=MagicMock(), + httpx_client=MagicMock(spec=AsyncClient), + event_buffer_period=1, + use_datastream=True, + ) + client = AsyncSchematic("test_key", config) + try: + # Mock the datastream client to raise an error + mock_ds = MagicMock() + mock_ds.check_flag = AsyncMock(side_effect=Exception("datastream failed")) + client._datastream_client = mock_ds + + # Mock the API to return a valid response + mock_data = CheckFlagResponseData( + value=True, + flag="test_flag", + reason="match", + ) + client.features.check_flag = AsyncMock( + return_value=MagicMock(data=mock_data) + ) + client.flag_check_cache_providers = [] + + result = await client.check_flag("test_flag", company={"id": "test-company"}) + assert result is True + + # API should have been called as fallback + client.features.check_flag.assert_called_once() + finally: + await client.event_buffer.stop() + if __name__ == "__main__": unittest.main() diff --git a/tests/custom/test_event_buffer.py b/tests/custom/test_event_buffer.py index bd71ff2..6dffd1a 100644 --- a/tests/custom/test_event_buffer.py +++ b/tests/custom/test_event_buffer.py @@ -1,3 +1,5 @@ +import threading +import time import unittest from unittest.mock import MagicMock, patch, call @@ -52,6 +54,81 @@ def test_stop(self): self.event_buffer.stop() self.assertTrue(self.event_buffer.shutdown.is_set()) + def test_shutdown_flushes_remaining(self): + """Corresponds to Go TestEventBuffer_ShutdownFlushesRemaining. + + Verify that stop() flushes buffered events even if batch isn't full. + """ + mock_api = MagicMock() + mock_logger = MagicMock() + buffer = EventBuffer( + events_api=mock_api, + logger=mock_logger, + period=10, # Long period so periodic flush won't trigger + max_events=100, # Large batch so auto-flush won't trigger + ) + + # Push several events (fewer than max_events so no auto-flush) + for i in range(5): + event = MagicMock(spec=CreateEventRequestBody) + buffer.push(event) + + # No flush should have happened yet + mock_api.create_event_batch.assert_not_called() + + # Stop the buffer, which should flush remaining events + buffer.stop() + + # Verify all events were flushed + mock_api.create_event_batch.assert_called_once() + flushed_events = mock_api.create_event_batch.call_args.kwargs["events"] + self.assertEqual(len(flushed_events), 5) + + def test_concurrent_push(self): + """Corresponds to Go TestEventBuffer_ConcurrentPush. + + Verify no events are lost when pushing from multiple threads. + """ + mock_api = MagicMock() + mock_logger = MagicMock() + buffer = EventBuffer( + events_api=mock_api, + logger=mock_logger, + period=10, # Long period to avoid periodic flush during test + max_events=1000, # Large batch to avoid auto-flush + ) + + num_threads = 10 + events_per_thread = 20 + total_expected = num_threads * events_per_thread + errors = [] + + def worker(): + try: + for _ in range(events_per_thread): + event = MagicMock(spec=CreateEventRequestBody) + buffer.push(event) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=worker) for _ in range(num_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + self.assertEqual(len(errors), 0, f"Concurrent push errors: {errors}") + + # Stop to flush remaining events + buffer.stop() + + # Count total events sent + total_sent = sum( + len(c.kwargs["events"]) + for c in mock_api.create_event_batch.call_args_list + ) + self.assertEqual(total_sent, total_expected) + @pytest.mark.asyncio class TestAsyncEventBuffer: @@ -157,6 +234,35 @@ async def test_stop(self): assert buffer.shutdown_event.is_set() assert buffer.stopped is True + async def test_shutdown_flushes_remaining(self): + """Corresponds to Go TestEventBuffer_ShutdownFlushesRemaining (async).""" + mock_api = MagicMock() + mock_logger = MagicMock() + task_mock = MagicMock() + + with patch('asyncio.create_task', return_value=task_mock): + buffer = AsyncEventBuffer( + events_api=mock_api, + logger=mock_logger, + period=10, + max_events=100, + max_retries=0, + ) + + # Push events (fewer than max_events) + for _ in range(5): + event = MagicMock(spec=CreateEventRequestBody) + await buffer.push(event) + + mock_api.create_event_batch.assert_not_called() + + # Stop should flush remaining events + await buffer.stop() + + mock_api.create_event_batch.assert_called_once() + flushed = mock_api.create_event_batch.call_args.kwargs["events"] + assert len(flushed) == 5 + if __name__ == "__main__": unittest.main() diff --git a/tests/datastream/__init__.py b/tests/datastream/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/datastream/test_cache.py b/tests/datastream/test_cache.py new file mode 100644 index 0000000..f95f54a --- /dev/null +++ b/tests/datastream/test_cache.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import asyncio +import time + +import pytest + +from schematic.cache import AsyncCacheProvider as CacheProvider, AsyncLocalCache as LocalCache + + +class TestLocalCacheGet: + async def test_returns_none_for_missing_key(self) -> None: + cache: LocalCache[str] = LocalCache(ttl=5000) + assert await cache.get("nonexistent") is None + + async def test_returns_stored_value(self) -> None: + cache: LocalCache[str] = LocalCache(ttl=5000) + await cache.set("key1", "value1") + assert await cache.get("key1") == "value1" + + async def test_returns_none_for_expired_item(self) -> None: + cache: LocalCache[str] = LocalCache(ttl=1) # 1ms TTL + await cache.set("key1", "value1") + await asyncio.sleep(0.01) # Wait for expiration + assert await cache.get("key1") is None + + +class TestLocalCacheSet: + async def test_overwrites_existing_value(self) -> None: + cache: LocalCache[str] = LocalCache(ttl=5000) + await cache.set("key1", "value1") + await cache.set("key1", "value2") + assert await cache.get("key1") == "value2" + + async def test_respects_max_items_with_lru_eviction(self) -> None: + cache: LocalCache[str] = LocalCache(max_items=2, ttl=5000) + await cache.set("a", "1") + await cache.set("b", "2") + # Access 'a' to make it recently used + await cache.get("a") + # Adding 'c' should evict 'b' (least recently used) + await cache.set("c", "3") + assert await cache.get("a") == "1" + assert await cache.get("b") is None + assert await cache.get("c") == "3" + + async def test_disabled_cache_when_max_items_zero(self) -> None: + cache: LocalCache[str] = LocalCache(max_items=0, ttl=5000) + await cache.set("key1", "value1") + assert await cache.get("key1") is None + + async def test_ttl_override(self) -> None: + cache: LocalCache[str] = LocalCache(ttl=5000) + await cache.set("key1", "value1", ttl=1) # 1ms override + await asyncio.sleep(0.01) + assert await cache.get("key1") is None + + +class TestLocalCacheDelete: + async def test_deletes_existing_key(self) -> None: + cache: LocalCache[str] = LocalCache(ttl=5000) + await cache.set("key1", "value1") + await cache.delete("key1") + assert await cache.get("key1") is None + + async def test_no_error_deleting_nonexistent_key(self) -> None: + cache: LocalCache[str] = LocalCache(ttl=5000) + await cache.delete("nonexistent") # Should not raise + + +class TestLocalCacheDeleteMissing: + async def test_removes_keys_not_in_keep_list(self) -> None: + cache: LocalCache[str] = LocalCache(ttl=5000) + await cache.set("a", "1") + await cache.set("b", "2") + await cache.set("c", "3") + await cache.delete_missing(["a", "c"]) + assert await cache.get("a") == "1" + assert await cache.get("b") is None + assert await cache.get("c") == "3" + + +class TestCacheProviderInterface: + async def test_raises_not_implemented(self) -> None: + provider: CacheProvider[str] = CacheProvider() + with pytest.raises(NotImplementedError): + await provider.get("key") + with pytest.raises(NotImplementedError): + await provider.set("key", "value") + with pytest.raises(NotImplementedError): + await provider.delete("key") + with pytest.raises(NotImplementedError): + await provider.delete_missing(["key"]) diff --git a/tests/datastream/test_datastream_client.py b/tests/datastream/test_datastream_client.py new file mode 100644 index 0000000..9f96c8a --- /dev/null +++ b/tests/datastream/test_datastream_client.py @@ -0,0 +1,660 @@ +from __future__ import annotations + +import asyncio +import logging +from typing import Any, Dict, List, Optional +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from schematic.cache import AsyncCacheProvider as CacheProvider, AsyncLocalCache as LocalCache +from schematic.datastream.datastream_client import DataStreamClient, DataStreamClientOptions +from schematic.datastream.types import DataStreamResp, EntityType, MessageType +from schematic.types import CheckFlagRequestBody, RulesengineCheckFlagResult + + +class MockCacheProvider(CacheProvider[Any]): + """Simple in-memory cache for testing.""" + + def __init__(self) -> None: + self._store: Dict[str, Any] = {} + + async def get(self, key: str) -> Optional[Any]: + return self._store.get(key) + + async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None: + self._store[key] = value + + async def delete(self, key: str) -> None: + self._store.pop(key, None) + + async def delete_missing(self, keys_to_keep: List[str], *, scan_pattern: Optional[str] = None) -> None: + keep = set(keys_to_keep) + to_delete = [k for k in self._store if k not in keep] + for k in to_delete: + del self._store[k] + + +@pytest.fixture +def logger() -> logging.Logger: + return logging.getLogger("test_datastream") + + +class TestDataStreamClientInit: + def test_replicator_mode_requires_cache_providers(self, logger: logging.Logger) -> None: + with pytest.raises(ValueError, match="Replicator mode requires"): + DataStreamClient(DataStreamClientOptions( + api_key="test-key", + logger=logger, + replicator_mode=True, + )) + + def test_replicator_mode_accepts_custom_cache(self, logger: logging.Logger) -> None: + cache = MockCacheProvider() + client = DataStreamClient(DataStreamClientOptions( + api_key="test-key", + logger=logger, + replicator_mode=True, + company_cache=cache, + company_lookup_cache=cache, + user_cache=cache, + user_lookup_cache=cache, + flag_cache=cache, + )) + assert client.is_replicator_mode() + assert not client.is_connected() + + def test_normal_mode_defaults(self, logger: logging.Logger) -> None: + client = DataStreamClient(DataStreamClientOptions( + api_key="test-key", + base_url="https://api.schematichq.com", + logger=logger, + )) + assert not client.is_replicator_mode() + assert not client.is_connected() + + +class TestDataStreamClientReplicatorMode: + @pytest.fixture + def client(self, logger: logging.Logger) -> DataStreamClient: + cache = MockCacheProvider() + return DataStreamClient(DataStreamClientOptions( + api_key="test-key", + logger=logger, + replicator_mode=True, + company_cache=cache, + company_lookup_cache=cache, + user_cache=cache, + user_lookup_cache=cache, + flag_cache=cache, + )) + + async def test_get_company_raises_when_not_cached(self, client: DataStreamClient) -> None: + with pytest.raises(RuntimeError, match="not found in cache"): + await client.get_company({"id": "co_123"}) + + async def test_get_user_raises_when_not_cached(self, client: DataStreamClient) -> None: + with pytest.raises(RuntimeError, match="not found in cache"): + await client.get_user({"id": "usr_123"}) + + +class TestDataStreamClientMessageHandling: + @pytest.fixture + def client_with_cache(self, logger: logging.Logger) -> tuple[DataStreamClient, MockCacheProvider, MockCacheProvider, MockCacheProvider]: + company_cache = MockCacheProvider() + user_cache = MockCacheProvider() + flag_cache = MockCacheProvider() + lookup_cache = MockCacheProvider() + client = DataStreamClient(DataStreamClientOptions( + api_key="test-key", + logger=logger, + replicator_mode=True, + company_cache=company_cache, + company_lookup_cache=lookup_cache, + user_cache=user_cache, + user_lookup_cache=lookup_cache, + flag_cache=flag_cache, + )) + return client, company_cache, user_cache, flag_cache + + async def test_handle_company_message_caches_data( + self, + client_with_cache: tuple[DataStreamClient, MockCacheProvider, MockCacheProvider, MockCacheProvider], + ) -> None: + client, company_cache, _, _ = client_with_cache + msg = DataStreamResp( + data={ + "id": "co_123", + "keys": {"slug": "acme"}, + "account_id": "acc_1", + "environment_id": "env_1", + "billing_product_ids": [], + "credit_balances": {}, + "metrics": [], + "plan_ids": [], + "plan_version_ids": [], + "rules": [], + "traits": [], + }, + entity_type=EntityType.COMPANY.value, + message_type=MessageType.FULL.value, + ) + await client._handle_message(msg) + + # Company should be retrievable from cache + company = await client._get_company_from_cache({"slug": "acme"}) + assert company is not None + assert company.id == "co_123" + + async def test_handle_user_message_caches_data( + self, + client_with_cache: tuple[DataStreamClient, MockCacheProvider, MockCacheProvider, MockCacheProvider], + ) -> None: + client, _, user_cache, _ = client_with_cache + msg = DataStreamResp( + data={ + "id": "usr_456", + "keys": {"email": "test@example.com"}, + "account_id": "acc_1", + "environment_id": "env_1", + "rules": [], + "traits": [], + }, + entity_type=EntityType.USER.value, + message_type=MessageType.FULL.value, + ) + await client._handle_message(msg) + + user = await client._get_user_from_cache({"email": "test@example.com"}) + assert user is not None + assert user.id == "usr_456" + + async def test_handle_flags_message_caches_all_flags( + self, + client_with_cache: tuple[DataStreamClient, MockCacheProvider, MockCacheProvider, MockCacheProvider], + ) -> None: + client, _, _, flag_cache = client_with_cache + flags = [ + {"key": "flag-a", "id": "f1", "default_value": True, "rules": [], "account_id": "acc_1", "environment_id": "env_1"}, + {"key": "flag-b", "id": "f2", "default_value": False, "rules": [], "account_id": "acc_1", "environment_id": "env_1"}, + ] + msg = DataStreamResp( + data=flags, + entity_type=EntityType.FLAGS.value, + message_type=MessageType.FULL.value, + ) + await client._handle_message(msg) + + flag_a = await client.get_flag("flag-a") + flag_b = await client.get_flag("flag-b") + assert flag_a is not None + assert flag_a.key == "flag-a" + assert flag_b is not None + assert flag_b.key == "flag-b" + + async def test_handle_flag_delete( + self, + client_with_cache: tuple[DataStreamClient, MockCacheProvider, MockCacheProvider, MockCacheProvider], + ) -> None: + client, _, _, flag_cache = client_with_cache + + # First add a flag + msg = DataStreamResp( + data={"key": "flag-x", "id": "fx", "default_value": True, "rules": [], "account_id": "acc_1", "environment_id": "env_1"}, + entity_type=EntityType.FLAG.value, + message_type=MessageType.FULL.value, + ) + await client._handle_message(msg) + assert await client.get_flag("flag-x") is not None + + # Then delete it + msg = DataStreamResp( + data={"key": "flag-x", "id": "fx", "default_value": False, "rules": [], "account_id": "acc_1", "environment_id": "env_1"}, + entity_type=EntityType.FLAG.value, + message_type=MessageType.DELETE.value, + ) + await client._handle_message(msg) + assert await client.get_flag("flag-x") is None + + async def test_handle_company_delete( + self, + client_with_cache: tuple[DataStreamClient, MockCacheProvider, MockCacheProvider, MockCacheProvider], + ) -> None: + client, company_cache, _, _ = client_with_cache + + # Add company + msg = DataStreamResp( + data={ + "id": "co_del", + "keys": {"slug": "delete-me"}, + "account_id": "acc_1", + "environment_id": "env_1", + "billing_product_ids": [], + "credit_balances": {}, + "metrics": [], + "plan_ids": [], + "plan_version_ids": [], + "rules": [], + "traits": [], + }, + entity_type=EntityType.COMPANY.value, + message_type=MessageType.FULL.value, + ) + await client._handle_message(msg) + assert await client._get_company_from_cache({"slug": "delete-me"}) is not None + + # Delete company + msg = DataStreamResp( + data={ + "id": "co_del", + "keys": {"slug": "delete-me"}, + "account_id": "acc_1", + "environment_id": "env_1", + "billing_product_ids": [], + "credit_balances": {}, + "metrics": [], + "plan_ids": [], + "plan_version_ids": [], + "rules": [], + "traits": [], + }, + entity_type=EntityType.COMPANY.value, + message_type=MessageType.DELETE.value, + ) + await client._handle_message(msg) + assert await client._get_company_from_cache({"slug": "delete-me"}) is None + + async def test_handle_error_message( + self, + client_with_cache: tuple[DataStreamClient, MockCacheProvider, MockCacheProvider, MockCacheProvider], + ) -> None: + client, _, _, _ = client_with_cache + msg = DataStreamResp( + data={"error": "test error", "keys": {"id": "co_err"}, "entity_type": EntityType.COMPANY.value}, + entity_type=EntityType.COMPANY.value, + message_type=MessageType.ERROR.value, + ) + # Should not raise + await client._handle_message(msg) + + +class TestDataStreamClientClose: + async def test_close_cleans_up(self, logger: logging.Logger) -> None: + cache = MockCacheProvider() + client = DataStreamClient(DataStreamClientOptions( + api_key="test-key", + logger=logger, + replicator_mode=True, + company_cache=cache, + company_lookup_cache=cache, + user_cache=cache, + user_lookup_cache=cache, + flag_cache=cache, + )) + await client.close() # Should not raise + + +class TestDataStreamClientCacheKeys: + def test_flag_cache_key_uses_version(self, logger: logging.Logger) -> None: + client = DataStreamClient(DataStreamClientOptions( + api_key="test-key", + base_url="https://api.schematichq.com", + logger=logger, + )) + key = client._flag_cache_key("Premium-Feature") + # Should be lowercased and include version key + assert "premium-feature" in key + assert key.startswith("flags:") + + def test_resource_key_to_cache_key_lowercases(self, logger: logging.Logger) -> None: + client = DataStreamClient(DataStreamClientOptions( + api_key="test-key", + base_url="https://api.schematichq.com", + logger=logger, + )) + key = client._resource_key_to_cache_key("company", "Slug", "AcmeCorp") + assert "slug" in key + assert "acmecorp" in key + + +class TestDataStreamClientFlagEvaluation: + async def test_evaluate_flag_returns_default_when_engine_unavailable(self, logger: logging.Logger) -> None: + from schematic.types import RulesengineFlag + + cache = MockCacheProvider() + client = DataStreamClient(DataStreamClientOptions( + api_key="test-key", + logger=logger, + replicator_mode=True, + company_cache=cache, + company_lookup_cache=cache, + user_cache=cache, + user_lookup_cache=cache, + flag_cache=cache, + )) + flag = RulesengineFlag( + id="f1", key="test", account_id="a", environment_id="e", + default_value=True, rules=[], + ) + result = client._evaluate_flag(flag, None, None) + assert isinstance(result, RulesengineCheckFlagResult) + assert result.value is True + assert result.reason == "RULES_ENGINE_UNAVAILABLE" + assert result.flag_key == "test" + + async def test_check_flag_raises_when_flag_not_found(self, logger: logging.Logger) -> None: + cache = MockCacheProvider() + client = DataStreamClient(DataStreamClientOptions( + api_key="test-key", + logger=logger, + replicator_mode=True, + company_cache=cache, + company_lookup_cache=cache, + user_cache=cache, + user_lookup_cache=cache, + flag_cache=cache, + )) + with pytest.raises(RuntimeError, match="Flag not found"): + await client.check_flag(CheckFlagRequestBody(), "nonexistent-flag") + + async def test_flag_evaluation_with_cached_company(self, logger: logging.Logger) -> None: + """Spec test #6: Flag evaluation with cached company.""" + from schematic.types import RulesengineFlag + + cache = MockCacheProvider() + client = DataStreamClient(DataStreamClientOptions( + api_key="test-key", + logger=logger, + replicator_mode=True, + company_cache=cache, + company_lookup_cache=cache, + user_cache=cache, + user_lookup_cache=cache, + flag_cache=cache, + )) + + # Cache a company via full message + await client._handle_message(DataStreamResp( + data={ + "id": "co_eval", + "keys": {"slug": "eval-co"}, + "account_id": "acc_1", + "environment_id": "env_1", + "billing_product_ids": [], + "credit_balances": {}, + "metrics": [], + "plan_ids": [], + "plan_version_ids": [], + "rules": [], + "traits": [], + }, + entity_type=EntityType.COMPANY.value, + message_type=MessageType.FULL.value, + )) + + # Cache a flag + await client._handle_message(DataStreamResp( + data={"key": "co-flag", "id": "f1", "default_value": True, "rules": [], "account_id": "acc_1", "environment_id": "env_1"}, + entity_type=EntityType.FLAG.value, + message_type=MessageType.FULL.value, + )) + + result = await client.check_flag( + CheckFlagRequestBody(company={"slug": "eval-co"}), + "co-flag", + ) + assert isinstance(result, RulesengineCheckFlagResult) + assert result.company_id == "co_eval" + assert result.flag_key == "co-flag" + + async def test_flag_evaluation_with_cached_user(self, logger: logging.Logger) -> None: + """Spec test #7: Flag evaluation with cached user.""" + from schematic.types import RulesengineFlag + + cache = MockCacheProvider() + client = DataStreamClient(DataStreamClientOptions( + api_key="test-key", + logger=logger, + replicator_mode=True, + company_cache=cache, + company_lookup_cache=cache, + user_cache=cache, + user_lookup_cache=cache, + flag_cache=cache, + )) + + # Cache a user + await client._handle_message(DataStreamResp( + data={ + "id": "usr_eval", + "keys": {"email": "eval@test.com"}, + "account_id": "acc_1", + "environment_id": "env_1", + "rules": [], + "traits": [], + }, + entity_type=EntityType.USER.value, + message_type=MessageType.FULL.value, + )) + + # Cache a flag + await client._handle_message(DataStreamResp( + data={"key": "usr-flag", "id": "f2", "default_value": False, "rules": [], "account_id": "acc_1", "environment_id": "env_1"}, + entity_type=EntityType.FLAG.value, + message_type=MessageType.FULL.value, + )) + + result = await client.check_flag( + CheckFlagRequestBody(user={"email": "eval@test.com"}), + "usr-flag", + ) + assert isinstance(result, RulesengineCheckFlagResult) + assert result.user_id == "usr_eval" + assert result.flag_key == "usr-flag" + + +class TestDataStreamClientPartialMerge: + """Spec test #4: Partial entity message merges into cache.""" + + @pytest.fixture + def client_with_cache(self, logger: logging.Logger) -> tuple[DataStreamClient, MockCacheProvider]: + cache = MockCacheProvider() + client = DataStreamClient(DataStreamClientOptions( + api_key="test-key", + logger=logger, + replicator_mode=True, + company_cache=cache, + company_lookup_cache=cache, + user_cache=cache, + user_lookup_cache=cache, + flag_cache=cache, + )) + return client, cache + + async def test_partial_company_merges_keys( + self, + client_with_cache: tuple[DataStreamClient, MockCacheProvider], + ) -> None: + client, _ = client_with_cache + + # Add full company + await client._handle_message(DataStreamResp( + data={ + "id": "co_partial", + "keys": {"slug": "original"}, + "account_id": "acc_1", + "environment_id": "env_1", + "billing_product_ids": [], + "credit_balances": {}, + "metrics": [], + "plan_ids": [], + "plan_version_ids": [], + "rules": [], + "traits": [], + }, + entity_type=EntityType.COMPANY.value, + message_type=MessageType.FULL.value, + )) + + # Partial update adds a new key + await client._handle_message(DataStreamResp( + data={"id": "co_partial", "keys": {"domain": "example.com"}}, + entity_type=EntityType.COMPANY.value, + message_type=MessageType.PARTIAL.value, + )) + + company = await client._get_company_from_cache({"slug": "original"}) + assert company is not None + assert company.keys == {"slug": "original", "domain": "example.com"} + + async def test_partial_user_merges_keys( + self, + client_with_cache: tuple[DataStreamClient, MockCacheProvider], + ) -> None: + client, _ = client_with_cache + + # Add full user + await client._handle_message(DataStreamResp( + data={ + "id": "usr_partial", + "keys": {"email": "orig@test.com"}, + "account_id": "acc_1", + "environment_id": "env_1", + "rules": [], + "traits": [], + }, + entity_type=EntityType.USER.value, + message_type=MessageType.FULL.value, + )) + + # Partial update adds a new key + await client._handle_message(DataStreamResp( + data={"id": "usr_partial", "keys": {"slack_id": "U123"}}, + entity_type=EntityType.USER.value, + message_type=MessageType.PARTIAL.value, + )) + + user = await client._get_user_from_cache({"email": "orig@test.com"}) + assert user is not None + assert user.keys == {"email": "orig@test.com", "slack_id": "U123"} + + +class TestDataStreamClientDeepCopy: + """Spec test #12: Deep copy prevents mutation of cached entities.""" + + async def test_cached_company_mutation_does_not_affect_cache(self, logger: logging.Logger) -> None: + cache = MockCacheProvider() + client = DataStreamClient(DataStreamClientOptions( + api_key="test-key", + logger=logger, + replicator_mode=True, + company_cache=cache, + company_lookup_cache=cache, + user_cache=cache, + user_lookup_cache=cache, + flag_cache=cache, + )) + + await client._handle_message(DataStreamResp( + data={ + "id": "co_mut", + "keys": {"slug": "mutable"}, + "account_id": "acc_1", + "environment_id": "env_1", + "billing_product_ids": ["bp-1"], + "credit_balances": {}, + "metrics": [], + "plan_ids": [], + "plan_version_ids": [], + "rules": [], + "traits": [], + }, + entity_type=EntityType.COMPANY.value, + message_type=MessageType.FULL.value, + )) + + # Retrieve and mutate + company = await client._get_company_from_cache({"slug": "mutable"}) + assert company is not None + original_ids = list(company.billing_product_ids) + company.billing_product_ids.append("bp-INJECTED") + + # Re-retrieve — cache should be unaffected + fresh = await client._get_company_from_cache({"slug": "mutable"}) + assert fresh is not None + assert fresh.billing_product_ids == original_ids + + async def test_cached_user_mutation_does_not_affect_cache(self, logger: logging.Logger) -> None: + cache = MockCacheProvider() + client = DataStreamClient(DataStreamClientOptions( + api_key="test-key", + logger=logger, + replicator_mode=True, + company_cache=cache, + company_lookup_cache=cache, + user_cache=cache, + user_lookup_cache=cache, + flag_cache=cache, + )) + + await client._handle_message(DataStreamResp( + data={ + "id": "usr_mut", + "keys": {"email": "mut@test.com"}, + "account_id": "acc_1", + "environment_id": "env_1", + "rules": [], + "traits": [], + }, + entity_type=EntityType.USER.value, + message_type=MessageType.FULL.value, + )) + + user = await client._get_user_from_cache({"email": "mut@test.com"}) + assert user is not None + user.keys["injected"] = "bad" + + fresh = await client._get_user_from_cache({"email": "mut@test.com"}) + assert fresh is not None + assert "injected" not in fresh.keys + + +class TestDataStreamClientMissingEntityTimeout: + """Spec test #8: Missing company triggers fetch/wait (times out without WS).""" + + async def test_get_company_times_out_without_connection(self, logger: logging.Logger) -> None: + client = DataStreamClient(DataStreamClientOptions( + api_key="test-key", + base_url="https://api.schematichq.com", + logger=logger, + )) + with pytest.raises(RuntimeError, match="not connected"): + await client.get_company({"slug": "missing"}) + + async def test_get_user_times_out_without_connection(self, logger: logging.Logger) -> None: + client = DataStreamClient(DataStreamClientOptions( + api_key="test-key", + base_url="https://api.schematichq.com", + logger=logger, + )) + with pytest.raises(RuntimeError, match="not connected"): + await client.get_user({"email": "missing@test.com"}) + + +class TestDataStreamClientDefaultReplicatorHealthUrl: + """Verify replicator_health_url defaults to spec canonical value.""" + + def test_default_health_url(self, logger: logging.Logger) -> None: + opts = DataStreamClientOptions( + api_key="test-key", + logger=logger, + ) + assert opts.replicator_health_url == "http://localhost:8090/ready" + + def test_custom_health_url_overrides_default(self, logger: logging.Logger) -> None: + opts = DataStreamClientOptions( + api_key="test-key", + logger=logger, + replicator_health_url="http://custom:9090/health", + ) + assert opts.replicator_health_url == "http://custom:9090/health" diff --git a/tests/datastream/test_merge.py b/tests/datastream/test_merge.py new file mode 100644 index 0000000..6f263a8 --- /dev/null +++ b/tests/datastream/test_merge.py @@ -0,0 +1,547 @@ +from __future__ import annotations + +import datetime as dt + +import pytest + +from schematic.datastream.merge import ( + deep_copy_company, + deep_copy_user, + extract_id, + partial_company, + partial_user, +) +from schematic.types import ( + RulesengineCompany, + RulesengineCompanyMetric, + RulesengineFeatureEntitlement, + RulesengineRule, + RulesengineSubscription, + RulesengineTrait, + RulesengineTraitDefinition, + RulesengineUser, +) + + +def _make_trait(value: str, definition_id: str | None = None) -> RulesengineTrait: + td = None + if definition_id is not None: + td = RulesengineTraitDefinition( + id=definition_id, + comparable_type="string", + entity_type="company", + ) + return RulesengineTrait(value=value, trait_definition=td) + + +def _make_rule(rule_id: str) -> RulesengineRule: + return RulesengineRule( + id=rule_id, + name=rule_id, + priority=1, + value=True, + rule_type="override", + account_id="acc-1", + environment_id="env-1", + condition_groups=[], + conditions=[], + ) + + +def _make_entitlement(feature_id: str, feature_key: str) -> RulesengineFeatureEntitlement: + return RulesengineFeatureEntitlement( + feature_id=feature_id, + feature_key=feature_key, + value_type="boolean", + ) + + +def _make_metric( + event_subtype: str, + period: str = "all_time", + month_reset: str = "first_of_month", + value: int = 0, +) -> RulesengineCompanyMetric: + return RulesengineCompanyMetric( + account_id="acc-1", + company_id="co-1", + created_at=dt.datetime(2026, 1, 1, tzinfo=dt.timezone.utc), + environment_id="env-1", + event_subtype=event_subtype, + month_reset=month_reset, + period=period, + value=value, + ) + + +def base_company() -> RulesengineCompany: + return RulesengineCompany( + id="co-1", + account_id="acc-1", + environment_id="env-1", + base_plan_id="plan-1", + billing_product_ids=["bp-1"], + credit_balances={"credit-1": 100.0}, + keys={"domain": "example.com"}, + plan_ids=["plan-1"], + plan_version_ids=["pv-1"], + traits=[_make_trait("Enterprise", "plan")], + entitlements=[_make_entitlement("feat-1", "feature-one")], + metrics=[], + rules=[], + ) + + +def base_user() -> RulesengineUser: + return RulesengineUser( + id="user-1", + account_id="acc-1", + environment_id="env-1", + keys={"email": "user@example.com"}, + traits=[_make_trait("Premium", "tier")], + rules=[], + ) + + +# ------------------------------------------------------------------ +# partial_company tests +# ------------------------------------------------------------------ + + +class TestPartialCompanyOnlyTraits: + def test_replaces_traits_preserves_other_fields(self) -> None: + existing = base_company() + partial = { + "id": "co-1", + "traits": [{"value": "Startup", "trait_definition": {"id": "plan", "comparable_type": "string", "entity_type": "company"}}], + } + + merged = partial_company(existing, partial) + + assert len(merged.traits) == 1 + assert merged.traits[0].value == "Startup" + + assert merged.account_id == "acc-1" + assert merged.environment_id == "env-1" + assert merged.keys == {"domain": "example.com"} + assert merged.billing_product_ids == ["bp-1"] + assert merged.base_plan_id == "plan-1" + + +class TestPartialCompanyMergesKeys: + def test_new_key_added_existing_preserved(self) -> None: + existing = base_company() + partial = {"id": "co-1", "keys": {"slug": "new-slug"}} + + merged = partial_company(existing, partial) + + assert merged.keys == {"domain": "example.com", "slug": "new-slug"} + assert len(merged.traits) == 1 + + +class TestPartialCompanyMergesCreditBalances: + def test_new_balance_added_existing_preserved(self) -> None: + existing = base_company() + partial = {"id": "co-1", "credit_balances": {"credit-2": 200.0}} + + merged = partial_company(existing, partial) + + assert merged.credit_balances == {"credit-1": 100.0, "credit-2": 200.0} + + def test_overwrites_existing_balance(self) -> None: + existing = base_company() + partial = {"id": "co-1", "credit_balances": {"credit-1": 50.0}} + + merged = partial_company(existing, partial) + + assert merged.credit_balances == {"credit-1": 50.0} + + +class TestPartialCompanyUpsertsMetrics: + def test_updates_matching_appends_new(self) -> None: + existing = base_company().model_copy( + update={ + "metrics": [ + _make_metric("event-a", "all_time", "first_of_month", 10), + _make_metric("event-b", "current_month", "first_of_month", 5), + ], + } + ) + partial = { + "id": "co-1", + "metrics": [ + {"event_subtype": "event-a", "period": "all_time", "month_reset": "first_of_month", "value": 42, + "account_id": "acc-1", "company_id": "co-1", "environment_id": "env-1", + "created_at": "2026-01-01T00:00:00Z"}, + {"event_subtype": "event-c", "period": "current_day", "month_reset": "billing_cycle", "value": 1, + "account_id": "acc-1", "company_id": "co-1", "environment_id": "env-1", + "created_at": "2026-01-01T00:00:00Z"}, + ], + } + + merged = partial_company(existing, partial) + + assert len(merged.metrics) == 3 + # event-a updated in place + assert merged.metrics[0].event_subtype == "event-a" + assert merged.metrics[0].value == 42 + # event-b unchanged + assert merged.metrics[1].event_subtype == "event-b" + assert merged.metrics[1].value == 5 + # event-c appended + assert merged.metrics[2].event_subtype == "event-c" + assert merged.metrics[2].value == 1 + + # Original not mutated + assert existing.metrics[0].value == 10 + + +class TestPartialCompanyEmptyEntitlements: + def test_clears_entitlements(self) -> None: + existing = base_company() + partial = {"id": "co-1", "entitlements": []} + + merged = partial_company(existing, partial) + + assert merged.entitlements == [] + assert merged.account_id == "acc-1" + + +class TestPartialCompanyNullBasePlanID: + def test_sets_base_plan_to_none(self) -> None: + existing = base_company() + partial = {"id": "co-1", "base_plan_id": None} + + merged = partial_company(existing, partial) + + assert merged.base_plan_id is None + assert merged.billing_product_ids == ["bp-1"] + + +class TestPartialCompanyMissingID: + def test_raises_value_error(self) -> None: + existing = base_company() + partial: dict[str, list[str]] = {"traits": []} + + with pytest.raises(ValueError, match="missing required field: id"): + partial_company(existing, partial) + + +class TestPartialCompanyDoesNotMutateOriginal: + def test_original_unchanged(self) -> None: + existing = base_company() + orig_keys = dict(existing.keys) + + partial = {"id": "co-1", "keys": {"slug": "new-slug"}, "traits": []} + + merged = partial_company(existing, partial) + + assert existing.keys == orig_keys + assert len(existing.traits) == 1 + assert merged.keys == {"domain": "example.com", "slug": "new-slug"} + assert merged.traits == [] + + +class TestPartialCompanyRules: + def test_replaces_rules(self) -> None: + existing = base_company().model_copy(update={"rules": [_make_rule("rule-old")]}) + partial = { + "id": "co-1", + "rules": [{"id": "rule-new", "name": "rule-new", "priority": 1, "value": True, + "rule_type": "override", "account_id": "acc-1", "environment_id": "env-1", + "condition_groups": [], "conditions": []}], + } + + merged = partial_company(existing, partial) + + assert len(merged.rules) == 1 + assert merged.rules[0].id == "rule-new" + assert existing.rules[0].id == "rule-old" + + +class TestPartialCompanyFullEntity: + def test_full_entity_partial_message(self) -> None: + existing = base_company().model_copy( + update={ + "metrics": [_make_metric("event-a", "all_time", "first_of_month", 10)], + "rules": [_make_rule("rule-1")], + } + ) + + partial = { + "id": "co-1", + "account_id": "acc-2", + "environment_id": "env-2", + "base_plan_id": "plan-99", + "billing_product_ids": ["bp-10", "bp-20"], + "credit_balances": {"credit-1": 999.0, "credit-new": 50.0}, + "entitlements": [ + {"feature_id": "feat-new", "feature_key": "feature-new", "value_type": "boolean"}, + {"feature_id": "feat-2", "feature_key": "feature-two", "value_type": "boolean"}, + ], + "keys": {"domain": "new.com", "slug": "new-slug"}, + "metrics": [ + {"event_subtype": "event-a", "period": "all_time", "month_reset": "first_of_month", "value": 42, + "account_id": "acc-1", "company_id": "co-1", "environment_id": "env-1", + "created_at": "2026-01-01T00:00:00Z"}, + {"event_subtype": "event-new", "period": "current_day", "month_reset": "billing_cycle", "value": 7, + "account_id": "acc-1", "company_id": "co-1", "environment_id": "env-1", + "created_at": "2026-01-01T00:00:00Z"}, + ], + "plan_ids": ["plan-99", "plan-100"], + "plan_version_ids": ["pv-99"], + "rules": [ + {"id": "rule-new-1", "name": "r1", "priority": 1, "value": True, "rule_type": "override", + "account_id": "acc-1", "environment_id": "env-1", "condition_groups": [], "conditions": []}, + {"id": "rule-new-2", "name": "r2", "priority": 2, "value": False, "rule_type": "override", + "account_id": "acc-1", "environment_id": "env-1", "condition_groups": [], "conditions": []}, + ], + "subscription": {"id": "sub-new", "period_start": "2026-01-01T00:00:00Z", "period_end": "2026-02-01T00:00:00Z"}, + "traits": [ + {"value": "Startup", "trait_definition": {"id": "tier", "comparable_type": "string", "entity_type": "company"}}, + {"value": "Annual", "trait_definition": {"id": "billing", "comparable_type": "string", "entity_type": "company"}}, + ], + } + + merged = partial_company(existing, partial) + + assert merged.id == "co-1" + assert merged.account_id == "acc-2" + assert merged.environment_id == "env-2" + assert merged.base_plan_id == "plan-99" + assert merged.billing_product_ids == ["bp-10", "bp-20"] + + # Credit balances merge: existing credit-1 overwritten, credit-new added + assert merged.credit_balances == {"credit-1": 999.0, "credit-new": 50.0} + + assert merged.entitlements is not None + assert len(merged.entitlements) == 2 + assert merged.entitlements[0].feature_id == "feat-new" + assert merged.entitlements[1].feature_id == "feat-2" + + # Keys merge: domain overwritten, slug added + assert merged.keys == {"domain": "new.com", "slug": "new-slug"} + + # Metrics upsert: event-a updated, event-new appended + assert len(merged.metrics) == 2 + assert merged.metrics[0].event_subtype == "event-a" + assert merged.metrics[0].value == 42 + assert merged.metrics[1].event_subtype == "event-new" + assert merged.metrics[1].value == 7 + + assert merged.plan_ids == ["plan-99", "plan-100"] + assert merged.plan_version_ids == ["pv-99"] + + assert len(merged.rules) == 2 + assert merged.rules[0].id == "rule-new-1" + assert merged.rules[1].id == "rule-new-2" + + assert merged.subscription is not None + assert merged.subscription.id == "sub-new" + + assert len(merged.traits) == 2 + assert merged.traits[0].value == "Startup" + assert merged.traits[1].value == "Annual" + + # Original not mutated + assert existing.account_id == "acc-1" + assert existing.base_plan_id == "plan-1" + assert existing.keys == {"domain": "example.com"} + assert existing.metrics[0].value == 10 + + +# ------------------------------------------------------------------ +# partial_user tests +# ------------------------------------------------------------------ + + +class TestPartialUserOnlyTraits: + def test_replaces_traits_preserves_keys(self) -> None: + existing = base_user() + partial = { + "id": "user-1", + "traits": [{"value": "Free", "trait_definition": {"id": "tier", "comparable_type": "string", "entity_type": "user"}}], + } + + merged = partial_user(existing, partial) + + assert len(merged.traits) == 1 + assert merged.traits[0].value == "Free" + assert merged.keys == {"email": "user@example.com"} + + +class TestPartialUserMergesKeys: + def test_new_key_added_existing_preserved(self) -> None: + existing = base_user() + partial = {"id": "user-1", "keys": {"slack_id": "U123"}} + + merged = partial_user(existing, partial) + + assert merged.keys == {"email": "user@example.com", "slack_id": "U123"} + assert len(merged.traits) == 1 + + +class TestPartialUserMissingID: + def test_raises_value_error(self) -> None: + existing = base_user() + partial = {"keys": {"email": "new@example.com"}} + + with pytest.raises(ValueError, match="missing required field: id"): + partial_user(existing, partial) + + +class TestPartialUserDoesNotMutateOriginal: + def test_original_unchanged(self) -> None: + existing = base_user() + orig_keys = dict(existing.keys) + + partial = {"id": "user-1", "keys": {"slug": "new"}, "traits": []} + + merged = partial_user(existing, partial) + + assert existing.keys == orig_keys + assert len(existing.traits) == 1 + assert merged.keys == {"email": "user@example.com", "slug": "new"} + assert merged.traits == [] + + +class TestPartialUserFullEntity: + def test_full_entity_partial_message(self) -> None: + existing = base_user().model_copy(update={"rules": [_make_rule("rule-1")]}) + + partial = { + "id": "user-1", + "account_id": "acc-2", + "environment_id": "env-2", + "keys": {"email": "new@example.com", "slack_id": "U999"}, + "traits": [ + {"value": "Free", "trait_definition": {"id": "tier", "comparable_type": "string", "entity_type": "user"}}, + {"value": "Monthly", "trait_definition": {"id": "billing", "comparable_type": "string", "entity_type": "user"}}, + ], + "rules": [ + {"id": "rule-new-1", "name": "r1", "priority": 1, "value": True, "rule_type": "override", + "account_id": "acc-1", "environment_id": "env-1", "condition_groups": [], "conditions": []}, + {"id": "rule-new-2", "name": "r2", "priority": 2, "value": False, "rule_type": "override", + "account_id": "acc-1", "environment_id": "env-1", "condition_groups": [], "conditions": []}, + ], + } + + merged = partial_user(existing, partial) + + assert merged.id == "user-1" + assert merged.account_id == "acc-2" + assert merged.environment_id == "env-2" + + # Keys merge: email overwritten, slack_id added + assert merged.keys == {"email": "new@example.com", "slack_id": "U999"} + + assert len(merged.traits) == 2 + assert merged.traits[0].value == "Free" + assert merged.traits[1].value == "Monthly" + + assert len(merged.rules) == 2 + assert merged.rules[0].id == "rule-new-1" + assert merged.rules[1].id == "rule-new-2" + + # Original not mutated + assert existing.account_id == "acc-1" + assert existing.keys == {"email": "user@example.com"} + assert len(existing.traits) == 1 + assert existing.rules[0].id == "rule-1" + + +# ------------------------------------------------------------------ +# extract_id tests +# ------------------------------------------------------------------ + + +class TestExtractID: + def test_from_dict(self) -> None: + assert extract_id({"id": "co-1", "traits": []}) == "co-1" + + def test_from_model(self) -> None: + user = base_user() + assert extract_id(user) == "user-1" + + def test_missing_returns_none(self) -> None: + assert extract_id({"traits": []}) is None + + def test_none_returns_none(self) -> None: + assert extract_id(None) is None + + +# ------------------------------------------------------------------ +# deep_copy tests +# ------------------------------------------------------------------ + + +class TestDeepCopyCompany: + def test_none_returns_none(self) -> None: + assert deep_copy_company(None) is None + + def test_full_copy_is_independent(self) -> None: + orig = base_company().model_copy( + update={ + "subscription": RulesengineSubscription( + id="sub-1", + period_start=dt.datetime(2026, 1, 1, tzinfo=dt.timezone.utc), + period_end=dt.datetime(2026, 2, 1, tzinfo=dt.timezone.utc), + ), + "metrics": [ + _make_metric("event-1", value=42), + ], + } + ) + + cp = deep_copy_company(orig) + assert cp is not None + + assert cp.id == orig.id + assert cp.account_id == orig.account_id + assert cp.environment_id == orig.environment_id + assert cp.base_plan_id == orig.base_plan_id + assert cp.keys == orig.keys + assert cp.credit_balances == orig.credit_balances + assert cp.subscription is not None + assert cp.subscription.id == "sub-1" + assert cp.metrics[0].value == 42 + + # Verify it's a separate object + assert cp is not orig + + +class TestDeepCopyUser: + def test_empty_fields(self) -> None: + orig = RulesengineUser( + id="u1", + account_id="acc-1", + environment_id="env-1", + keys={}, + traits=[], + rules=[], + ) + cp = deep_copy_user(orig) + assert cp is not None + + assert cp.id == "u1" + assert cp.keys == {} + assert cp.traits == [] + assert cp.rules == [] + + def test_full_copy_is_independent(self) -> None: + orig = base_user().model_copy(update={"rules": [_make_rule("r1")]}) + + cp = deep_copy_user(orig) + assert cp is not None + + assert cp.id == orig.id + assert cp.account_id == orig.account_id + assert cp.keys == orig.keys + assert cp.traits[0].value == "Premium" + assert cp.rules[0].id == "r1" + + # Verify it's a separate object + assert cp is not orig + + def test_none_returns_none(self) -> None: + assert deep_copy_user(None) is None diff --git a/tests/datastream/test_rules_engine.py b/tests/datastream/test_rules_engine.py new file mode 100644 index 0000000..ce342a4 --- /dev/null +++ b/tests/datastream/test_rules_engine.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +import pytest + +from schematic.datastream.rules_engine import RulesEngineClient +from schematic.types import RulesengineCheckFlagResult, RulesengineFlag, RulesengineRule + +# Skip all tests if wasmtime is not installed +wasmtime = pytest.importorskip("wasmtime", reason="wasmtime not installed") + + +def _make_flag(**overrides: object) -> RulesengineFlag: + """Build a minimal valid flag for the WASM rules engine.""" + defaults = dict( + id="flag1", + key="test-flag", + account_id="acc_1", + environment_id="env_1", + default_value=False, + rules=[], + ) + defaults.update(overrides) + return RulesengineFlag(**defaults) # type: ignore[arg-type] + + + +class TestRulesEngineClientInit: + async def test_initialize_loads_wasm(self) -> None: + engine = RulesEngineClient() + assert not engine.is_initialized() + await engine.initialize() + assert engine.is_initialized() + + async def test_initialize_is_idempotent(self) -> None: + engine = RulesEngineClient() + await engine.initialize() + await engine.initialize() # Should not raise + assert engine.is_initialized() + + async def test_get_version_key(self) -> None: + engine = RulesEngineClient() + await engine.initialize() + version = engine.get_version_key() + assert isinstance(version, str) + assert len(version) == 8 # 8-char hex string + + def test_check_flag_before_init_raises(self) -> None: + engine = RulesEngineClient() + with pytest.raises(RuntimeError, match="not initialized"): + engine.check_flag(_make_flag()) + + +class TestRulesEngineCheckFlag: + @pytest.fixture + async def engine(self) -> RulesEngineClient: + e = RulesEngineClient() + await e.initialize() + return e + + async def test_evaluates_flag_with_default_value(self, engine: RulesEngineClient) -> None: + flag = _make_flag(default_value=True) + result = engine.check_flag(flag) + assert isinstance(result, RulesengineCheckFlagResult) + assert result.value is True + assert result.flag_key == "test-flag" + assert result.reason != "" + + async def test_evaluates_flag_with_company_context(self, engine: RulesEngineClient) -> None: + from schematic.types import RulesengineCompany + + flag = _make_flag( + default_value=False, + rules=[ + RulesengineRule( + id="rule1", + account_id="acc_1", + environment_id="env_1", + name="Global Override", + rule_type="global_override", + value=True, + priority=1, + conditions=[], + condition_groups=[], + ) + ], + ) + company = RulesengineCompany( + id="co_123", + account_id="acc_1", + environment_id="env_1", + keys={"id": "co_123"}, + traits=[], + metrics=[], + rules=[], + entitlements=[], + billing_product_ids=[], + credit_balances={}, + plan_ids=[], + plan_version_ids=[], + ) + result = engine.check_flag(flag, company) + assert isinstance(result, RulesengineCheckFlagResult) + assert result.value is True + assert result.flag_key == "test-flag" + assert result.rule_id == "rule1" + assert result.rule_type is not None + assert result.reason != "" + + async def test_evaluates_flag_with_user_context(self, engine: RulesEngineClient) -> None: + from schematic.types import RulesengineUser + + flag = _make_flag(id="flag2", key="user-flag", default_value=True) + user = RulesengineUser( + id="usr_456", + account_id="acc_1", + environment_id="env_1", + keys={"id": "usr_456"}, + traits=[], + rules=[], + ) + result = engine.check_flag(flag, None, user) + assert isinstance(result, RulesengineCheckFlagResult) + assert result.value is True + assert result.flag_key == "user-flag" + + async def test_returns_default_for_empty_rules(self, engine: RulesEngineClient) -> None: + flag = _make_flag(id="flag3", key="empty-rules", default_value=False) + result = engine.check_flag(flag) + assert result.value is False + assert result.flag_key == "empty-rules" + + +class TestRulesEngineFileNotFound: + async def test_missing_wasm_raises(self) -> None: + engine = RulesEngineClient(wasm_path="/nonexistent/rulesengine.wasm") + with pytest.raises(FileNotFoundError): + await engine.initialize() diff --git a/tests/datastream/test_websocket_client.py b/tests/datastream/test_websocket_client.py new file mode 100644 index 0000000..18be616 --- /dev/null +++ b/tests/datastream/test_websocket_client.py @@ -0,0 +1,479 @@ +from __future__ import annotations + +import asyncio +import json +import logging +from contextlib import asynccontextmanager +from typing import AsyncIterator, List, Optional +from unittest.mock import patch + +import pytest + +from schematic.datastream.types import DataStreamBaseReq, DataStreamReq, DataStreamResp, EntityType +from schematic.datastream.websocket_client import ( + ClientOptions, + DatastreamWSClient, + convert_api_url_to_websocket_url, +) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Mock helpers +# --------------------------------------------------------------------------- + + +class MockWebSocket: + """Fake WebSocket that yields a fixed list of messages then stops. + + Set ``block_on_empty=True`` to keep the connection open after all messages + are consumed — useful for tests that need to inspect state while connected. + Call ``.close()`` to unblock. + """ + + def __init__(self, messages: Optional[List[str]] = None, block_on_empty: bool = False) -> None: + self._messages = messages or [] + self._index = 0 + self._block_on_empty = block_on_empty + self._close_event: asyncio.Event = asyncio.Event() + self.sent: List[str] = [] + self.closed = False + + async def send(self, message: str | bytes) -> None: + self.sent.append(message.decode() if isinstance(message, bytes) else message) + + async def close(self) -> None: + self.closed = True + self._close_event.set() + + def __aiter__(self) -> AsyncIterator[str]: + return self + + async def __anext__(self) -> str: + if self._index < len(self._messages): + msg = self._messages[self._index] + self._index += 1 + return msg + if self._block_on_empty: + await self._close_event.wait() + raise StopAsyncIteration + + +class _AlwaysFailConnect: + """Mimics websockets.connect but always raises on __aenter__.""" + + def __call__(self, *args, **kwargs) -> _AlwaysFailConnect: + return self + + async def __aenter__(self) -> None: + raise ConnectionError("connection refused") + + async def __aexit__(self, *args) -> None: + pass + + +def make_connect(ws: MockWebSocket): + """Return a mock for websockets.connect that yields the given MockWebSocket.""" + + @asynccontextmanager + async def _connect(*args, **kwargs): + yield ws + + return _connect + + +def make_client( + messages: Optional[List[str]] = None, + ws: Optional[MockWebSocket] = None, + **kwargs, +) -> tuple[DatastreamWSClient, MockWebSocket, List[DataStreamResp]]: + """Build a DatastreamWSClient backed by a MockWebSocket. + + Extra kwargs are forwarded to ClientOptions, allowing callbacks and other + options to be set per-test. + """ + if ws is None: + ws = MockWebSocket(messages or []) + + received: List[DataStreamResp] = [] + + async def handler(msg: DataStreamResp) -> None: + received.append(msg) + + client = DatastreamWSClient( + ClientOptions( + url="wss://test.example.com/datastream", + api_key="test-key", + message_handler=handler, + logger=logger, + min_reconnect_delay=0.0, + max_reconnect_delay=0.0, + **kwargs, + ) + ) + return client, ws, received + + +async def wait_until(condition, timeout: float = 2.0, poll: float = 0.01) -> None: + """Poll until condition() is True or timeout is reached.""" + + async def _poll(): + while not condition(): + await asyncio.sleep(poll) + + await asyncio.wait_for(_poll(), timeout=timeout) + + +@asynccontextmanager +async def run_client(client: DatastreamWSClient): + """Context manager that ensures client.close() is always called.""" + client.start() + try: + yield client + finally: + await client.close() + + +# --------------------------------------------------------------------------- +# convert_api_url_to_websocket_url +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "api_url,expected", + [ + ("https://api.schematichq.com", "wss://datastream.schematichq.com/datastream"), + ("https://api.staging.schematichq.com", "wss://datastream.staging.schematichq.com/datastream"), + ("https://custom.example.com", "wss://custom.example.com/datastream"), + ("http://localhost:8080", "ws://localhost:8080/datastream"), + ("http://localhost", "ws://localhost/datastream"), + ], +) +def test_convert_api_url(api_url: str, expected: str) -> None: + assert convert_api_url_to_websocket_url(api_url) == expected + + +def test_convert_api_url_unsupported_scheme() -> None: + with pytest.raises(ValueError, match="Unsupported scheme"): + convert_api_url_to_websocket_url("ftp://example.com") + + +# --------------------------------------------------------------------------- +# Constructor validation +# --------------------------------------------------------------------------- + + +def test_init_missing_url() -> None: + async def handler(msg): pass + + with pytest.raises(ValueError, match="url is required"): + DatastreamWSClient(ClientOptions(url="", api_key="key", message_handler=handler, logger=logger)) + + +def test_init_missing_api_key() -> None: + async def handler(msg): pass + + with pytest.raises(ValueError, match="api_key is required"): + DatastreamWSClient(ClientOptions(url="wss://example.com", api_key="", message_handler=handler, logger=logger)) + + +def test_init_missing_message_handler() -> None: + with pytest.raises(ValueError, match="message_handler is required"): + DatastreamWSClient( + ClientOptions(url="wss://example.com", api_key="key", message_handler=None, logger=logger) # type: ignore[arg-type] + ) + + +def test_init_converts_http_url() -> None: + async def handler(msg): pass + + client = DatastreamWSClient( + ClientOptions(url="https://api.schematichq.com", api_key="key", message_handler=handler, logger=logger) + ) + assert client._url == "wss://datastream.schematichq.com/datastream" + + +# --------------------------------------------------------------------------- +# send_message +# --------------------------------------------------------------------------- + + +async def test_send_message_raises_when_not_connected() -> None: + client, _, _ = make_client() + req = DataStreamBaseReq(data=DataStreamReq(entity_type=EntityType.COMPANY)) + + with pytest.raises(RuntimeError, match="not available"): + await client.send_message(req) + + +async def test_send_message_sends_json_when_connected() -> None: + ws = MockWebSocket(block_on_empty=True) + connected = asyncio.Event() + client, ws, _ = make_client(ws=ws, on_connected=lambda: connected.set()) + + with patch("schematic.datastream.websocket_client.websockets.connect", make_connect(ws)): + async with run_client(client): + await asyncio.wait_for(connected.wait(), timeout=2.0) + + req = DataStreamBaseReq(data=DataStreamReq(entity_type=EntityType.COMPANY)) + await client.send_message(req) + + assert len(ws.sent) == 1 + assert json.loads(ws.sent[0]) == {"data": {"action": "start", "entity_type": "company"}} + + +# --------------------------------------------------------------------------- +# Message handling +# --------------------------------------------------------------------------- + + +async def test_string_message_delivered_to_handler() -> None: + msg = json.dumps({"entity_type": "rulesengine.Company", "message_type": "full", "data": {"id": "123"}}) + client, ws, received = make_client(messages=[msg]) + + with patch("schematic.datastream.websocket_client.websockets.connect", make_connect(ws)): + async with run_client(client): + await wait_until(lambda: len(received) == 1) + + assert received[0].entity_type == "rulesengine.Company" + assert received[0].message_type == "full" + assert received[0].data == {"id": "123"} + + +async def test_bytes_message_delivered_to_handler() -> None: + payload = json.dumps({"entity_type": "rulesengine.Flag", "message_type": "full", "data": None}) + ws = MockWebSocket(messages=[payload.encode()]) # type: ignore[list-item] + client, ws, received = make_client(ws=ws) + + with patch("schematic.datastream.websocket_client.websockets.connect", make_connect(ws)): + async with run_client(client): + await wait_until(lambda: len(received) == 1) + + assert received[0].entity_type == "rulesengine.Flag" + + +async def test_invalid_json_calls_on_error() -> None: + errors: List[Exception] = [] + client, ws, _ = make_client(messages=["not-valid-json"], on_error=errors.append) + + with patch("schematic.datastream.websocket_client.websockets.connect", make_connect(ws)): + async with run_client(client): + await wait_until(lambda: len(errors) > 0) + + assert "Failed to parse" in str(errors[0]) + + +async def test_message_handler_exception_calls_on_error() -> None: + errors: List[Exception] = [] + + async def bad_handler(msg: DataStreamResp) -> None: + raise ValueError("handler blew up") + + msg = json.dumps({"entity_type": "rulesengine.Company", "message_type": "full", "data": None}) + ws = MockWebSocket(messages=[msg]) + client = DatastreamWSClient( + ClientOptions( + url="wss://test.example.com/datastream", + api_key="key", + message_handler=bad_handler, + logger=logger, + min_reconnect_delay=0.0, + max_reconnect_delay=0.0, + on_error=errors.append, + ) + ) + + with patch("schematic.datastream.websocket_client.websockets.connect", make_connect(ws)): + async with run_client(client): + await wait_until(lambda: len(errors) > 0) + + assert "Message handler error" in str(errors[0]) + + +# --------------------------------------------------------------------------- +# State callbacks +# --------------------------------------------------------------------------- + + +async def test_on_connected_and_on_ready_fired() -> None: + connected_calls: List[bool] = [] + ready_calls: List[bool] = [] + + ws = MockWebSocket(block_on_empty=True) + client, ws, _ = make_client( + ws=ws, + on_connected=lambda: connected_calls.append(True), + on_ready=lambda: ready_calls.append(True), + ) + + with patch("schematic.datastream.websocket_client.websockets.connect", make_connect(ws)): + async with run_client(client): + await wait_until(client.is_ready) + + assert connected_calls == [True] + assert ready_calls == [True] + assert client.is_connected() + assert client.is_ready() + + +async def test_on_disconnected_and_on_not_ready_fired_on_close() -> None: + disconnected_calls: List[bool] = [] + not_ready_calls: List[bool] = [] + + ws = MockWebSocket(block_on_empty=True) + client, ws, _ = make_client( + ws=ws, + on_disconnected=lambda: disconnected_calls.append(True), + on_not_ready=lambda: not_ready_calls.append(True), + ) + + with patch("schematic.datastream.websocket_client.websockets.connect", make_connect(ws)): + async with run_client(client): + await wait_until(client.is_connected) + + assert True in disconnected_calls + assert True in not_ready_calls + assert not client.is_connected() + assert not client.is_ready() + + +# --------------------------------------------------------------------------- +# connection_ready_handler +# --------------------------------------------------------------------------- + + +async def test_connection_ready_handler_called_before_ready() -> None: + order: List[str] = [] + + async def ready_handler() -> None: + order.append("ready_handler") + + ws = MockWebSocket(block_on_empty=True) + client, ws, _ = make_client( + ws=ws, + connection_ready_handler=ready_handler, + on_ready=lambda: order.append("on_ready"), + ) + + with patch("schematic.datastream.websocket_client.websockets.connect", make_connect(ws)): + async with run_client(client): + await wait_until(client.is_ready) + + assert order == ["ready_handler", "on_ready"] + + +async def test_connection_ready_handler_failure_prevents_ready() -> None: + async def failing_handler() -> None: + raise RuntimeError("setup failed") + + ws = MockWebSocket() + client, ws, _ = make_client( + ws=ws, + connection_ready_handler=failing_handler, + max_reconnect_attempts=1, + ) + + with patch("schematic.datastream.websocket_client.websockets.connect", make_connect(ws)): + async with run_client(client): + await asyncio.sleep(0.1) + + assert not client.is_ready() + + +# --------------------------------------------------------------------------- +# Reconnection +# --------------------------------------------------------------------------- + + +async def test_reconnects_after_connection_error() -> None: + connect_calls: List[int] = [] + msg = json.dumps({"entity_type": "rulesengine.Company", "message_type": "full", "data": None}) + success_ws = MockWebSocket(messages=[msg]) + + @asynccontextmanager + async def mock_connect(*args, **kwargs): + connect_calls.append(1) + if len(connect_calls) == 1: + raise ConnectionError("first attempt fails") + yield success_ws + + received: List[DataStreamResp] = [] + + async def handler(m: DataStreamResp) -> None: + received.append(m) + + client = DatastreamWSClient( + ClientOptions( + url="wss://test.example.com/datastream", + api_key="key", + message_handler=handler, + logger=logger, + min_reconnect_delay=0.0, + max_reconnect_delay=0.0, + max_reconnect_attempts=5, + ) + ) + + with patch("schematic.datastream.websocket_client.websockets.connect", mock_connect): + async with run_client(client): + await wait_until(lambda: len(received) == 1) + + assert len(connect_calls) >= 2 + assert len(received) == 1 + + +async def test_stops_at_max_reconnect_attempts() -> None: + errors: List[Exception] = [] + + client = DatastreamWSClient( + ClientOptions( + url="wss://test.example.com/datastream", + api_key="key", + message_handler=lambda _: None, # type: ignore[arg-type,return-value] + logger=logger, + min_reconnect_delay=0.0, + max_reconnect_delay=0.0, + max_reconnect_attempts=3, + on_error=errors.append, + ) + ) + + with patch("schematic.datastream.websocket_client.websockets.connect", _AlwaysFailConnect()): + async with run_client(client): + await wait_until(lambda: len(errors) > 0) + + assert "Max reconnection" in str(errors[0]) + + +# --------------------------------------------------------------------------- +# Backoff delay +# --------------------------------------------------------------------------- + + +def test_backoff_delay_within_bounds() -> None: + client, _, _ = make_client() + + for attempt in range(1, 8): + delay = client._calculate_backoff_delay(attempt) + assert delay >= 0 + assert delay <= client._max_reconnect_delay + client._min_reconnect_delay + + +def test_backoff_delay_does_not_exceed_max() -> None: + async def handler(msg): pass + + client = DatastreamWSClient( + ClientOptions( + url="wss://test.example.com/datastream", + api_key="key", + message_handler=handler, + logger=logger, + min_reconnect_delay=1.0, + max_reconnect_delay=5.0, + ) + ) + + # At high attempt counts, delay should be capped at max + jitter ceiling + delay = client._calculate_backoff_delay(20) + assert delay <= 5.0 + 1.0