Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"python-envs.defaultEnvManager": "ms-python.python:system"
}
33 changes: 33 additions & 0 deletions api/core/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Any

from fastapi import Header, HTTPException
from fastapi.security import HTTPAuthorizationCredentials
from fastapi_cloudauth.cognito import CognitoClaims, CognitoCurrentUser # type: ignore
from pydantic import Field


# https://github.com/iwpnd/fastapi-key-auth/blob/main/fastapi_key_auth/dependency/authorizer.py
class APIKeyDependency:
def __init__(self, key: str | None):
self.key = key

def __call__(self, x_api_key: str | None = Header(...)) -> str | None:
if x_api_key != self.key:
raise HTTPException(status_code=401, detail="unauthorized")
return x_api_key


class GroupClaims(CognitoClaims): # type: ignore
cognito_groups: list[str] | None = Field(alias="cognito:groups")


class UserGroupCognitoCurrentUser(CognitoCurrentUser): # type: ignore
user_info = GroupClaims

async def call(self, http_auth: HTTPAuthorizationCredentials) -> Any:
user_info = await super().call(http_auth)
if "users" not in (getattr(user_info, "cognito_groups") or []):
raise HTTPException(
status_code=403, detail="Not a member of the 'users' group"
)
return user_info
126 changes: 126 additions & 0 deletions api/core/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import gzip
import os
import sqlite3
import tempfile
import time
from typing import Any, cast

from botocore.exceptions import ClientError
from mypy_boto3_s3 import S3Client
from sqlalchemy import Engine, create_engine

from api.models import Base


class Database:
"""Database wrapper."""

def __init__(
self,
db_url: str,
connect_kwargs: dict[str, Any] | None = None,
):
self.db_url = db_url
self.connect_kwargs = connect_kwargs or {}

@property
def engine(self) -> Engine:
if hasattr(self, "_engine"):
return cast(Engine, self._engine) # type: ignore[has-type]
retries = 0
while True:
try:
engine = create_engine(self.db_url, connect_args=self.connect_kwargs)
# Attempt to create a connection or perform any necessary operations
engine.connect()
self._engine = engine
return engine # Connection successful
except Exception as e:
if retries >= 10:
raise RuntimeError(f"Could not create engine: {str(e)}")
retries += 1
time.sleep(60)

def create(self) -> None:
"""Create database tables."""
Base.metadata.create_all(bind=self.engine)

def backup(self) -> bool:
"""Backup the database. To be implemented by subclasses if supported."""
return False

def empty(self) -> None:
"""Empty the database by dropping and recreating all tables."""
Base.metadata.drop_all(bind=self.engine)
Base.metadata.create_all(bind=self.engine)


class SqliteDatabase(Database):
"""SQLite database wrapper with optional S3 backup support."""

BACKUP_KEY = "userapi_sqlite_backup/backup.db.gz"

def __init__(
self,
db_url: str,
s3_client: S3Client | None = None,
s3_bucket: str | None = None,
):
if not db_url.startswith("sqlite:///"):
raise ValueError(f"SQLiteRDSJobQueue requires SQLite DB URL, got: {db_url}")
if not ((s3_client is None) == (s3_bucket is None)):
raise ValueError(
"Both s3_client and s3_bucket must be provided for S3 backup/restore, or both must be None."
)
self.s3_client = s3_client
self.s3_bucket = s3_bucket
super().__init__(db_url, connect_kwargs={"check_same_thread": False})

def create(self) -> None:
self._restore_database()
super().create()

@property
def db_path(self) -> str:
return self.db_url[len("sqlite:///") :]

def backup(self) -> bool:
"""Backup the SQLite database to S3."""
if not self.s3_bucket or not self.s3_client:
return False

with tempfile.TemporaryDirectory() as temp_dir:
tmp_backup_path = os.path.join(temp_dir, "backup.db")
tmp_gzip_path = os.path.join(temp_dir, "backup.db.gz")
with sqlite3.connect(self.db_path) as source_conn:
with sqlite3.connect(tmp_backup_path) as backup_conn:
source_conn.backup(backup_conn)

with open(tmp_backup_path, "rb") as f_in:
with gzip.open(tmp_gzip_path, "wb") as f_out:
f_out.writelines(f_in)
self.s3_client.upload_file(tmp_gzip_path, self.s3_bucket, self.BACKUP_KEY)
return True

def _restore_database(self) -> bool:
"""Restore the SQLite database from S3."""
if not self.s3_bucket or not self.s3_client:
return False

try:
self.s3_client.head_object(Bucket=self.s3_bucket, Key=self.BACKUP_KEY)
except ClientError as e:
if e.response["Error"]["Code"] == "404":
return False
raise

with tempfile.TemporaryDirectory() as temp_dir:
tmp_gzip_path = os.path.join(temp_dir, "backup.db.gz")
tmp_backup_path = os.path.join(temp_dir, "backup.db")
self.s3_client.download_file(self.s3_bucket, self.BACKUP_KEY, tmp_gzip_path)
with gzip.open(tmp_gzip_path, "rb") as f_in:
with open(tmp_backup_path, "wb") as f_out:
f_out.write(f_in.read())
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
os.rename(tmp_backup_path, self.db_path)
return True
46 changes: 25 additions & 21 deletions api/core/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,16 @@
import shutil
import zipfile
from pathlib import Path, PurePosixPath
from typing import Any, BinaryIO, Generator, cast
from typing import Any, BinaryIO, Callable, Generator

import boto3
import humanize
from botocore.client import Config
from botocore.response import StreamingBody
from botocore.utils import fix_s3_host
from fastapi import Request
from fastapi.responses import FileResponse, StreamingResponse
from mypy_boto3_s3 import S3Client
from mypy_boto3_s3.type_defs import ObjectIdentifierTypeDef

from api import models, settings
from api import models
from api.schemas.file import FileHTTPRequest, FileInfo, FileTypes


Expand Down Expand Up @@ -387,29 +384,36 @@ def download_url(
)


def get_filesystem_with_root(root_path: str) -> FileSystem:
def get_filesystem_with_root(
root_path: str,
filesystem: str,
s3_bucket: str | None = None,
s3_client: S3Client | None = None,
) -> FileSystem:
"""Get the filesystem to use."""
predef_dirs = [e.value for e in models.UploadFileTypes] + [
e.value for e in models.OutputEndpoints
]
if settings.filesystem == "s3":
s3_client = boto3.client(
"s3",
region_name=settings.s3_region,
endpoint_url=f"https://s3.{settings.s3_region}.amazonaws.com",
config=Config(signature_version="v4", s3={"addressing_style": "path"}),
)
# this and config=... required to avoid DNS problems with new buckets
s3_client.meta.events.unregister("before-sign.s3", fix_s3_host)
return S3Filesystem(
root_path, s3_client, cast(str, settings.s3_bucket), predef_dirs=predef_dirs
)
elif settings.filesystem == "local":
if filesystem == "s3":
assert s3_client is not None, "S3 client must be provided for S3 filesystem"
assert s3_bucket is not None, "S3 bucket must be provided for S3 filesystem"
return S3Filesystem(root_path, s3_client, s3_bucket, predef_dirs=predef_dirs)
elif filesystem == "local":
return LocalFilesystem(root_path, predef_dirs=predef_dirs)
else:
raise ValueError("Invalid filesystem setting")


def get_user_filesystem(user_id: str) -> FileSystem:
def user_filesystem_getter(
user_data_root_path: str,
filesystem: str,
s3_bucket: str | None = None,
s3_client: S3Client | None = None,
) -> Callable[[str], FileSystem]:
"""Get the filesystem to use for a user."""
return get_filesystem_with_root(str(Path(settings.user_data_root_path) / user_id))
return lambda user_id: get_filesystem_with_root(
str(Path(user_data_root_path) / user_id),
filesystem=filesystem,
s3_bucket=s3_bucket,
s3_client=s3_client,
)
30 changes: 15 additions & 15 deletions api/crud/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from sqlalchemy.orm import Session

from api import models, settings
from api.core.filesystem import FileSystem, get_user_filesystem
from api.core.filesystem import FileSystem
from api.schemas import job as schemas


def enqueue_job(
job: models.Job, enqueueing_func: Callable[[schemas.QueueJob], None]
job: models.Job,
filesystem: FileSystem,
enqueueing_func: Callable[[schemas.QueueJob], None],
) -> None:
user_fs = get_user_filesystem(user_id=job.user_id)

app = job.application
job_config = settings.application_config.config[app["application"]][app["version"]][
app["entrypoint"]
Expand Down Expand Up @@ -47,14 +47,14 @@ def prepare_files(root_in: str, root_out: str, fs: FileSystem) -> dict[str, str]
f"artifact/{artifact_id}"
for artifact_id in job.attributes["files_down"]["artifact_ids"]
]
_validate_files(user_fs, [config_path] + data_paths + artifact_paths)
_validate_files(filesystem, [config_path] + data_paths + artifact_paths)
roots_down = handler_config["files_down"]
files_down = prepare_files(config_path, roots_down["config_id"], user_fs)
files_down = prepare_files(config_path, roots_down["config_id"], filesystem)
for data_path in data_paths:
files_down.update(prepare_files(data_path, roots_down["data_ids"], user_fs))
files_down.update(prepare_files(data_path, roots_down["data_ids"], filesystem))
for artifact_path in artifact_paths:
files_down.update(
prepare_files(artifact_path, roots_down["artifact_ids"], user_fs)
prepare_files(artifact_path, roots_down["artifact_ids"], filesystem)
)

app_specs = schemas.AppSpecs(
Expand All @@ -76,9 +76,9 @@ def prepare_files(root_in: str, root_out: str, fs: FileSystem) -> dict[str, str]
)

paths_upload = {
"output": user_fs.full_path_uri(job.paths_out["output"]),
"log": user_fs.full_path_uri(job.paths_out["log"]),
"artifact": user_fs.full_path_uri(job.paths_out["artifact"]),
"output": filesystem.full_path_uri(job.paths_out["output"]),
"log": filesystem.full_path_uri(job.paths_out["log"]),
"artifact": filesystem.full_path_uri(job.paths_out["artifact"]),
}

queue_item = schemas.QueueJob(
Expand Down Expand Up @@ -117,6 +117,7 @@ def _validate_files(filesystem: FileSystem, paths: list[str]) -> None:

def create_job(
db: Session,
filesystem: FileSystem,
enqueueing_func: Callable[[schemas.QueueJob], None],
job: schemas.JobCreate,
user_id: int,
Expand Down Expand Up @@ -146,18 +147,17 @@ def create_job(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ve,
)
enqueue_job(db_job, enqueueing_func)
enqueue_job(db_job, filesystem, enqueueing_func)
db.commit()
db.refresh(db_job)
return db_job


def delete_job(db: Session, db_job: models.Job) -> models.Job:
def delete_job(db: Session, filesystem: FileSystem, db_job: models.Job) -> models.Job:
db.delete(db_job)
user_fs = get_user_filesystem(user_id=db_job.user_id)
for path in db_job.paths_out.values():
if path[-1] != "/":
path += "/"
user_fs.delete(path)
filesystem.delete(path)
db.commit()
return db_job
26 changes: 0 additions & 26 deletions api/database.py

This file was deleted.

Loading
Loading