Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Dockerfile.cloudrun
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ RUN cd /tmp/audio-separator-src \
"python-multipart>=0.0.6" \
"filetype>=1.2.0" \
"google-cloud-storage>=2.0.0" \
"google-cloud-firestore>=2.0.0" \
&& rm -rf /tmp/audio-separator-src

# Set up CUDA library paths
Expand Down
4 changes: 2 additions & 2 deletions audio_separator/remote/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def separate_audio(
data["custom_output_names"] = json.dumps(custom_output_names)

try:
# Increase timeout for large files (5 minutes)
# Server returns immediately with task_id; 60s is generous for submission
# When using gcs_uri (no file upload), we still need multipart/form-data
# encoding because FastAPI requires it for endpoints with File()/Form() params.
# Passing a dummy empty file field forces requests to use multipart encoding.
Expand All @@ -158,7 +158,7 @@ def separate_audio(
f"{self.api_url}/separate",
files=files,
data=data,
timeout=300,
timeout=60,
)
response.raise_for_status()
return response.json()
Expand Down
121 changes: 85 additions & 36 deletions audio_separator/remote/deploy_cloudrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,40 @@
MODEL_BUCKET = os.environ.get("MODEL_BUCKET", "")
PORT = int(os.environ.get("PORT", "8080"))

# In-memory job status tracking (one instance handles one job at a time on Cloud Run GPU)
job_status_store: dict[str, dict] = {}


# Track model readiness
models_ready = False

# --- Async job infrastructure ---
gpu_semaphore = threading.Semaphore(1)

OUTPUT_BUCKET = os.environ.get("OUTPUT_BUCKET", "nomadkaraoke-audio-separator-outputs")
GCP_PROJECT = os.environ.get("GCP_PROJECT", "nomadkaraoke")

_job_store = None
_output_store = None


def get_job_store():
"""Get or create the Firestore job store (lazy init)."""
global _job_store
if _job_store is None:
from audio_separator.remote.job_store import FirestoreJobStore

_job_store = FirestoreJobStore(project=GCP_PROJECT)
return _job_store


def get_output_store():
"""Get or create the GCS output store (lazy init)."""
global _output_store
if _output_store is None:
from audio_separator.remote.output_store import GCSOutputStore

_output_store = GCSOutputStore(bucket_name=OUTPUT_BUCKET, project=GCP_PROJECT)
return _output_store


def generate_file_hash(filename: str) -> str:
"""Generate a short, stable hash for a filename to use in download URLs."""
Expand Down Expand Up @@ -188,19 +216,26 @@ def separate_audio_sync(

def update_status(status: str, progress: int = 0, error: str = None, files: dict = None):
status_data = {
"task_id": task_id,
"status": status,
"progress": progress,
"original_filename": filename,
"models_used": models_used,
"total_models": len(models) if models else 1,
"current_model_index": 0,
"files": files or {},
}
if files is not None:
status_data["files"] = files
if error:
status_data["error"] = error
job_status_store[task_id] = status_data

try:
get_job_store().update(task_id, status_data)
except Exception as e:
logger.warning(f"[{task_id}] Failed to update Firestore status: {e}")

# Wait for GPU availability
update_status("queued", 0)
logger.info(f"[{task_id}] Waiting for GPU semaphore...")
gpu_semaphore.acquire()
logger.info(f"[{task_id}] GPU semaphore acquired, starting separation")
try:
os.makedirs(f"{STORAGE_DIR}/outputs/{task_id}", exist_ok=True)
output_dir = f"{STORAGE_DIR}/outputs/{task_id}"
Expand Down Expand Up @@ -329,6 +364,9 @@ def update_status(status: str, progress: int = 0, error: str = None, files: dict
fname = os.path.basename(f)
all_output_files[generate_file_hash(fname)] = fname

# Upload outputs to GCS for cross-instance access
get_output_store().upload_task_outputs(task_id, output_dir)

update_status("completed", 100, files=all_output_files)
logger.info(f"Separation completed. {len(all_output_files)} output files.")
return {"task_id": task_id, "status": "completed", "files": all_output_files, "models_used": models_used}
Expand All @@ -338,13 +376,16 @@ def update_status(status: str, progress: int = 0, error: str = None, files: dict
traceback.print_exc()
update_status("error", 0, error=str(e))

# Clean up on error
return {"task_id": task_id, "status": "error", "error": str(e), "models_used": models_used}

finally:
gpu_semaphore.release()
logger.info(f"[{task_id}] GPU semaphore released")
# Clean up local files (outputs are in GCS now)
output_dir = f"{STORAGE_DIR}/outputs/{task_id}"
if os.path.exists(output_dir):
shutil.rmtree(output_dir, ignore_errors=True)

return {"task_id": task_id, "status": "error", "error": str(e), "models_used": models_used}


# --- FastAPI Application ---

Expand Down Expand Up @@ -451,9 +492,10 @@ async def separate_audio(
filename = file.filename

task_id = str(uuid.uuid4())
instance_id = os.environ.get("K_REVISION", "local")

# Set initial status
job_status_store[task_id] = {
# Write initial status to Firestore
get_job_store().set(task_id, {
"task_id": task_id,
"status": "submitted",
"progress": 0,
Expand All @@ -462,12 +504,12 @@ async def separate_audio(
"total_models": 1 if preset else (len(models_list) if models_list else 1),
"current_model_index": 0,
"files": {},
}
"instance_id": instance_id,
})

# Run separation in a background thread to not block the event loop
# but keep the request alive (Cloud Run keeps the instance warm)
# Fire-and-forget: run separation in background thread
loop = asyncio.get_event_loop()
await loop.run_in_executor(
loop.run_in_executor(
None,
lambda: separate_audio_sync(
audio_data,
Expand Down Expand Up @@ -509,8 +551,15 @@ async def separate_audio(
),
)

# Return the final status (completed or error)
return job_status_store.get(task_id, {"task_id": task_id, "status": "error", "error": "Job lost"})
# Return immediately — client polls /status/{task_id}
return {
"task_id": task_id,
"status": "submitted",
"progress": 0,
"original_filename": filename,
"models_used": [f"preset:{preset}"] if preset else (models_list or ["default"]),
"total_models": 1 if preset else (len(models_list) if models_list else 1),
}

except HTTPException:
raise
Expand All @@ -521,8 +570,9 @@ async def separate_audio(
@web_app.get("/status/{task_id}")
async def get_job_status(task_id: str) -> dict:
"""Get the status of a separation job."""
if task_id in job_status_store:
return job_status_store[task_id]
result = get_job_store().get(task_id)
if result:
return result
return {
"task_id": task_id,
"status": "not_found",
Expand All @@ -535,32 +585,20 @@ async def get_job_status(task_id: str) -> dict:
async def download_file(task_id: str, file_hash: str) -> Response:
"""Download a separated audio file using its hash identifier."""
try:
# Look up filename from job status
status_data = job_status_store.get(task_id)
status_data = get_job_store().get(task_id)
if not status_data:
raise HTTPException(status_code=404, detail="Task not found")

files_dict = status_data.get("files", {})

# Handle both dict (hash→filename) and list (legacy) formats
actual_filename = None
if isinstance(files_dict, dict):
actual_filename = files_dict.get(file_hash)
elif isinstance(files_dict, list):
for fname in files_dict:
if generate_file_hash(fname) == file_hash:
actual_filename = fname
break

if not actual_filename:
raise HTTPException(status_code=404, detail=f"File with hash {file_hash} not found")

file_path = f"{STORAGE_DIR}/outputs/{task_id}/{actual_filename}"
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail=f"File not found on disk: {actual_filename}")

with open(file_path, "rb") as f:
file_data = f.read()
file_data = get_output_store().get_file_bytes(task_id, actual_filename)

detected_type = filetype.guess(file_data)
content_type = detected_type.mime if detected_type and detected_type.mime else "application/octet-stream"
Expand Down Expand Up @@ -667,9 +705,20 @@ async def root() -> dict:

@web_app.on_event("startup")
async def startup_event():
"""Download models from GCS on startup."""
"""Clean up local storage and download models from GCS on startup."""
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(f"{STORAGE_DIR}/outputs", exist_ok=True)

# Wipe local outputs from previous instance
outputs_dir = f"{STORAGE_DIR}/outputs"
if os.path.exists(outputs_dir):
shutil.rmtree(outputs_dir, ignore_errors=True)
os.makedirs(outputs_dir, exist_ok=True)

# Clean up old Firestore jobs (>1 hour)
try:
get_job_store().cleanup_old_jobs(max_age_seconds=3600)
except Exception as e:
logger.warning(f"Failed to clean up old jobs: {e}")

# Download models in background thread to not block startup probe
thread = threading.Thread(target=download_models_from_gcs, daemon=True)
Expand Down
73 changes: 73 additions & 0 deletions audio_separator/remote/job_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Firestore-backed job status store for audio separation jobs.

Replaces the in-memory dict so any Cloud Run instance can read/write job status.
"""
import logging
import time
from typing import Optional

logger = logging.getLogger("audio-separator-api")

COLLECTION = "audio_separation_jobs"


class FirestoreJobStore:
"""Job status store backed by Firestore.

Provides dict-like get/set interface for job status documents.
"""

def __init__(self, project: str = "nomadkaraoke"):
from google.cloud import firestore

self._firestore = firestore
self._db = firestore.Client(project=project)
self._collection = self._db.collection(COLLECTION)

def set(self, task_id: str, data: dict) -> None:
"""Create or overwrite a job status document."""
data = {**data, "updated_at": self._firestore.SERVER_TIMESTAMP}
if "created_at" not in data:
data["created_at"] = self._firestore.SERVER_TIMESTAMP
self._collection.document(task_id).set(data)

def get(self, task_id: str) -> Optional[dict]:
"""Get job status. Returns None if not found."""
doc = self._collection.document(task_id).get()
if doc.exists:
return doc.to_dict()
return None

def update(self, task_id: str, fields: dict) -> None:
"""Merge fields into an existing document."""
fields = {**fields, "updated_at": self._firestore.SERVER_TIMESTAMP}
self._collection.document(task_id).update(fields)

def delete(self, task_id: str) -> None:
"""Delete a job status document."""
self._collection.document(task_id).delete()

def __contains__(self, task_id: str) -> bool:
"""Check if a task exists."""
doc = self._collection.document(task_id).get()
return doc.exists

def cleanup_old_jobs(self, max_age_seconds: int = 3600) -> int:
"""Delete completed/errored jobs older than max_age_seconds. Returns count deleted."""
cutoff = time.time() - max_age_seconds
from datetime import datetime, timezone
cutoff_dt = datetime.fromtimestamp(cutoff, tz=timezone.utc)

deleted = 0
query = (
self._collection
.where("status", "in", ["completed", "error"])
.where("updated_at", "<", cutoff_dt)
)
for doc in query.stream():
doc.reference.delete()
deleted += 1

if deleted:
logger.info(f"Cleaned up {deleted} old job(s) from Firestore")
return deleted
58 changes: 58 additions & 0 deletions audio_separator/remote/output_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""GCS-backed output file store for audio separation results.

Uploads separation output files to GCS so any Cloud Run instance can serve downloads.
"""
import logging
import os

logger = logging.getLogger("audio-separator-api")


class GCSOutputStore:
"""Manages separation output files in GCS."""

def __init__(self, bucket_name: str = "nomadkaraoke-audio-separator-outputs", project: str = "nomadkaraoke"):
from google.cloud import storage

self._client = storage.Client(project=project)
self._bucket = self._client.bucket(bucket_name)

def upload_task_outputs(self, task_id: str, local_dir: str) -> list[str]:
"""Upload all files in local_dir to GCS under {task_id}/ prefix.

Returns list of uploaded filenames.
"""
uploaded = []
for filename in os.listdir(local_dir):
local_path = os.path.join(local_dir, filename)
if not os.path.isfile(local_path):
continue
gcs_path = f"{task_id}/{filename}"
blob = self._bucket.blob(gcs_path)
blob.upload_from_filename(local_path)
uploaded.append(filename)
logger.info(f"Uploaded {filename} to gs://{self._bucket.name}/{gcs_path}")
return uploaded

def get_file_bytes(self, task_id: str, filename: str) -> bytes:
"""Download file content as bytes (for HTTP responses)."""
gcs_path = f"{task_id}/{filename}"
blob = self._bucket.blob(gcs_path)
return blob.download_as_bytes()

def download_file(self, task_id: str, filename: str, local_path: str) -> str:
"""Download a file from GCS to a local path."""
gcs_path = f"{task_id}/{filename}"
blob = self._bucket.blob(gcs_path)
blob.download_to_filename(local_path)
return local_path

def delete_task_outputs(self, task_id: str) -> int:
"""Delete all output files for a task. Returns count deleted."""
deleted = 0
for blob in self._bucket.list_blobs(prefix=f"{task_id}/"):
blob.delete()
deleted += 1
if deleted:
logger.info(f"Deleted {deleted} output file(s) for task {task_id}")
return deleted
Loading
Loading