From 05290f0d5e5fde8754d91cba27c52e7368da7fb4 Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Fri, 6 Feb 2026 20:08:48 +0100 Subject: [PATCH 1/8] Make filesystem dependency everywhere --- api/core/filesystem.py | 38 ++++++++---- api/crud/job.py | 30 +++++----- api/dependencies.py | 17 +++++- api/endpoints/auth.py | 10 +++- api/endpoints/files.py | 20 ++++--- api/endpoints/jobs.py | 12 +++- api/settings.py | 2 +- tests/conftest.py | 14 ++--- tests/integration/conftest.py | 94 +++++++++++++----------------- tests/unit/core/test_filesystem.py | 2 - 10 files changed, 129 insertions(+), 110 deletions(-) diff --git a/api/core/filesystem.py b/api/core/filesystem.py index d9160b4..5745a2c 100644 --- a/api/core/filesystem.py +++ b/api/core/filesystem.py @@ -5,7 +5,7 @@ import shutil import zipfile from pathlib import Path, PurePosixPath -from typing import Any, BinaryIO, Generator, cast +from typing import Any, BinaryIO, Callable, Generator import boto3 import humanize @@ -17,7 +17,7 @@ from mypy_boto3_s3 import S3Client from mypy_boto3_s3.type_defs import ObjectIdentifierTypeDef -from api import models, settings +from api import models from api.schemas.file import FileHTTPRequest, FileInfo, FileTypes @@ -387,29 +387,43 @@ def download_url( ) -def get_filesystem_with_root(root_path: str) -> FileSystem: +def get_filesystem_with_root( + root_path: str, + filesystem: str, + s3_region: str, + s3_bucket: str | None, +) -> FileSystem: """Get the filesystem to use.""" predef_dirs = [e.value for e in models.UploadFileTypes] + [ e.value for e in models.OutputEndpoints ] - if settings.filesystem == "s3": + if filesystem == "s3": + assert s3_bucket is not None, "S3 bucket must be provided for S3 filesystem" s3_client = boto3.client( "s3", - region_name=settings.s3_region, - endpoint_url=f"https://s3.{settings.s3_region}.amazonaws.com", + region_name=s3_region, + endpoint_url=f"https://s3.{s3_region}.amazonaws.com", config=Config(signature_version="v4", s3={"addressing_style": "path"}), ) # this and config=... required to avoid DNS problems with new buckets s3_client.meta.events.unregister("before-sign.s3", fix_s3_host) - return S3Filesystem( - root_path, s3_client, cast(str, settings.s3_bucket), predef_dirs=predef_dirs - ) - elif settings.filesystem == "local": + return S3Filesystem(root_path, s3_client, s3_bucket, predef_dirs=predef_dirs) + elif filesystem == "local": return LocalFilesystem(root_path, predef_dirs=predef_dirs) else: raise ValueError("Invalid filesystem setting") -def get_user_filesystem(user_id: str) -> FileSystem: +def user_filesystem_getter( + user_data_root_path: str, + filesystem: str, + s3_region: str, + s3_bucket: str | None, +) -> Callable[[str], FileSystem]: """Get the filesystem to use for a user.""" - return get_filesystem_with_root(str(Path(settings.user_data_root_path) / user_id)) + return lambda user_id: get_filesystem_with_root( + str(Path(user_data_root_path) / user_id), + filesystem=filesystem, + s3_region=s3_region, + s3_bucket=s3_bucket, + ) diff --git a/api/crud/job.py b/api/crud/job.py index aed5a99..f83abc1 100644 --- a/api/crud/job.py +++ b/api/crud/job.py @@ -6,15 +6,15 @@ from sqlalchemy.orm import Session from api import models, settings -from api.core.filesystem import FileSystem, get_user_filesystem +from api.core.filesystem import FileSystem from api.schemas import job as schemas def enqueue_job( - job: models.Job, enqueueing_func: Callable[[schemas.QueueJob], None] + job: models.Job, + filesystem: FileSystem, + enqueueing_func: Callable[[schemas.QueueJob], None], ) -> None: - user_fs = get_user_filesystem(user_id=job.user_id) - app = job.application job_config = settings.application_config.config[app["application"]][app["version"]][ app["entrypoint"] @@ -47,14 +47,14 @@ def prepare_files(root_in: str, root_out: str, fs: FileSystem) -> dict[str, str] f"artifact/{artifact_id}" for artifact_id in job.attributes["files_down"]["artifact_ids"] ] - _validate_files(user_fs, [config_path] + data_paths + artifact_paths) + _validate_files(filesystem, [config_path] + data_paths + artifact_paths) roots_down = handler_config["files_down"] - files_down = prepare_files(config_path, roots_down["config_id"], user_fs) + files_down = prepare_files(config_path, roots_down["config_id"], filesystem) for data_path in data_paths: - files_down.update(prepare_files(data_path, roots_down["data_ids"], user_fs)) + files_down.update(prepare_files(data_path, roots_down["data_ids"], filesystem)) for artifact_path in artifact_paths: files_down.update( - prepare_files(artifact_path, roots_down["artifact_ids"], user_fs) + prepare_files(artifact_path, roots_down["artifact_ids"], filesystem) ) app_specs = schemas.AppSpecs( @@ -76,9 +76,9 @@ def prepare_files(root_in: str, root_out: str, fs: FileSystem) -> dict[str, str] ) paths_upload = { - "output": user_fs.full_path_uri(job.paths_out["output"]), - "log": user_fs.full_path_uri(job.paths_out["log"]), - "artifact": user_fs.full_path_uri(job.paths_out["artifact"]), + "output": filesystem.full_path_uri(job.paths_out["output"]), + "log": filesystem.full_path_uri(job.paths_out["log"]), + "artifact": filesystem.full_path_uri(job.paths_out["artifact"]), } queue_item = schemas.QueueJob( @@ -117,6 +117,7 @@ def _validate_files(filesystem: FileSystem, paths: list[str]) -> None: def create_job( db: Session, + filesystem: FileSystem, enqueueing_func: Callable[[schemas.QueueJob], None], job: schemas.JobCreate, user_id: int, @@ -146,18 +147,17 @@ def create_job( status_code=status.HTTP_400_BAD_REQUEST, detail=ve, ) - enqueue_job(db_job, enqueueing_func) + enqueue_job(db_job, filesystem, enqueueing_func) db.commit() db.refresh(db_job) return db_job -def delete_job(db: Session, db_job: models.Job) -> models.Job: +def delete_job(db: Session, filesystem: FileSystem, db_job: models.Job) -> models.Job: db.delete(db_job) - user_fs = get_user_filesystem(user_id=db_job.user_id) for path in db_job.paths_out.values(): if path[-1] != "/": path += "/" - user_fs.delete(path) + filesystem.delete(path) db.commit() return db_job diff --git a/api/dependencies.py b/api/dependencies.py index ec3c017..6de255c 100644 --- a/api/dependencies.py +++ b/api/dependencies.py @@ -9,7 +9,7 @@ from api import settings from api.core import notifications -from api.core.filesystem import FileSystem, get_user_filesystem +from api.core.filesystem import FileSystem, user_filesystem_getter from api.schemas.job import QueueJob @@ -52,11 +52,22 @@ async def current_user_global_dep( return current_user -async def filesystem_dep( +async def filesystem_getter_dep() -> Callable[[str], FileSystem]: + """Get the user's filesystem getter.""" + return user_filesystem_getter( + user_data_root_path=settings.user_data_root_path, + filesystem=settings.filesystem, + s3_region=settings.s3_region, + s3_bucket=settings.s3_bucket, + ) + + +async def user_filesystem_dep( + filesystem_getter: Callable[[str], FileSystem] = Depends(filesystem_getter_dep), current_user: CognitoClaims = Depends(current_user_dep), ) -> FileSystem: """Get the user's filesystem.""" - return get_user_filesystem(current_user.username) + return filesystem_getter(current_user.username) class APIKeyDependency: diff --git a/api/endpoints/auth.py b/api/endpoints/auth.py index 883de5c..d6addd7 100644 --- a/api/endpoints/auth.py +++ b/api/endpoints/auth.py @@ -4,13 +4,15 @@ as the authentication is handled by the Cognito service. """ +from typing import Callable + import boto3 from fastapi import APIRouter, Depends, HTTPException, status from fastapi.responses import JSONResponse from fastapi.security import OAuth2PasswordRequestForm from api.core.aws import calculate_secret_hash -from api.core.filesystem import get_user_filesystem +from api.core.filesystem import FileSystem from api.schemas.token import TokenResponse from api.schemas.user import User, UserGroups from api.settings import cognito_client_id, cognito_secret, cognito_user_pool_id @@ -25,7 +27,9 @@ description="Register a new user", ) def register_user( - user: OAuth2PasswordRequestForm = Depends(), groups: list[UserGroups] | None = None + user: OAuth2PasswordRequestForm = Depends(), + filesystem_getter_dep: Callable[[str], FileSystem] = Depends(), + groups: list[UserGroups] | None = None, ) -> User: client = boto3.client("cognito-idp") try: @@ -52,7 +56,7 @@ def register_user( Password=user.password, Permanent=True, ) - filesystem = get_user_filesystem(response["User"]["Username"]) + filesystem = filesystem_getter_dep(response["User"]["Username"]) filesystem.init() except client.exceptions.ClientError as e: if e.response["Error"]["Code"] == "UsernameExistsException": diff --git a/api/endpoints/files.py b/api/endpoints/files.py index 0194102..e95e433 100644 --- a/api/endpoints/files.py +++ b/api/endpoints/files.py @@ -6,7 +6,7 @@ from api import models from api.core.filesystem import FileSystem -from api.dependencies import filesystem_dep +from api.dependencies import user_filesystem_dep from api.schemas import file as file_schemas router = APIRouter() @@ -19,7 +19,7 @@ description="Download a file", ) def download_file( - file_path: str, filesystem: FileSystem = Depends(filesystem_dep) + file_path: str, filesystem: FileSystem = Depends(user_filesystem_dep) ) -> FileResponse | StreamingResponse: try: return filesystem.download(file_path) @@ -33,7 +33,9 @@ def download_file( description="Get request parameters (pre-signed URL) to download a file", ) def get_download_presigned_url( - file_path: str, request: Request, filesystem: FileSystem = Depends(filesystem_dep) + file_path: str, + request: Request, + filesystem: FileSystem = Depends(user_filesystem_dep), ) -> file_schemas.FileHTTPRequest: try: return filesystem.download_url( @@ -52,7 +54,7 @@ def list_files( base_path: str = "", show_dirs: bool = True, recursive: bool = False, - filesystem: FileSystem = Depends(filesystem_dep), + filesystem: FileSystem = Depends(user_filesystem_dep), ) -> list[file_schemas.FileInfo]: try: return sorted( @@ -73,7 +75,7 @@ def upload_file( f_type: models.UploadFileTypes, base_path: str, file: UploadFile, - filesystem: FileSystem = Depends(filesystem_dep), + filesystem: FileSystem = Depends(user_filesystem_dep), ) -> file_schemas.FileInfo: base_path = f"{f_type.value}/" + base_path file_path = os.path.join(base_path, file.filename or "unnamed") @@ -91,7 +93,7 @@ def get_upload_presigned_url( f_type: models.UploadFileTypes, base_path: str, request: Request, - filesystem: FileSystem = Depends(filesystem_dep), + filesystem: FileSystem = Depends(user_filesystem_dep), ) -> file_schemas.FileHTTPRequest: base_path = f"{f_type.value}/" + base_path return filesystem.create_file_url( @@ -107,7 +109,7 @@ def get_upload_presigned_url( def create_directory( f_type: models.UploadFileTypes, base_path: str, - filesystem: FileSystem = Depends(filesystem_dep), + filesystem: FileSystem = Depends(user_filesystem_dep), ) -> None: return filesystem.create_directory(f"{f_type.value}/{base_path}/") @@ -120,7 +122,7 @@ def create_directory( def rename_file( file_path: str, file: file_schemas.FileUpdate, - filesystem: FileSystem = Depends(filesystem_dep), + filesystem: FileSystem = Depends(user_filesystem_dep), ) -> file_schemas.FileInfo: try: filesystem.rename(file_path, file.path) @@ -141,6 +143,6 @@ def rename_file( description="Delete a file or directory", ) def delete_file( - file_path: str, filesystem: FileSystem = Depends(filesystem_dep) + file_path: str, filesystem: FileSystem = Depends(user_filesystem_dep) ) -> None: filesystem.delete(file_path) diff --git a/api/endpoints/jobs.py b/api/endpoints/jobs.py index 1d047cd..28e6299 100644 --- a/api/endpoints/jobs.py +++ b/api/endpoints/jobs.py @@ -4,8 +4,9 @@ from sqlalchemy.orm import Session import api.database as database +from api.core.filesystem import FileSystem from api.crud import job as crud -from api.dependencies import enqueueing_function_dep +from api.dependencies import enqueueing_function_dep, user_filesystem_dep from api.schemas.job import Job, JobCreate, QueueJob from api.settings import application_config @@ -60,11 +61,13 @@ def start_job( request: Request, job: JobCreate, db: Session = Depends(database.get_db), + filesystem: FileSystem = Depends(user_filesystem_dep), enqueueing_func: Callable[[QueueJob], None] = Depends(enqueueing_function_dep), ) -> Job: try: return crud.create_job( db, + filesystem, enqueueing_func, job, user_id=request.state.current_user.username, @@ -81,11 +84,14 @@ def start_job( description="Delete a job", ) def delete_job( - request: Request, job_id: int, db: Session = Depends(database.get_db) + request: Request, + job_id: int, + db: Session = Depends(database.get_db), + filesystem: FileSystem = Depends(user_filesystem_dep), ) -> None: db_job = crud.get_job(db, job_id) if db_job is None or db_job.user_id != request.state.current_user.username: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Job not found" ) - crud.delete_job(db, db_job) + crud.delete_job(db, filesystem, db_job) diff --git a/api/settings.py b/api/settings.py index ad57b3d..88d99c2 100644 --- a/api/settings.py +++ b/api/settings.py @@ -27,7 +27,7 @@ def _load_possibly_aws_secret(name: str) -> str | None: if os.environ.get("DATABASE_SECRET"): # set and not None database_secret = _load_possibly_aws_secret("DATABASE_SECRET") database_url = database_url.format(database_secret) -filesystem = os.environ.get("FILESYSTEM") +filesystem = os.environ.get("FILESYSTEM", "local") s3_bucket = os.environ.get("S3_BUCKET") s3_region = os.environ.get("S3_REGION", "eu-central-1") user_data_root_path = os.environ.get("USER_DATA_ROOT_PATH", "/data") diff --git a/tests/conftest.py b/tests/conftest.py index ffa8f67..6b86382 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import datetime import secrets import time -from typing import Any, Generator +from typing import Any from unittest.mock import MagicMock import boto3 @@ -18,16 +18,10 @@ REGION_NAME: BucketLocationConstraintType = "eu-central-1" -@pytest.fixture(scope="session") -def monkeypatch_module() -> Generator[pytest.MonkeyPatch, Any, None]: - with pytest.MonkeyPatch.context() as mp: - yield mp - - -@pytest.fixture(autouse=True, scope="function") -def enqueueing_func(monkeypatch_module: pytest.MonkeyPatch) -> MagicMock: +@pytest.fixture(autouse=True) +def enqueueing_func(monkeypatch: pytest.MonkeyPatch) -> MagicMock: mock_enqueueing_function = MagicMock() - monkeypatch_module.setitem( + monkeypatch.setitem( app.dependency_overrides, enqueueing_function_dep, # type: ignore lambda: mock_enqueueing_function, diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 8e935fe..21f01f6 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,7 +1,7 @@ import os import shutil from io import BytesIO -from typing import Any, Generator, cast +from typing import Any, Callable, Generator, cast from unittest.mock import Mock import pytest @@ -16,19 +16,19 @@ FileSystem, LocalFilesystem, S3Filesystem, - get_user_filesystem, + user_filesystem_getter, ) from api.database import Base, get_db from api.dependencies import ( APIKeyDependency, current_user_dep, email_sender_dep, - filesystem_dep, + filesystem_getter_dep, workerfacing_api_auth_dep, ) from api.main import app from api.models import Job -from tests.conftest import REGION_NAME, RDSTestingInstance, S3TestingBucket +from tests.conftest import RDSTestingInstance, S3TestingBucket @pytest.fixture(scope="session") @@ -104,40 +104,16 @@ def db_session( def base_filesystem( env: str, base_user_dir: str, - monkeypatch_module: pytest.MonkeyPatch, + monkeypatch: pytest.MonkeyPatch, s3_testing_bucket: S3TestingBucket, ) -> Generator[FileSystem, Any, None]: if env == "local": base_user_dir = f"./{base_user_dir}" - - monkeypatch_module.setattr( - settings, - "user_data_root_path", - base_user_dir, - ) - monkeypatch_module.setattr( - settings, - "s3_region", - REGION_NAME, - ) - monkeypatch_module.setattr( - settings, - "filesystem", - "local" if env == "local" else "s3", - ) - - if env == "local": shutil.rmtree(base_user_dir, ignore_errors=True) yield LocalFilesystem(base_user_dir) shutil.rmtree(base_user_dir, ignore_errors=True) elif env == "aws": - # Update settings to use the actual unique bucket name created by S3TestingBucket - monkeypatch_module.setattr( - settings, - "s3_bucket", - s3_testing_bucket.bucket_name, - ) yield S3Filesystem( base_user_dir, s3_testing_bucket.s3_client, s3_testing_bucket.bucket_name ) @@ -148,15 +124,29 @@ def base_filesystem( @pytest.fixture -def user_filesystem(base_filesystem: FileSystem, username: str) -> FileSystem: - return get_user_filesystem(username) +def filesystem_getter( + base_filesystem: FileSystem, + base_user_dir: str, + s3_testing_bucket: S3TestingBucket, +) -> Callable[[str], FileSystem]: + return user_filesystem_getter( + user_data_root_path=base_user_dir, + filesystem="s3" if isinstance(base_filesystem, S3Filesystem) else "local", + s3_region=s3_testing_bucket.region_name, + s3_bucket=s3_testing_bucket.bucket_name, + ) + + +@pytest.fixture +def user_filesystem( + filesystem_getter: Callable[[str], FileSystem], username: str +) -> FileSystem: + return filesystem_getter(username) @pytest.fixture(autouse=True) -def override_db_dep( - db_session: Session, monkeypatch_module: pytest.MonkeyPatch -) -> None: - monkeypatch_module.setitem( +def override_db_dep(db_session: Session, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem( app.dependency_overrides, # type: ignore get_db, lambda: db_session, @@ -164,50 +154,50 @@ def override_db_dep( @pytest.fixture(autouse=True) -def override_filesystem_dep( - user_filesystem: FileSystem, monkeypatch_module: pytest.MonkeyPatch +def override_user_filesystem_getter( + filesystem_getter: Callable[[str], FileSystem], monkeypatch: pytest.MonkeyPatch ) -> None: - monkeypatch_module.setitem( + monkeypatch.setitem( app.dependency_overrides, # type: ignore - filesystem_dep, - lambda: user_filesystem, + filesystem_getter_dep, + lambda: filesystem_getter, ) -@pytest.fixture(autouse=True, scope="session") +@pytest.fixture(autouse=True) def override_auth( - monkeypatch_module: pytest.MonkeyPatch, username: str, user_email: str + monkeypatch: pytest.MonkeyPatch, username: str, user_email: str ) -> None: - monkeypatch_module.setitem( + monkeypatch.setitem( app.dependency_overrides, # type: ignore current_user_dep, lambda: CognitoClaims(**{"cognito:username": username, "email": user_email}), ) -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(autouse=True) def override_internal_api_key_secret( - monkeypatch_module: pytest.MonkeyPatch, internal_api_key_secret: str + monkeypatch: pytest.MonkeyPatch, internal_api_key_secret: str ) -> None: - monkeypatch_module.setitem( + monkeypatch.setitem( app.dependency_overrides, # type: ignore workerfacing_api_auth_dep, APIKeyDependency(internal_api_key_secret), ) -@pytest.fixture(scope="session", autouse=True) -def override_email_sender(monkeypatch_module: pytest.MonkeyPatch) -> None: - monkeypatch_module.setitem( +@pytest.fixture(autouse=True) +def override_email_sender(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem( app.dependency_overrides, # type: ignore email_sender_dep, lambda: notifications.DummyEmailSender(), ) -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(autouse=True) def override_application_config( - monkeypatch_module: pytest.MonkeyPatch, application: dict[str, str] + monkeypatch: pytest.MonkeyPatch, application: dict[str, str] ) -> None: application_config = Mock() application_config.config = { @@ -236,7 +226,7 @@ def override_application_config( }, }, } - monkeypatch_module.setattr(settings, "application_config", application_config) + monkeypatch.setattr(settings, "application_config", application_config) @pytest.fixture diff --git a/tests/unit/core/test_filesystem.py b/tests/unit/core/test_filesystem.py index 9bb81fd..00f63ec 100644 --- a/tests/unit/core/test_filesystem.py +++ b/tests/unit/core/test_filesystem.py @@ -57,10 +57,8 @@ def data_file1( def test_list_directory_file( self, filesystem: FileSystem, - monkeypatch: pytest.MonkeyPatch, data_file1_name: str, ) -> None: - monkeypatch.setattr(filesystem, "isdir", lambda path: False) with pytest.raises(NotADirectoryError): filesystem.list_directory(data_file1_name) From deb41ea94d048eb3ca9934f081a594b7a4dae535 Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Sun, 1 Mar 2026 23:09:06 +0100 Subject: [PATCH 2/8] Backup local DB to S3 --- .vscode/settings.json | 3 + api/core/auth.py | 33 ++++++ api/core/database.py | 121 +++++++++++++++++++ api/core/filesystem.py | 22 +--- api/database.py | 26 ----- api/dependencies.py | 79 +++++++------ api/endpoints/job_update.py | 5 +- api/endpoints/jobs.py | 24 ++-- api/main.py | 43 +++++-- api/models.py | 8 +- api/schemas/job.py | 90 ++++++-------- api/settings.py | 29 ++--- tests/conftest.py | 48 ++++++-- tests/integration/conftest.py | 149 +++++++++++------------- tests/integration/endpoints/conftest.py | 13 +++ tests/integration/test_main.py | 105 +++++++++++++++++ tests/unit/core/test_filesystem.py | 10 +- 17 files changed, 544 insertions(+), 264 deletions(-) create mode 100644 .vscode/settings.json create mode 100644 api/core/auth.py create mode 100644 api/core/database.py delete mode 100644 api/database.py create mode 100644 tests/integration/endpoints/conftest.py create mode 100644 tests/integration/test_main.py diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..c9ebf2d --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python-envs.defaultEnvManager": "ms-python.python:system" +} \ No newline at end of file diff --git a/api/core/auth.py b/api/core/auth.py new file mode 100644 index 0000000..9d11104 --- /dev/null +++ b/api/core/auth.py @@ -0,0 +1,33 @@ +from typing import Any + +from fastapi import Header, HTTPException +from fastapi.security import HTTPAuthorizationCredentials +from fastapi_cloudauth.cognito import CognitoClaims, CognitoCurrentUser # type: ignore +from pydantic import Field + + +# https://github.com/iwpnd/fastapi-key-auth/blob/main/fastapi_key_auth/dependency/authorizer.py +class APIKeyDependency: + def __init__(self, key: str | None): + self.key = key + + def __call__(self, x_api_key: str | None = Header(...)) -> str | None: + if x_api_key != self.key: + raise HTTPException(status_code=401, detail="unauthorized") + return x_api_key + + +class GroupClaims(CognitoClaims): # type: ignore + cognito_groups: list[str] | None = Field(alias="cognito:groups") + + +class UserGroupCognitoCurrentUser(CognitoCurrentUser): # type: ignore + user_info = GroupClaims + + async def call(self, http_auth: HTTPAuthorizationCredentials) -> Any: + user_info = await super().call(http_auth) + if "users" not in (getattr(user_info, "cognito_groups") or []): + raise HTTPException( + status_code=403, detail="Not a member of the 'users' group" + ) + return user_info diff --git a/api/core/database.py b/api/core/database.py new file mode 100644 index 0000000..57904fa --- /dev/null +++ b/api/core/database.py @@ -0,0 +1,121 @@ +import gzip +import os +import sqlite3 +import tempfile +import time +from typing import Any, cast + +from botocore.exceptions import ClientError +from mypy_boto3_s3 import S3Client +from sqlalchemy import Engine, create_engine + +from api.models import Base + + +class Database: + """Database wrapper.""" + + def __init__( + self, + db_url: str, + connect_kwargs: dict[str, Any] | None = None, + ): + self.db_url = db_url + self.connect_kwargs = connect_kwargs or {} + + @property + def engine(self) -> Engine: + if hasattr(self, "_engine"): + return cast(Engine, self._engine) # type: ignore[has-type] + retries = 0 + while True: + try: + engine = create_engine(self.db_url, connect_args=self.connect_kwargs) + # Attempt to create a connection or perform any necessary operations + engine.connect() + self._engine = engine + return engine # Connection successful + except Exception as e: + if retries >= 10: + raise RuntimeError(f"Could not create engine: {str(e)}") + retries += 1 + time.sleep(60) + + def create(self) -> None: + """Create database tables.""" + Base.metadata.create_all(bind=self.engine) + + def backup(self) -> bool: + """Backup the database. To be implemented by subclasses if supported.""" + return False + + +class SqliteDatabase(Database): + """SQLite database wrapper with optional S3 backup support.""" + + BACKUP_KEY = "userapi_sqlite_backup/backup.db.gz" + + def __init__( + self, + db_url: str, + s3_client: S3Client | None = None, + s3_bucket: str | None = None, + ): + if not db_url.startswith("sqlite:///"): + raise ValueError(f"SQLiteRDSJobQueue requires SQLite DB URL, got: {db_url}") + if not ((s3_client is None) == (s3_bucket is None)): + raise ValueError( + "Both s3_client and s3_bucket must be provided for S3 backup/restore, or both must be None." + ) + self.s3_client = s3_client + self.s3_bucket = s3_bucket + super().__init__(db_url, connect_kwargs={"check_same_thread": False}) + + def create(self) -> None: + self._restore_database() + super().create() + + @property + def db_path(self) -> str: + return self.db_url[len("sqlite:///") :] + + def backup(self) -> bool: + """Backup the SQLite database to S3.""" + if not self.s3_bucket or not self.s3_client: + return False + + with tempfile.TemporaryDirectory() as temp_dir: + tmp_backup_path = os.path.join(temp_dir, "backup.db") + tmp_gzip_path = os.path.join(temp_dir, "backup.db.gz") + with sqlite3.connect(self.db_path) as source_conn: + with sqlite3.connect(tmp_backup_path) as backup_conn: + source_conn.backup(backup_conn) + + with open(tmp_backup_path, "rb") as f_in: + with gzip.open(tmp_gzip_path, "wb") as f_out: + f_out.writelines(f_in) + self.s3_client.upload_file(tmp_gzip_path, self.s3_bucket, self.BACKUP_KEY) + return True + + def _restore_database(self) -> bool: + """Restore the SQLite database from S3.""" + if not self.s3_bucket or not self.s3_client: + return False + + try: + self.s3_client.head_object(Bucket=self.s3_bucket, Key=self.BACKUP_KEY) + except ClientError as e: + if e.response["Error"]["Code"] == "404": + return False + raise + + with tempfile.TemporaryDirectory() as temp_dir: + tmp_gzip_path = os.path.join(temp_dir, "backup.db.gz") + tmp_backup_path = os.path.join(temp_dir, "backup.db") + self.s3_client.download_file(self.s3_bucket, self.BACKUP_KEY, tmp_gzip_path) + with gzip.open(tmp_gzip_path, "rb") as f_in: + with open(tmp_backup_path, "wb") as f_out: + f_out.write(f_in.read()) + os.makedirs(os.path.dirname(self.db_path), exist_ok=True) + os.rename(tmp_backup_path, self.db_path) + return True diff --git a/api/core/filesystem.py b/api/core/filesystem.py index 5745a2c..5aded9f 100644 --- a/api/core/filesystem.py +++ b/api/core/filesystem.py @@ -7,11 +7,8 @@ from pathlib import Path, PurePosixPath from typing import Any, BinaryIO, Callable, Generator -import boto3 import humanize -from botocore.client import Config from botocore.response import StreamingBody -from botocore.utils import fix_s3_host from fastapi import Request from fastapi.responses import FileResponse, StreamingResponse from mypy_boto3_s3 import S3Client @@ -390,23 +387,16 @@ def download_url( def get_filesystem_with_root( root_path: str, filesystem: str, - s3_region: str, - s3_bucket: str | None, + s3_bucket: str | None = None, + s3_client: S3Client | None = None, ) -> FileSystem: """Get the filesystem to use.""" predef_dirs = [e.value for e in models.UploadFileTypes] + [ e.value for e in models.OutputEndpoints ] if filesystem == "s3": + assert s3_client is not None, "S3 client must be provided for S3 filesystem" assert s3_bucket is not None, "S3 bucket must be provided for S3 filesystem" - s3_client = boto3.client( - "s3", - region_name=s3_region, - endpoint_url=f"https://s3.{s3_region}.amazonaws.com", - config=Config(signature_version="v4", s3={"addressing_style": "path"}), - ) - # this and config=... required to avoid DNS problems with new buckets - s3_client.meta.events.unregister("before-sign.s3", fix_s3_host) return S3Filesystem(root_path, s3_client, s3_bucket, predef_dirs=predef_dirs) elif filesystem == "local": return LocalFilesystem(root_path, predef_dirs=predef_dirs) @@ -417,13 +407,13 @@ def get_filesystem_with_root( def user_filesystem_getter( user_data_root_path: str, filesystem: str, - s3_region: str, - s3_bucket: str | None, + s3_bucket: str | None = None, + s3_client: S3Client | None = None, ) -> Callable[[str], FileSystem]: """Get the filesystem to use for a user.""" return lambda user_id: get_filesystem_with_root( str(Path(user_data_root_path) / user_id), filesystem=filesystem, - s3_region=s3_region, s3_bucket=s3_bucket, + s3_client=s3_client, ) diff --git a/api/database.py b/api/database.py deleted file mode 100644 index d2a955b..0000000 --- a/api/database.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import Any, Generator - -from sqlalchemy import create_engine -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import Session, sessionmaker - -from api.settings import database_url - -engine = create_engine( - database_url, - connect_args=( - {"check_same_thread": False} if database_url.startswith("sqlite") else {} - ), -) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - -Base = declarative_base() - - -def get_db() -> Generator[Session, Any, None]: - """Get database session.""" - db = SessionLocal() - try: - yield db - finally: - db.close() diff --git a/api/dependencies.py b/api/dependencies.py index 6de255c..55636d7 100644 --- a/api/dependencies.py +++ b/api/dependencies.py @@ -1,42 +1,60 @@ -from typing import Any, Callable +from typing import Any, Callable, Generator +import boto3 import requests -from fastapi import Depends, Header, HTTPException, Request +from botocore.config import Config +from botocore.utils import fix_s3_host +from fastapi import Depends, HTTPException, Request from fastapi.encoders import jsonable_encoder -from fastapi.security import HTTPAuthorizationCredentials -from fastapi_cloudauth.cognito import CognitoClaims, CognitoCurrentUser # type: ignore -from pydantic import BaseModel, Field +from fastapi_cloudauth.cognito import CognitoClaims # type: ignore +from sqlalchemy.orm import Session, sessionmaker from api import settings from api.core import notifications +from api.core.auth import APIKeyDependency, UserGroupCognitoCurrentUser +from api.core.database import Database, SqliteDatabase from api.core.filesystem import FileSystem, user_filesystem_getter from api.schemas.job import QueueJob +# S3 client setup +s3_client = None +if settings.s3_bucket: + s3_client = boto3.client( + "s3", + region_name=settings.s3_region, + endpoint_url=f"https://s3.{settings.s3_region}.amazonaws.com", + config=Config(signature_version="v4", s3={"addressing_style": "path"}), + ) + # this and config=... required to avoid DNS problems with new buckets + s3_client.meta.events.unregister("before-sign.s3", fix_s3_host) -class GroupClaims(CognitoClaims): # type: ignore - """CognitoClaims with added groups claim.""" - cognito_groups: list[str] | None = Field(alias="cognito:groups") +# Database +if settings.database_url.startswith("sqlite"): + db: Database = SqliteDatabase( + db_url=settings.database_url, + s3_client=s3_client, + s3_bucket=settings.s3_bucket, + ) +else: + db = Database(db_url=settings.database_url) -class UserGroupCognitoCurrentUser(CognitoCurrentUser): # type: ignore - """ - Check membership in the 'users' group and add group membership information. - """ +async def db_dep() -> Database: + return db - user_info = GroupClaims - async def call( - self, http_auth: HTTPAuthorizationCredentials - ) -> BaseModel | dict[str, Any] | None: - user_info = await super().call(http_auth) - if "users" not in (getattr(user_info, "cognito_groups") or []): - raise HTTPException( - status_code=403, detail="Not a member of the 'users' group" - ) - return user_info # type: ignore +def session_dep(db_dep: Database = Depends(db_dep)) -> Generator[Session, Any, None]: + """Get database session.""" + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=db_dep.engine) + db = SessionLocal() + try: + yield db + finally: + db.close() +# User authentication current_user_dep = UserGroupCognitoCurrentUser( region=settings.cognito_region, userPoolId=settings.cognito_user_pool_id, @@ -52,12 +70,13 @@ async def current_user_global_dep( return current_user +# Filesystem async def filesystem_getter_dep() -> Callable[[str], FileSystem]: """Get the user's filesystem getter.""" return user_filesystem_getter( user_data_root_path=settings.user_data_root_path, filesystem=settings.filesystem, - s3_region=settings.s3_region, + s3_client=s3_client, s3_bucket=settings.s3_bucket, ) @@ -70,20 +89,11 @@ async def user_filesystem_dep( return filesystem_getter(current_user.username) -class APIKeyDependency: - def __init__(self, key: str | None): - """Check API-internal key.""" - self.key = key - - def __call__(self, x_api_key: str | None = Header(...)) -> str | None: - if x_api_key != self.key: - raise HTTPException(status_code=401, detail="unauthorized") - return x_api_key - - +# App-internal authentication (i.e. user-facing API <-> worker-facing API) workerfacing_api_auth_dep = APIKeyDependency(settings.internal_api_key_secret) +# Notifications async def email_sender_dep() -> notifications.EmailSender: """Get the email sender.""" service = settings.email_sender_service @@ -110,6 +120,7 @@ async def email_sender_dep() -> notifications.EmailSender: ) +# Job enqueueing to worker-facing API async def enqueueing_function_dep() -> Callable[[QueueJob], None]: def enqueue(queue_item: QueueJob) -> None: resp = requests.post( diff --git a/api/endpoints/job_update.py b/api/endpoints/job_update.py index ef5034b..7bc2a7d 100644 --- a/api/endpoints/job_update.py +++ b/api/endpoints/job_update.py @@ -5,8 +5,7 @@ import api.core.notifications as notifications import api.crud.job as job_crud -from api.database import get_db -from api.dependencies import email_sender_dep, workerfacing_api_auth_dep +from api.dependencies import email_sender_dep, session_dep, workerfacing_api_auth_dep from api.models import JobStates from api.schemas.job_update import JobUpdate @@ -20,7 +19,7 @@ ) def update_job_status( update: JobUpdate, - db: Session = Depends(get_db), + db: Session = Depends(session_dep), email_sender: notifications.EmailSender = Depends(email_sender_dep), ) -> JobStates: db_job = job_crud.get_job(db, update.job_id) diff --git a/api/endpoints/jobs.py b/api/endpoints/jobs.py index 28e6299..f5abd76 100644 --- a/api/endpoints/jobs.py +++ b/api/endpoints/jobs.py @@ -3,10 +3,9 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status from sqlalchemy.orm import Session -import api.database as database from api.core.filesystem import FileSystem from api.crud import job as crud -from api.dependencies import enqueueing_function_dep, user_filesystem_dep +from api.dependencies import enqueueing_function_dep, session_dep, user_filesystem_dep from api.schemas.job import Job, JobCreate, QueueJob from api.settings import application_config @@ -30,25 +29,23 @@ def list_jobs( request: Request, offset: int = 0, limit: int = 100, - db: Session = Depends(database.get_db), + db: Session = Depends(session_dep), ) -> list[Job]: - return sorted( - crud.get_jobs(db, request.state.current_user.username, offset, limit), - key=lambda x: x.date_created, - reverse=True, - ) + db_jobs = crud.get_jobs(db, request.state.current_user.username, offset, limit) + jobs = [Job.model_validate(db_job) for db_job in db_jobs] # models -> schemas + return sorted(jobs, key=lambda x: x.date_created, reverse=True) @router.get("/jobs/{job_id}", response_model=Job, description="Describe a job") def describe_job( - request: Request, job_id: int, db: Session = Depends(database.get_db) + request: Request, job_id: int, db: Session = Depends(session_dep) ) -> Job: db_job = crud.get_job(db, job_id) if db_job is None or db_job.user_id != request.state.current_user.username: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Job not found" ) - return db_job + return Job.model_validate(db_job) @router.post( @@ -60,12 +57,12 @@ def describe_job( def start_job( request: Request, job: JobCreate, - db: Session = Depends(database.get_db), + db: Session = Depends(session_dep), filesystem: FileSystem = Depends(user_filesystem_dep), enqueueing_func: Callable[[QueueJob], None] = Depends(enqueueing_function_dep), ) -> Job: try: - return crud.create_job( + db_job = crud.create_job( db, filesystem, enqueueing_func, @@ -73,6 +70,7 @@ def start_job( user_id=request.state.current_user.username, user_email=request.state.current_user.email, ) + return Job.model_validate(db_job) except FileNotFoundError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) @@ -86,7 +84,7 @@ def start_job( def delete_job( request: Request, job_id: int, - db: Session = Depends(database.get_db), + db: Session = Depends(session_dep), filesystem: FileSystem = Depends(user_filesystem_dep), ) -> None: db_job = crud.get_job(db, job_id) diff --git a/api/main.py b/api/main.py index 756377d..9875641 100644 --- a/api/main.py +++ b/api/main.py @@ -1,3 +1,8 @@ +import asyncio +import logging +from contextlib import asynccontextmanager +from typing import AsyncGenerator + import dotenv dotenv.load_dotenv() @@ -6,11 +11,40 @@ from fastapi.middleware.cors import CORSMiddleware from api import dependencies, settings, tags -from api.database import Base, engine +from api.core.database import Database from api.endpoints import auth, auth_get, files, job_update, jobs from api.exceptions import register_exception_handlers -app = FastAPI(openapi_tags=tags.tags_metadata) +logger = logging.getLogger(__name__) + + +async def cron_backup_database(db: Database) -> None: + while True: + logger.info("Database backup: starting...") + # Run backup in thread pool to avoid blocking event loop; + # Fine instead of making backup async since it runs infrequently. + try: + if await asyncio.to_thread(db.backup): + logger.info("Backed up database.") + except Exception as e: + logger.error(f"Database backup failed with {e}") + await asyncio.sleep(settings.cron_backup_interval) + + +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + db = app.dependency_overrides.get(dependencies.db_dep, dependencies.db_dep)() + assert isinstance(db, Database) + db.create() + task_backup = asyncio.create_task(cron_backup_database(db)) + yield + task_backup.cancel() + await asyncio.gather(task_backup, return_exceptions=True) + if db.backup(): + logger.info("Created final backup on shutdown.") + + +app = FastAPI(openapi_tags=tags.tags_metadata, lifespan=lifespan) if settings.frontend_url: app.add_middleware( CORSMiddleware, @@ -43,8 +77,3 @@ @app.get("/") async def root() -> str: return "Welcome to the DECODE OpenCloud User-facing API" - - -@app.on_event("startup") -def on_startup() -> None: - Base.metadata.create_all(bind=engine) diff --git a/api/models.py b/api/models.py index f530cac..284bd89 100644 --- a/api/models.py +++ b/api/models.py @@ -10,9 +10,11 @@ Text, UniqueConstraint, ) -from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import DeclarativeBase, mapped_column -from api.database import Base + +class Base(DeclarativeBase): + pass class JobStates(enum.Enum): @@ -43,7 +45,7 @@ class UploadFileTypes(enum.Enum): artifact = "artifact" -class Job(Base): # type: ignore +class Job(Base): __tablename__ = "jobs" id = mapped_column(Integer, primary_key=True, index=True) diff --git a/api/schemas/job.py b/api/schemas/job.py index bec6a24..3d9f2dc 100644 --- a/api/schemas/job.py +++ b/api/schemas/job.py @@ -1,7 +1,7 @@ import datetime -from typing import Any +from typing import Self -from pydantic import BaseModel, validator +from pydantic import BaseModel, Field, model_validator from api import settings from api.models import EnvironmentTypes, JobStates, OutputEndpoints @@ -20,32 +20,28 @@ class Application(BaseModel): version: str entrypoint: str - @validator("application") - def application_check(cls: "Application", v: str, values: dict[str, str]) -> str: - allowed = list(settings.application_config.config.keys()) - if v not in allowed: - raise ValueError(f"Application must be one of {allowed}, not {v}.") - return v - - @validator("version") - def version_check(cls: "Application", v: str, values: dict[str, str]) -> str: - if "application" not in values: - raise ValueError("Application must be set before version.") - allowed = settings.application_config.config[values["application"]].keys() - if v not in allowed: - raise ValueError(f"Version must be one of {allowed}, not {v}.") - return v - - @validator("entrypoint") - def entrypoint_check(cls: "Application", v: str, values: dict[str, str]) -> str: - if "application" not in values or "version" not in values: - raise ValueError("Application and version must be set before entrypoint.") - allowed = settings.application_config.config[values["application"]][ - values["version"] - ].keys() - if v not in allowed: - raise ValueError(f"Entrypoint must be one of {allowed}, not {v}.") - return v + @model_validator(mode="after") + def application_check(self) -> Self: + allowed_apps = list(settings.application_config.config.keys()) + if self.application not in allowed_apps: + raise ValueError( + f"Application must be one of {allowed_apps}, not {self.application}." + ) + allowed_versions = list( + settings.application_config.config[self.application].keys() + ) + if self.version not in allowed_versions: + raise ValueError( + f"Version must be one of {allowed_versions}, not {self.version}." + ) + allowed_entrypoints = list( + settings.application_config.config[self.application][self.version].keys() + ) + if self.entrypoint not in allowed_entrypoints: + raise ValueError( + f"Entrypoint must be one of {allowed_entrypoints}, not {self.entrypoint}." + ) + return self class InputJobAttributes(BaseModel): @@ -62,36 +58,22 @@ class JobAttributes(BaseModel): class JobBase(BaseModel): job_name: str environment: EnvironmentTypes | None = None - priority: int | None = None + priority: int = Field(0, ge=0, le=5) application: Application attributes: JobAttributes hardware: HardwareSpecs | None = None - @validator("attributes") - def env_check( - cls: "JobBase", v: JobAttributes, values: dict[str, Any] - ) -> JobAttributes: - app = values.get("application") - if not app: - raise ValueError("Application must be set before attributes.") - application = ( - app.application if hasattr(app, "application") else app["application"] - ) - version = app.version if hasattr(app, "version") else app["version"] - entrypoint = app.entrypoint if hasattr(app, "entrypoint") else app["entrypoint"] - config = settings.application_config.config[application][version][entrypoint] - allowed = config["app"]["env"] - if v.env_vars is not None and not all(v_ in allowed for v_ in v.env_vars): - raise ValueError(f"Environment variables must be in {allowed}.") - return v - - @validator("priority") - def priority_check(cls: "JobBase", v: int | None, values: dict[str, Any]) -> int: - if v is None: - v = 0 - elif v < 0 or v > 5: - raise ValueError(f"Priority must be between 0 and 5, not {v}.") - return v + @model_validator(mode="after") + def env_check(self) -> Self: + config = settings.application_config.config[self.application.application][ + self.application.version + ][self.application.entrypoint] + allowed_env_vars = config["app"]["env"] + if self.attributes.env_vars is not None and not all( + v_ in allowed_env_vars for v_ in self.attributes.env_vars + ): + raise ValueError(f"Environment variables must be in {allowed_env_vars}.") + return self class JobReadBase(BaseModel): diff --git a/api/settings.py b/api/settings.py index 88d99c2..a4995b9 100644 --- a/api/settings.py +++ b/api/settings.py @@ -1,21 +1,24 @@ import abc import json import os -from typing import Any, cast +from typing import Any import boto3 import yaml -def _load_possibly_aws_secret(name: str) -> str | None: - """Load environment variable and read password if it is a secret from AWS Secrets Manager.""" - value = os.environ.get(name) - if not value: - return value - try: - return cast(str, json.loads(value)["password"]) # AWS Secrets Manager - except json.JSONDecodeError: - return value +def get_secret_from_env(secret_name: str) -> str | None: + secret = os.environ.get(secret_name) + if secret: # exists and not None + try: + secret = json.loads(secret)["password"] # AWS Secrets Manager + except json.JSONDecodeError: + pass + return secret + + +# Cron job intervals +cron_backup_interval = 3600 # 1 hour # Stage @@ -25,7 +28,7 @@ def _load_possibly_aws_secret(name: str) -> str | None: # Data database_url = os.environ.get("DATABASE_URL", "sqlite:///./sql_app.db") if os.environ.get("DATABASE_SECRET"): # set and not None - database_secret = _load_possibly_aws_secret("DATABASE_SECRET") + database_secret = get_secret_from_env("DATABASE_SECRET") database_url = database_url.format(database_secret) filesystem = os.environ.get("FILESYSTEM", "local") s3_bucket = os.environ.get("S3_BUCKET") @@ -35,14 +38,14 @@ def _load_possibly_aws_secret(name: str) -> str | None: # Worker-facing API workerfacing_api_url = os.environ.get("WORKERFACING_API_URL", "http://127.0.0.1:8001") -internal_api_key_secret = _load_possibly_aws_secret("INTERNAL_API_KEY_SECRET") +internal_api_key_secret = get_secret_from_env("INTERNAL_API_KEY_SECRET") # Authentication cognito_user_pool_id = os.environ.get("COGNITO_USER_POOL_ID", "") cognito_region = os.environ.get("COGNITO_REGION", "eu-central-1") cognito_client_id = os.environ.get("COGNITO_CLIENT_ID", "") -cognito_secret = _load_possibly_aws_secret("COGNITO_SECRET") +cognito_secret = get_secret_from_env("COGNITO_SECRET") # Email sender diff --git a/tests/conftest.py b/tests/conftest.py index 6b86382..91ec3ab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import datetime import secrets import time -from typing import Any +from typing import Any, Generator from unittest.mock import MagicMock import boto3 @@ -39,7 +39,7 @@ def create(self) -> None: self.add_ingress_rule() self.db_url = self.create_db_url() self.engine = self.get_engine() - self.delete_db_tables() + self.cleanup() def get_engine(self) -> Engine: for _ in range(5): @@ -78,7 +78,20 @@ def add_ingress_rule(self) -> None: else: raise e - def delete_db_tables(self) -> None: + def remove_ingress_rules(self) -> None: + # cleans up earlier tests too (in case of failures) + security_groups = self.ec2_client.describe_security_groups( + GroupNames=[self.vpc_sg_rule_params["GroupName"]] + ) + for sg in security_groups["SecurityGroups"]: + for rule in sg["IpPermissions"]: + if rule.get("FromPort") == 5432 and rule.get("ToPort") == 5432: + self.ec2_client.revoke_security_group_ingress( + GroupId=sg["GroupId"], + IpPermissions=[rule], # type: ignore + ) + + def cleanup(self) -> None: metadata = MetaData() engine = self.engine metadata.reflect(engine) @@ -137,11 +150,11 @@ def create_db_url(self) -> str: address = response["DBInstances"][0]["Endpoint"]["Address"] return f"postgresql://{user}:{password}@{address}:5432/{self.db_name}" - def cleanup(self) -> None: - self.delete_db_tables() - self.ec2_client.revoke_security_group_ingress(**self.vpc_sg_rule_params) - def delete(self) -> None: + # never used (AWS tests skipped) + if not hasattr(self, "rds_client"): + return + self.remove_ingress_rules() self.rds_client.delete_db_instance( DBInstanceIdentifier=self.db_name, SkipFinalSnapshot=True, @@ -188,20 +201,33 @@ def cleanup(self) -> bool: return True def delete(self) -> None: + # never used (AWS tests skipped) + if not hasattr(self, "s3_client"): + return exists = self.cleanup() if exists: self.s3_client.delete_bucket(Bucket=self.bucket_name) @pytest.fixture(scope="session") -def rds_testing_instance() -> RDSTestingInstance: - return RDSTestingInstance("decodecloudintegrationtestsuserapi") +def rds_testing_instance() -> Generator[RDSTestingInstance, Any, None]: + # tests themselves must create the instance by calling instance.create(); + # this way, if no test that needs the DB is run, no RDS instance is created + # instance.delete() only deletes the RDS instance if it was created + instance = RDSTestingInstance("decodecloudintegrationtestsworkerapi") + yield instance + instance.delete() @pytest.fixture(scope="session") -def s3_testing_bucket() -> S3TestingBucket: +def s3_testing_bucket() -> Generator[S3TestingBucket, Any, None]: + # tests themselves must create the bucket by calling bucket.create(); + # this way, if no test that needs the bucket is run, no S3 bucket is created + # bucket.delete() only deletes the S3 bucket if it was created bucket_suffix = datetime.datetime.now(datetime.UTC).strftime("%Y%m%d%H%M%S") - return S3TestingBucket(bucket_suffix) + bucket = S3TestingBucket(bucket_suffix) + yield bucket + bucket.delete() @pytest.mark.aws diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 21f01f6..3de5eb5 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,27 +1,26 @@ -import os import shutil from io import BytesIO -from typing import Any, Callable, Generator, cast +from typing import Any, Callable, Generator from unittest.mock import Mock import pytest -from fastapi.testclient import TestClient from fastapi_cloudauth.cognito import CognitoClaims # type: ignore -from sqlalchemy import create_engine +from mypy_boto3_s3 import S3Client from sqlalchemy.orm import Session from api import settings from api.core import notifications +from api.core.auth import APIKeyDependency +from api.core.database import Database, SqliteDatabase from api.core.filesystem import ( FileSystem, LocalFilesystem, S3Filesystem, user_filesystem_getter, ) -from api.database import Base, get_db from api.dependencies import ( - APIKeyDependency, current_user_dep, + db_dep, email_sender_dep, filesystem_getter_dep, workerfacing_api_auth_dep, @@ -32,12 +31,12 @@ @pytest.fixture(scope="session") -def username() -> str: +def test_username() -> str: return "test_user" @pytest.fixture(scope="session") -def user_email() -> str: +def test_user_email() -> str: return "user@example.com" @@ -58,67 +57,50 @@ def application() -> dict[str, str]: @pytest.fixture( scope="session", - params=["local", pytest.param("aws", marks=pytest.mark.aws)], + params=["local-fs", pytest.param("aws-fs", marks=pytest.mark.aws)], ) -def env( - request: pytest.FixtureRequest, - rds_testing_instance: RDSTestingInstance, +def base_filesystem( + base_user_dir: str, s3_testing_bucket: S3TestingBucket, -) -> Generator[str, Any, None]: - env = cast(str, request.param) - if env == "aws": - rds_testing_instance.create() + request: pytest.FixtureRequest, +) -> FileSystem: + if request.param == "local-fs": + return LocalFilesystem(base_user_dir) + elif request.param == "aws-fs": s3_testing_bucket.create() - yield env - if env == "aws": - rds_testing_instance.delete() - s3_testing_bucket.delete() - - -@pytest.fixture -def db_session( - env: str, rds_testing_instance: RDSTestingInstance -) -> Generator[Session, Any, None]: - if env == "local": - rel_test_db_path = "./test_app.db" - shutil.rmtree(rel_test_db_path, ignore_errors=True) - engine = create_engine( - f"sqlite:///{rel_test_db_path}", connect_args={"check_same_thread": False} + return S3Filesystem( + base_user_dir, s3_testing_bucket.s3_client, s3_testing_bucket.bucket_name ) - elif env == "aws": - engine = rds_testing_instance.engine else: raise NotImplementedError - Base.metadata.create_all(bind=engine) - with Session(engine) as session: - yield session - if env == "local": - os.remove(rel_test_db_path) - elif env == "aws": - rds_testing_instance.cleanup() - - -@pytest.fixture -def base_filesystem( - env: str, - base_user_dir: str, - monkeypatch: pytest.MonkeyPatch, +@pytest.fixture( + scope="session", + params=["local-db", pytest.param("aws-db", marks=pytest.mark.aws)], +) +def db( + base_filesystem: FileSystem, s3_testing_bucket: S3TestingBucket, -) -> Generator[FileSystem, Any, None]: - if env == "local": - base_user_dir = f"./{base_user_dir}" - shutil.rmtree(base_user_dir, ignore_errors=True) - yield LocalFilesystem(base_user_dir) - shutil.rmtree(base_user_dir, ignore_errors=True) - - elif env == "aws": - yield S3Filesystem( - base_user_dir, s3_testing_bucket.s3_client, s3_testing_bucket.bucket_name + rds_testing_instance: RDSTestingInstance, + tmpdir_factory: pytest.TempdirFactory, + request: pytest.FixtureRequest, +) -> Database: + if request.param == "local-db": + test_db_path = tmpdir_factory.mktemp("integration") / "test_app.db" + s3_bucket: str | None = None + s3_client: S3Client | None = None + if isinstance(base_filesystem, S3Filesystem): + s3_bucket = s3_testing_bucket.bucket_name + s3_client = s3_testing_bucket.s3_client + return SqliteDatabase( + db_url=f"sqlite:///{test_db_path}", s3_client=s3_client, s3_bucket=s3_bucket ) - s3_testing_bucket.cleanup() - + elif request.param == "aws-db": + if isinstance(base_filesystem, LocalFilesystem): + pytest.skip("Only testing RDS DB in combination with S3 filesystem") + rds_testing_instance.create() + return Database(db_url=rds_testing_instance.db_url) else: raise NotImplementedError @@ -132,46 +114,58 @@ def filesystem_getter( return user_filesystem_getter( user_data_root_path=base_user_dir, filesystem="s3" if isinstance(base_filesystem, S3Filesystem) else "local", - s3_region=s3_testing_bucket.region_name, s3_bucket=s3_testing_bucket.bucket_name, + s3_client=s3_testing_bucket.s3_client, ) @pytest.fixture def user_filesystem( - filesystem_getter: Callable[[str], FileSystem], username: str + filesystem_getter: Callable[[str], FileSystem], test_username: str ) -> FileSystem: - return filesystem_getter(username) + return filesystem_getter(test_username) @pytest.fixture(autouse=True) -def override_db_dep(db_session: Session, monkeypatch: pytest.MonkeyPatch) -> None: +def override_db_dep(db: Database, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setitem( app.dependency_overrides, # type: ignore - get_db, - lambda: db_session, + db_dep, + lambda: db, ) @pytest.fixture(autouse=True) def override_user_filesystem_getter( - filesystem_getter: Callable[[str], FileSystem], monkeypatch: pytest.MonkeyPatch -) -> None: + filesystem_getter: Callable[[str], FileSystem], + base_filesystem: FileSystem, + s3_testing_bucket: S3TestingBucket, + base_user_dir: str, + monkeypatch: pytest.MonkeyPatch, +) -> Generator[None, None, None]: monkeypatch.setitem( app.dependency_overrides, # type: ignore filesystem_getter_dep, lambda: filesystem_getter, ) + yield + # cleanup after every test + if isinstance(base_filesystem, S3Filesystem): + s3_testing_bucket.cleanup() + else: + shutil.rmtree(base_user_dir, ignore_errors=True) @pytest.fixture(autouse=True) def override_auth( - monkeypatch: pytest.MonkeyPatch, username: str, user_email: str + monkeypatch: pytest.MonkeyPatch, test_username: str, test_user_email: str ) -> None: monkeypatch.setitem( app.dependency_overrides, # type: ignore current_user_dep, - lambda: CognitoClaims(**{"cognito:username": username, "email": user_email}), + lambda: CognitoClaims( + **{"cognito:username": test_username, "email": test_user_email} + ), ) @@ -234,11 +228,6 @@ def require_auth(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.delitem(app.dependency_overrides, current_user_dep) # type: ignore -@pytest.fixture -def client() -> TestClient: - return TestClient(app) - - @pytest.fixture def data_files(user_filesystem: FileSystem) -> dict[str, str]: data_file1_name = "data/test/data_file1.txt" @@ -282,16 +271,16 @@ def job_attrs() -> dict[str, Any]: @pytest.fixture def jobs( - username: str, - user_email: str, + test_username: str, + test_user_email: str, application: dict[str, str], job_attrs: dict[str, Any], db_session: Session, ) -> list[Job]: job1 = Job( id=42, - user_id=username, - user_email=user_email, + user_id=test_username, + user_email=test_user_email, job_name="job_test_1", environment="cloud", application=application, @@ -301,8 +290,8 @@ def jobs( ) job2 = Job( id=50, - user_id=username, - user_email=user_email, + user_id=test_username, + user_email=test_user_email, job_name="job_test_2", environment=None, application=application, diff --git a/tests/integration/endpoints/conftest.py b/tests/integration/endpoints/conftest.py new file mode 100644 index 0000000..3f8a43f --- /dev/null +++ b/tests/integration/endpoints/conftest.py @@ -0,0 +1,13 @@ +from typing import Generator + +import pytest +from fastapi.testclient import TestClient + +from api.main import app + + +@pytest.fixture +def client() -> Generator[TestClient, None, None]: + # run everything in lifespan context + with TestClient(app) as client: + yield client diff --git a/tests/integration/test_main.py b/tests/integration/test_main.py new file mode 100644 index 0000000..df05855 --- /dev/null +++ b/tests/integration/test_main.py @@ -0,0 +1,105 @@ +import gzip +import sqlite3 +import tempfile +import time +from typing import cast + +import pytest +from fastapi.testclient import TestClient +from sqlalchemy.orm import Session + +from api import settings +from api.core.database import Database, SqliteDatabase +from api.core.filesystem import FileSystem, S3Filesystem +from api.dependencies import db_dep +from api.main import app +from api.models import Job +from tests.conftest import S3TestingBucket + + +@pytest.fixture +def client() -> TestClient: + return TestClient(app) + + +class TestCronBackupDatabase: + @pytest.fixture(autouse=True) + def skip_if_not_sqlite_s3(self, db: Database, base_filesystem: FileSystem) -> None: + """Skip tests if not using SQLite DB with S3 filesystem.""" + if not isinstance(db, SqliteDatabase) or not isinstance( + base_filesystem, S3Filesystem + ): + pytest.skip("Backup tests only run with SQLite DB and S3 filesystem") + + @pytest.fixture(autouse=True) + def setup_backup_cron_interval(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Set backup cron interval to 1 seconds for faster testing.""" + monkeypatch.setattr(settings, "cron_backup_interval", 1) + + def get_backup_nrows(self, s3_testing_bucket: S3TestingBucket) -> int: + """Helper to get number of rows in backup database.""" + response = s3_testing_bucket.s3_client.get_object( + Bucket=s3_testing_bucket.bucket_name, + Key=SqliteDatabase.BACKUP_KEY, + ) + backup_data_gzip = response["Body"].read() + backup_data = gzip.decompress(backup_data_gzip) + with tempfile.NamedTemporaryFile(suffix=".db") as tmp_file: + tmp_file.write(backup_data) + tmp_path = tmp_file.name + conn = sqlite3.connect(tmp_path) + cursor = conn.cursor() + cursor.execute("SELECT COUNT(*) FROM jobs") + n_rows = cursor.fetchone()[0] + conn.close() + return cast(int, n_rows) + + def test_sqlite_backup( + self, + db: SqliteDatabase, + jobs: list[Job], + client: TestClient, + s3_testing_bucket: S3TestingBucket, + tmpdir_factory: pytest.TempdirFactory, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Test the backup and restore functionality of the SQLiteRDSJobQueue.""" + # Startup: no backup present + with pytest.raises(s3_testing_bucket.s3_client.exceptions.NoSuchKey): + self.get_backup_nrows(s3_testing_bucket) + + with client: + # First start-up: no jobs + time.sleep(2) # wait for backup to run + assert self.get_backup_nrows(s3_testing_bucket) == 0 + + # Enqueue a job and verify it's backed up + with Session(db.engine) as session: + session.add(jobs[0]) + session.commit() + time.sleep(2) # wait for backup to run + assert self.get_backup_nrows(s3_testing_bucket) == 1 + + # Enqueue a second job and shutdown before backup runs + with Session(db.engine) as session: + session.add(jobs[1]) + session.commit() + + # On shutdown, final backup should run with both jobs + assert self.get_backup_nrows(s3_testing_bucket) == 2 + + # New queue (e.g., application started again) should restore from backup + new_db_url = f"sqlite:///{tmpdir_factory.mktemp('integration') / 'restored.db'}" + new_db = SqliteDatabase( + new_db_url, + s3_client=s3_testing_bucket.s3_client, + s3_bucket=s3_testing_bucket.bucket_name, + ) + monkeypatch.setitem( + app.dependency_overrides, # type: ignore + db_dep, + lambda: new_db, + ) + with client: + assert len(client.get("/jobs").json()) == 2 + assert self.get_backup_nrows(s3_testing_bucket) == 2 diff --git a/tests/unit/core/test_filesystem.py b/tests/unit/core/test_filesystem.py index 00f63ec..e8c2909 100644 --- a/tests/unit/core/test_filesystem.py +++ b/tests/unit/core/test_filesystem.py @@ -235,9 +235,11 @@ def test_delete_directory( class TestLocalFilesystem(_TestFilesystem): @pytest.fixture(scope="class") def filesystem(self, base_dir: str) -> Generator[LocalFilesystem, Any, None]: - fs = LocalFilesystem(base_dir) - yield fs - shutil.rmtree(base_dir, ignore_errors=True) + yield LocalFilesystem(base_dir) + try: + shutil.rmtree(base_dir) + except FileNotFoundError: + pass @pytest.fixture def data_file1( @@ -269,7 +271,7 @@ def filesystem( yield S3Filesystem( base_dir, s3_testing_bucket.s3_client, s3_testing_bucket.bucket_name ) - s3_testing_bucket.delete() + s3_testing_bucket.cleanup() @pytest.fixture def data_file1( From 04b66cee8043789d4b6eb4a26d9fc1f84df8da35 Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Sun, 1 Mar 2026 23:56:23 +0100 Subject: [PATCH 3/8] fix tests --- api/core/database.py | 5 ++++ api/endpoints/auth_get.py | 3 ++- api/endpoints/jobs.py | 8 +++--- tests/integration/conftest.py | 25 ++++++++++++++++--- tests/integration/endpoints/test_files.py | 14 ++++++++--- .../integration/endpoints/test_job_update.py | 1 + 6 files changed, 44 insertions(+), 12 deletions(-) diff --git a/api/core/database.py b/api/core/database.py index 57904fa..8a7eaf4 100644 --- a/api/core/database.py +++ b/api/core/database.py @@ -49,6 +49,11 @@ def backup(self) -> bool: """Backup the database. To be implemented by subclasses if supported.""" return False + def empty(self) -> None: + """Empty the database by dropping and recreating all tables.""" + Base.metadata.drop_all(bind=self.engine) + Base.metadata.create_all(bind=self.engine) + class SqliteDatabase(Database): """SQLite database wrapper with optional S3 backup support.""" diff --git a/api/endpoints/auth_get.py b/api/endpoints/auth_get.py index 6290182..f8b152e 100644 --- a/api/endpoints/auth_get.py +++ b/api/endpoints/auth_get.py @@ -2,7 +2,8 @@ from fastapi import APIRouter, Depends -from api.dependencies import GroupClaims, current_user_dep +from api.core.auth import GroupClaims +from api.dependencies import current_user_dep from api.schemas.user import User from api.settings import cognito_client_id, cognito_region, cognito_user_pool_id diff --git a/api/endpoints/jobs.py b/api/endpoints/jobs.py index f5abd76..003c84d 100644 --- a/api/endpoints/jobs.py +++ b/api/endpoints/jobs.py @@ -32,7 +32,9 @@ def list_jobs( db: Session = Depends(session_dep), ) -> list[Job]: db_jobs = crud.get_jobs(db, request.state.current_user.username, offset, limit) - jobs = [Job.model_validate(db_job) for db_job in db_jobs] # models -> schemas + jobs = [ + Job.model_validate(db_job, from_attributes=True) for db_job in db_jobs + ] # models -> schemas return sorted(jobs, key=lambda x: x.date_created, reverse=True) @@ -45,7 +47,7 @@ def describe_job( raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Job not found" ) - return Job.model_validate(db_job) + return Job.model_validate(db_job, from_attributes=True) @router.post( @@ -70,7 +72,7 @@ def start_job( user_id=request.state.current_user.username, user_email=request.state.current_user.email, ) - return Job.model_validate(db_job) + return Job.model_validate(db_job, from_attributes=True) except FileNotFoundError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 3de5eb5..83d7e52 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -111,11 +111,12 @@ def filesystem_getter( base_user_dir: str, s3_testing_bucket: S3TestingBucket, ) -> Callable[[str], FileSystem]: + s3_fs = isinstance(base_filesystem, S3Filesystem) return user_filesystem_getter( user_data_root_path=base_user_dir, - filesystem="s3" if isinstance(base_filesystem, S3Filesystem) else "local", - s3_bucket=s3_testing_bucket.bucket_name, - s3_client=s3_testing_bucket.s3_client, + filesystem="s3" if s3_fs else "local", + s3_bucket=s3_testing_bucket.bucket_name if s3_fs else None, + s3_client=s3_testing_bucket.s3_client if s3_fs else None, ) @@ -127,12 +128,22 @@ def user_filesystem( @pytest.fixture(autouse=True) -def override_db_dep(db: Database, monkeypatch: pytest.MonkeyPatch) -> None: +def override_db_dep( + db: Database, + rds_testing_instance: RDSTestingInstance, + monkeypatch: pytest.MonkeyPatch, +) -> Generator[None, None, None]: monkeypatch.setitem( app.dependency_overrides, # type: ignore db_dep, lambda: db, ) + yield + # Cleanup after every test + if isinstance(db, SqliteDatabase): + db.empty() + else: + rds_testing_instance.cleanup() @pytest.fixture(autouse=True) @@ -269,6 +280,12 @@ def job_attrs() -> dict[str, Any]: } +@pytest.fixture +def db_session(db: Database) -> Generator[Session, None, None]: + with Session(db.engine) as session: + yield session + + @pytest.fixture def jobs( test_username: str, diff --git a/tests/integration/endpoints/test_files.py b/tests/integration/endpoints/test_files.py index 85c71fa..7084176 100644 --- a/tests/integration/endpoints/test_files.py +++ b/tests/integration/endpoints/test_files.py @@ -4,6 +4,8 @@ import requests from fastapi.testclient import TestClient +from api.core.filesystem import FileSystem, LocalFilesystem + ENDPOINT = "/files" @@ -167,7 +169,7 @@ def test_download_file_happy(client: TestClient, data_files: dict[str, str]) -> def test_get_url_file_happy( - env: str, client: TestClient, data_files: dict[str, str] + base_filesystem: FileSystem, client: TestClient, data_files: dict[str, str] ) -> None: data_file1_name, data_file1_contents = list(data_files.items())[0] response = client.get(f"{ENDPOINT}/{data_file1_name}/url") @@ -175,13 +177,15 @@ def test_get_url_file_happy( request_params = response.json() if "authorization" in request_params["headers"]: del request_params["headers"]["authorization"] - request_client = client if env == "local" else requests + request_client = ( + client if isinstance(base_filesystem, LocalFilesystem) else requests + ) response = request_client.request(**request_params) assert response.status_code == 200, response.text assert response.content.decode("utf-8") == data_file1_contents -def test_post_url_file_happy(env: str, client: TestClient) -> None: +def test_post_url_file_happy(base_filesystem: FileSystem, client: TestClient) -> None: data_file1_name = "data/test/data_file1.txt" data_file1_contents = "data file1 contents" response = client.post(f"{ENDPOINT}/{os.path.dirname(data_file1_name)}//url") @@ -196,7 +200,9 @@ def test_post_url_file_happy(env: str, client: TestClient) -> None: "text/plain", ) } - request_client = client if env == "local" else requests + request_client = ( + client if isinstance(base_filesystem, LocalFilesystem) else requests + ) request_client.request(**request_params, files=files) response = client.get( f"{ENDPOINT}//", params={"recursive": True, "show_dirs": False} diff --git a/tests/integration/endpoints/test_job_update.py b/tests/integration/endpoints/test_job_update.py index 491b668..ed2cfed 100644 --- a/tests/integration/endpoints/test_job_update.py +++ b/tests/integration/endpoints/test_job_update.py @@ -38,6 +38,7 @@ def test_job_status_update( headers={"x-api-key": internal_api_key_secret}, ) assert response.status_code == 200 + db_session.refresh(jobs[0]) job = db_session.query(Job).filter(Job.id == jobs[0].id).first() assert job is not None assert job.status == "running" From 2296ff8724dd3a36624d4ef9b49d222561c0255c Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Mon, 2 Mar 2026 00:07:19 +0100 Subject: [PATCH 4/8] fix main test --- tests/integration/conftest.py | 11 +++++++---- tests/integration/test_main.py | 6 +++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 83d7e52..a2c3613 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -287,12 +287,11 @@ def db_session(db: Database) -> Generator[Session, None, None]: @pytest.fixture -def jobs( +def job_defs( test_username: str, test_user_email: str, application: dict[str, str], job_attrs: dict[str, Any], - db_session: Session, ) -> list[Job]: job1 = Job( id=42, @@ -316,10 +315,14 @@ def jobs( hardware={}, paths_out={"output": "out", "log": "log", "artifact": "model"}, ) - for job in [job1, job2]: + return [job1, job2] + + +def jobs(job_defs: list[Job], db_session: Session) -> list[Job]: + for job in job_defs: db_session.add(job) db_session.commit() - return [job1, job2] + return job_defs @pytest.fixture diff --git a/tests/integration/test_main.py b/tests/integration/test_main.py index df05855..3d1eaa7 100644 --- a/tests/integration/test_main.py +++ b/tests/integration/test_main.py @@ -57,7 +57,7 @@ def get_backup_nrows(self, s3_testing_bucket: S3TestingBucket) -> int: def test_sqlite_backup( self, db: SqliteDatabase, - jobs: list[Job], + job_defs: list[Job], client: TestClient, s3_testing_bucket: S3TestingBucket, tmpdir_factory: pytest.TempdirFactory, @@ -75,14 +75,14 @@ def test_sqlite_backup( # Enqueue a job and verify it's backed up with Session(db.engine) as session: - session.add(jobs[0]) + session.add(job_defs[0]) session.commit() time.sleep(2) # wait for backup to run assert self.get_backup_nrows(s3_testing_bucket) == 1 # Enqueue a second job and shutdown before backup runs with Session(db.engine) as session: - session.add(jobs[1]) + session.add(job_defs[1]) session.commit() # On shutdown, final backup should run with both jobs From 3fd66423dfa940ea97140db4a5ec8e56a38a3b95 Mon Sep 17 00:00:00 2001 From: Arthur Jaques Date: Mon, 2 Mar 2026 00:23:33 +0100 Subject: [PATCH 5/8] Add pytest fixture for job definitions --- tests/integration/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index a2c3613..ec162e5 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -318,6 +318,7 @@ def job_defs( return [job1, job2] +@pytest.fixture def jobs(job_defs: list[Job], db_session: Session) -> list[Job]: for job in job_defs: db_session.add(job) From 1864f4cdcb0cd319956c5b38c814b8e2dfa36da2 Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Wed, 4 Mar 2026 23:36:04 +0100 Subject: [PATCH 6/8] fix tests, maybe --- tests/integration/endpoints/test_job_update.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/integration/endpoints/test_job_update.py b/tests/integration/endpoints/test_job_update.py index ed2cfed..41b4c38 100644 --- a/tests/integration/endpoints/test_job_update.py +++ b/tests/integration/endpoints/test_job_update.py @@ -10,7 +10,9 @@ ENDPOINT = "/_job_status" -def test_job_status_init(db_session: Session, jobs: list[Job]) -> None: +def test_job_status_init( + db_session: Session, client: TestClient, jobs: list[Job] +) -> None: job = db_session.query(Job).first() assert job is not None assert job.status == "queued" From 46b8a34cb8bb0d5c547a07172223bac1aa591c37 Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Thu, 5 Mar 2026 19:27:17 +0100 Subject: [PATCH 7/8] Bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c5da029..a46d465 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "api" -version = "0.1.3" +version = "0.2.0" description = "User-facing API of DECODE OpenCloud." authors = ["Arthur Jaques "] readme = "README.md" From 093d0879ec4068c57d8cf985dacb180068dcd07c Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Thu, 5 Mar 2026 23:40:28 +0100 Subject: [PATCH 8/8] [FIX] correct lifespan --- api/main.py | 2 +- tests/integration/conftest.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/api/main.py b/api/main.py index 9875641..dd779d4 100644 --- a/api/main.py +++ b/api/main.py @@ -33,7 +33,7 @@ async def cron_backup_database(db: Database) -> None: @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: - db = app.dependency_overrides.get(dependencies.db_dep, dependencies.db_dep)() + db = await app.dependency_overrides.get(dependencies.db_dep, dependencies.db_dep)() assert isinstance(db, Database) db.create() task_backup = asyncio.create_task(cron_backup_database(db)) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index ec162e5..d2b6023 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -133,10 +133,13 @@ def override_db_dep( rds_testing_instance: RDSTestingInstance, monkeypatch: pytest.MonkeyPatch, ) -> Generator[None, None, None]: + async def _override_db() -> Database: + return db + monkeypatch.setitem( app.dependency_overrides, # type: ignore db_dep, - lambda: db, + _override_db, ) yield # Cleanup after every test