From 4a3384e237bbc184308b2390442687b88d4499f2 Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Fri, 30 Jan 2026 21:55:33 +0100 Subject: [PATCH 01/22] replace on_event (deprecated) --- workerfacing_api/main.py | 58 ++++++++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 20 deletions(-) diff --git a/workerfacing_api/main.py b/workerfacing_api/main.py index 9c851d1..ec78699 100644 --- a/workerfacing_api/main.py +++ b/workerfacing_api/main.py @@ -1,13 +1,49 @@ +import asyncio +from contextlib import asynccontextmanager +from typing import AsyncGenerator + import dotenv from fastapi import Depends, FastAPI -from fastapi_utils.tasks import repeat_every dotenv.load_dotenv() from workerfacing_api import dependencies, settings, tags from workerfacing_api.endpoints import access, files, jobs, jobs_post -workerfacing_app = FastAPI(openapi_tags=tags.tags_metadata) +queue = dependencies.queue_dep() + + +async def find_failed_jobs() -> dict[str, int]: + print("Silent fails check: starting...") + try: + max_retries = settings.max_retries + timeout_failure = settings.timeout_failure + n_retry, n_fail = queue.handle_timeouts(max_retries, timeout_failure) + print(f"Silent fails check: {n_retry} re-queued, {n_fail} failed.") + return {"n_retry": n_retry, "n_fail": n_fail} + except Exception as e: + print(f"Silent fails check: failed with {e}") + return {"n_retry": 0, "n_fail": 0} + + +async def repeat_find_failed_jobs() -> None: + while True: + await find_failed_jobs() + await asyncio.sleep(60) + + +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + task = asyncio.create_task(repeat_find_failed_jobs()) + yield + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +workerfacing_app = FastAPI(openapi_tags=tags.tags_metadata, lifespan=lifespan) workerfacing_app.include_router( jobs.router, @@ -27,24 +63,6 @@ ) -queue = dependencies.queue_dep() - - -@workerfacing_app.on_event("startup") # type: ignore -@repeat_every(seconds=60, raise_exceptions=True) -async def find_failed_jobs() -> dict[str, int]: - print("Silent fails check: starting...") - try: - max_retries = settings.max_retries - timeout_failure = settings.timeout_failure - n_retry, n_fail = queue.handle_timeouts(max_retries, timeout_failure) - print(f"Silent fails check: {n_retry} re-queued, {n_fail} failed.") - return {"n_retry": n_retry, "n_fail": n_fail} - except Exception as e: - print(f"Silent fails check: failed with {e}") - return {"n_retry": 0, "n_fail": 0} - - @workerfacing_app.get("/") async def root() -> dict[str, str]: return {"message": "Welcome to the DECODE OpenCloud Worker-facing API"} From 8a1544dce9cb4b0fa23098be63a24c9cf0202fe1 Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Tue, 3 Feb 2026 19:59:19 +0100 Subject: [PATCH 02/22] Cron job to backup sqlite db to S3 --- workerfacing_api/core/queue.py | 141 +++++++++++++++++++++++-------- workerfacing_api/dependencies.py | 34 +++++--- workerfacing_api/main.py | 47 ++++++----- 3 files changed, 156 insertions(+), 66 deletions(-) diff --git a/workerfacing_api/core/queue.py b/workerfacing_api/core/queue.py index a9e668b..6771114 100644 --- a/workerfacing_api/core/queue.py +++ b/workerfacing_api/core/queue.py @@ -1,22 +1,27 @@ import datetime +import gzip import json import os import pickle +import subprocess +import tempfile import threading import time from abc import ABC, abstractmethod +from contextlib import nullcontext from types import TracebackType from typing import Any, Type import botocore.exceptions +from botocore.exceptions import ClientError from deprecated import deprecated from dict_hash import sha256 +from mypy_boto3_s3 import S3Client from mypy_boto3_sqs import SQSClient from sqlalchemy import create_engine, inspect, not_ from sqlalchemy.engine import Engine from sqlalchemy.orm import Query, Session -from workerfacing_api import settings from workerfacing_api.crud import job_tracking from workerfacing_api.exceptions import JobDeletedException, JobNotAssignedException from workerfacing_api.schemas.queue_jobs import ( @@ -49,25 +54,6 @@ def __exit__( self.lock.release() -class MockUpdateLock: - """ - Mock context manager. - Used for RDSQueue on databases that are not SQLite, - since locking is already achieved via `with_for_update`. - """ - - def __enter__(self) -> None: - pass - - def __exit__( - self, - exc_type: Type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - pass - - class JobQueue(ABC): """Abstract multi-environment job queue.""" @@ -332,26 +318,30 @@ class RDSJobQueue(JobQueue): Allows job tracking. """ - def __init__(self, db_url: str, max_retries: int = 10, retry_wait: int = 60): + def __init__( + self, + db_url: str, + retry_different: bool = True, + connect_kwargs: dict[str, Any] | None = None, + locking_context: UpdateLock | None = None, + ): self.db_url = db_url - self.update_lock = ( - UpdateLock() if self.db_url.startswith("sqlite") else MockUpdateLock() - ) - self.engine = self._get_engine(self.db_url, max_retries, retry_wait) + self.retry_different = retry_different + self.update_lock = locking_context or nullcontext() + self.engine = self._get_engine(self.db_url, connect_kwargs or {}) self.table_name = QueuedJob.__tablename__ - def _get_engine(self, db_url: str, max_retries: int, retry_wait: int) -> Engine: + def _get_engine( + self, + db_url: str, + connect_kwargs: dict[str, Any], + retry_wait: int = 60, + max_retries: int = 10, + ) -> Engine: retries = 0 while True: try: - engine = create_engine( - db_url, - connect_args=( - {"check_same_thread": False} - if db_url.startswith("sqlite") - else {} - ), - ) + engine = create_engine(db_url, connect_args=connect_kwargs) # Attempt to create a connection or perform any necessary operations engine.connect() return engine # Connection successful @@ -426,7 +416,7 @@ def filter_sort_query(query: Query[QueuedJob]) -> QueuedJob | None: (QueuedJob.gpu_mem <= filter.gpu_mem) | (QueuedJob.gpu_mem.is_(None)), ) - if settings.retry_different: + if self.retry_different: # only if worker did not already try running this job query = query.filter(not_(QueuedJob.workers.contains(hostname))) query = query.order_by(QueuedJob.priority.desc()).order_by( @@ -570,3 +560,84 @@ def handle_timeouts( pass session.commit() return n_retry, n_failed + + def backup(self) -> bool: + """Backup the database. To be implemented by subclasses if supported.""" + return False + + +class SQLiteRDSJobQueue(RDSJobQueue): + """SQLite-specific RDS job queue with optional S3 backup support. + + Extends RDSJobQueue with specifics of SQLite databases. + Allows S3 backup and restore functionality. + """ + + BACKUP_KEY = "workerapi_sqlite_backup/backup.db.gz" + + def __init__( + self, + db_url: str, + retry_different: bool = True, + s3_client: S3Client | None = None, + s3_bucket: str | None = None, + ): + if not db_url.startswith("sqlite:///"): + raise ValueError(f"SQLiteRDSJobQueue requires SQLite DB URL, got: {db_url}") + if not (s3_client is None == s3_bucket is None): + raise ValueError( + "Both s3_client and s3_bucket must be provided for S3 backup/restore, or both must be None." + ) + self.s3_client = s3_client + self.s3_bucket = s3_bucket + self.db_url = db_url # Needed for _restore_database + self._restore_database() + super().__init__( + db_url, + retry_different=retry_different, + connect_kwargs={"check_same_thread": False}, + locking_context=UpdateLock(), + ) + + @property + def db_path(self) -> str: + return self.db_url[len("sqlite:///") :] + + def backup(self) -> bool: + """Backup the SQLite database to S3.""" + if not self.s3_bucket or not self.s3_client: + return False + + with tempfile.TemporaryDirectory() as temp_dir: + tmp_backup_path = os.path.join(temp_dir, "backup.db") + tmp_gzip_path = os.path.join(temp_dir, "backup.db.gz") + backup_cmd = ["sqlite3", self.db_path, f".backup {tmp_backup_path}"] + subprocess.run(backup_cmd, text=True, check=True) + with open(tmp_backup_path, "rb") as f_in: + with gzip.open(tmp_gzip_path, "wb") as f_out: + f_out.writelines(f_in) + self.s3_client.upload_file(tmp_gzip_path, self.s3_bucket, self.BACKUP_KEY) + return True + + def _restore_database(self) -> bool: + """Restore the SQLite database from S3.""" + if not self.s3_bucket or not self.s3_client: + return False + + try: + self.s3_client.head_object(Bucket=self.s3_bucket, Key=self.BACKUP_KEY) + except ClientError as e: + if e.response["Error"]["Code"] == "404": + return False + raise + + with tempfile.TemporaryDirectory() as temp_dir: + tmp_gzip_path = os.path.join(temp_dir, "backup.db.gz") + tmp_backup_path = os.path.join(temp_dir, "backup.db") + self.s3_client.download_file(self.s3_bucket, self.BACKUP_KEY, tmp_gzip_path) + with gzip.open(tmp_gzip_path, "rb") as f_in: + with open(tmp_backup_path, "wb") as f_out: + f_out.write(f_in.read()) + os.makedirs(os.path.dirname(self.db_path), exist_ok=True) + os.rename(tmp_backup_path, self.db_path) + return True diff --git a/workerfacing_api/dependencies.py b/workerfacing_api/dependencies.py index d2fc7e6..661b500 100644 --- a/workerfacing_api/dependencies.py +++ b/workerfacing_api/dependencies.py @@ -11,9 +11,28 @@ from workerfacing_api import settings from workerfacing_api.core import filesystem, queue +s3_client = None +if settings.s3_bucket: + s3_client = boto3.client( + "s3", + region_name=settings.s3_region, + config=Config(signature_version="v4", s3={"addressing_style": "path"}), + ) + # this and config=... required to avoid DNS problems with new buckets + s3_client.meta.events.unregister("before-sign.s3", fix_s3_host) + # Queue queue_db_url = settings.queue_db_url -queue_ = queue.RDSJobQueue(queue_db_url) +retry_different = settings.retry_different +if queue_db_url.startswith("sqlite"): + queue_: queue.RDSJobQueue = queue.SQLiteRDSJobQueue( + db_url=queue_db_url, + retry_different=retry_different, + s3_client=s3_client, + s3_bucket=settings.s3_bucket, + ) +else: + queue_ = queue.RDSJobQueue(db_url=queue_db_url, retry_different=retry_different) queue_.create(err_on_exists=False) @@ -67,18 +86,11 @@ async def current_user_global_dep( return current_user -# Files +# Filesystem async def filesystem_dep() -> filesystem.FileSystem: if settings.filesystem == "s3": - s3_client = boto3.client( - "s3", - region_name=settings.s3_region, - config=Config(signature_version="v4", s3={"addressing_style": "path"}), - ) - # this and config=... required to avoid DNS problems with new buckets - s3_client.meta.events.unregister("before-sign.s3", fix_s3_host) - if settings.s3_bucket is None: - raise ValueError("S3 bucket not configured") + if s3_client is None or settings.s3_bucket is None: + raise ValueError("S3 bucket or client not configured") return filesystem.S3Filesystem(s3_client, settings.s3_bucket) elif settings.filesystem == "local": if settings.user_data_root_path is None: diff --git a/workerfacing_api/main.py b/workerfacing_api/main.py index ec78699..e7b6276 100644 --- a/workerfacing_api/main.py +++ b/workerfacing_api/main.py @@ -13,34 +13,41 @@ queue = dependencies.queue_dep() -async def find_failed_jobs() -> dict[str, int]: - print("Silent fails check: starting...") - try: - max_retries = settings.max_retries - timeout_failure = settings.timeout_failure - n_retry, n_fail = queue.handle_timeouts(max_retries, timeout_failure) - print(f"Silent fails check: {n_retry} re-queued, {n_fail} failed.") - return {"n_retry": n_retry, "n_fail": n_fail} - except Exception as e: - print(f"Silent fails check: failed with {e}") - return {"n_retry": 0, "n_fail": 0} +async def cron_handle_timeouts() -> dict[str, int]: + while True: + await asyncio.sleep(300) # every 5 minutes + print("Silent fails check: starting...") + try: + max_retries = settings.max_retries + timeout_failure = settings.timeout_failure + n_retry, n_fail = queue.handle_timeouts(max_retries, timeout_failure) + print(f"Silent fails check: {n_retry} re-queued, {n_fail} failed.") + return {"n_retry": n_retry, "n_fail": n_fail} + except Exception as e: + print(f"Silent fails check: failed with {e}") + return {"n_retry": 0, "n_fail": 0} -async def repeat_find_failed_jobs() -> None: +async def cron_backup_database() -> bool: while True: - await find_failed_jobs() - await asyncio.sleep(60) + await asyncio.sleep(3600) # every hour + # Run backup in thread pool to avoid blocking event loop; + # Fine instead of making backup async since it runs infrequently. + if await asyncio.to_thread(queue.backup): + print("Backed up database.") + return True + return False @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: - task = asyncio.create_task(repeat_find_failed_jobs()) + task_failed_jobs = asyncio.create_task(cron_handle_timeouts()) + task_backup = asyncio.create_task(cron_backup_database()) yield - task.cancel() - try: - await task - except asyncio.CancelledError: - pass + task_failed_jobs.cancel() + task_backup.cancel() + if queue.backup(): # final backup on shutdown + print("Created final backup on shutdown.") workerfacing_app = FastAPI(openapi_tags=tags.tags_metadata, lifespan=lifespan) From 7cbfdf19b5162f2746f8fa4baae5836bdf048a4a Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Tue, 3 Feb 2026 20:05:48 +0100 Subject: [PATCH 03/22] Simplify dependencies.py --- tests/integration/conftest.py | 3 +-- workerfacing_api/core/auth.py | 33 +++++++++++++++++++++++++ workerfacing_api/dependencies.py | 42 +++++--------------------------- 3 files changed, 40 insertions(+), 38 deletions(-) create mode 100644 workerfacing_api/core/auth.py diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index fd58219..4f95d9a 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -5,11 +5,10 @@ from tests.conftest import RDSTestingInstance, S3TestingBucket from workerfacing_api import settings +from workerfacing_api.core.auth import APIKeyDependency, GroupClaims from workerfacing_api.core.filesystem import FileSystem, LocalFilesystem, S3Filesystem from workerfacing_api.core.queue import RDSJobQueue from workerfacing_api.dependencies import ( - APIKeyDependency, - GroupClaims, authorizer, current_user_dep, filesystem_dep, diff --git a/workerfacing_api/core/auth.py b/workerfacing_api/core/auth.py new file mode 100644 index 0000000..6fd0237 --- /dev/null +++ b/workerfacing_api/core/auth.py @@ -0,0 +1,33 @@ +from typing import Any + +from fastapi import Header, HTTPException +from fastapi.security import HTTPAuthorizationCredentials +from fastapi_cloudauth.cognito import CognitoClaims, CognitoCurrentUser # type: ignore +from pydantic import Field + + +# https://github.com/iwpnd/fastapi-key-auth/blob/main/fastapi_key_auth/dependency/authorizer.py +class APIKeyDependency: + def __init__(self, key: str | None): + self.key = key + + def __call__(self, x_api_key: str | None = Header(...)) -> str | None: + if x_api_key != self.key: + raise HTTPException(status_code=401, detail="unauthorized") + return x_api_key + + +class GroupClaims(CognitoClaims): # type: ignore + cognito_groups: list[str] | None = Field(alias="cognito:groups") + + +class WorkerGroupCognitoCurrentUser(CognitoCurrentUser): # type: ignore + user_info = GroupClaims + + async def call(self, http_auth: HTTPAuthorizationCredentials) -> Any: + user_info = await super().call(http_auth) + if "workers" not in (getattr(user_info, "cognito_groups") or []): + raise HTTPException( + status_code=403, detail="Not a member of the 'workers' group" + ) + return user_info diff --git a/workerfacing_api/dependencies.py b/workerfacing_api/dependencies.py index 661b500..414ae35 100644 --- a/workerfacing_api/dependencies.py +++ b/workerfacing_api/dependencies.py @@ -1,16 +1,13 @@ -from typing import Any - import boto3 from botocore.config import Config from botocore.utils import fix_s3_host -from fastapi import Depends, Header, HTTPException, Request -from fastapi.security import HTTPAuthorizationCredentials -from fastapi_cloudauth.cognito import CognitoClaims, CognitoCurrentUser # type: ignore -from pydantic import Field +from fastapi import Depends, Request +from fastapi_cloudauth.cognito import CognitoClaims # type: ignore from workerfacing_api import settings -from workerfacing_api.core import filesystem, queue +from workerfacing_api.core import auth, filesystem, queue +# S3 client setup s3_client = None if settings.s3_bucket: s3_client = boto3.client( @@ -41,38 +38,11 @@ def queue_dep() -> queue.RDSJobQueue: # App-internal authentication (i.e. user-facing API <-> worker-facing API) -# https://github.com/iwpnd/fastapi-key-auth/blob/main/fastapi_key_auth/dependency/authorizer.py -class APIKeyDependency: - def __init__(self, key: str | None): - self.key = key - - def __call__(self, x_api_key: str | None = Header(...)) -> str | None: - if x_api_key != self.key: - raise HTTPException(status_code=401, detail="unauthorized") - return x_api_key - - -authorizer = APIKeyDependency(key=settings.internal_api_key_secret) +authorizer = auth.APIKeyDependency(key=settings.internal_api_key_secret) # Worker authentication -class GroupClaims(CognitoClaims): # type: ignore - cognito_groups: list[str] | None = Field(alias="cognito:groups") - - -class WorkerGroupCognitoCurrentUser(CognitoCurrentUser): # type: ignore - user_info = GroupClaims - - async def call(self, http_auth: HTTPAuthorizationCredentials) -> Any: - user_info = await super().call(http_auth) - if "workers" not in (getattr(user_info, "cognito_groups") or []): - raise HTTPException( - status_code=403, detail="Not a member of the 'workers' group" - ) - return user_info - - -current_user_dep = WorkerGroupCognitoCurrentUser( +current_user_dep = auth.WorkerGroupCognitoCurrentUser( region=settings.cognito_region, userPoolId=settings.cognito_user_pool_id, client_id=settings.cognito_client_id, From bd99ff3f4658237f5ea747d1072f11ba727c262e Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Wed, 4 Feb 2026 16:38:35 +0100 Subject: [PATCH 04/22] Correct cron + add tests --- tests/integration/conftest.py | 169 +++++++++++------- tests/integration/endpoints/conftest.py | 13 +- tests/integration/endpoints/test_files.py | 54 +++--- tests/integration/endpoints/test_jobs.py | 82 +-------- tests/integration/endpoints/test_jobs_post.py | 8 +- tests/integration/test_main.py | 163 +++++++++++++++++ tests/unit/core/test_queue.py | 2 +- workerfacing_api/core/queue.py | 50 +++--- workerfacing_api/dependencies.py | 1 - workerfacing_api/main.py | 24 +-- workerfacing_api/settings.py | 5 + 11 files changed, 342 insertions(+), 229 deletions(-) create mode 100644 tests/integration/test_main.py diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 4f95d9a..c0412f4 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,13 +1,14 @@ +import datetime import shutil -from typing import Any, Generator, cast +from typing import Generator, cast import pytest +from mypy_boto3_s3 import S3Client from tests.conftest import RDSTestingInstance, S3TestingBucket -from workerfacing_api import settings from workerfacing_api.core.auth import APIKeyDependency, GroupClaims from workerfacing_api.core.filesystem import FileSystem, LocalFilesystem, S3Filesystem -from workerfacing_api.core.queue import RDSJobQueue +from workerfacing_api.core.queue import RDSJobQueue, SQLiteRDSJobQueue from workerfacing_api.dependencies import ( authorizer, current_user_dep, @@ -15,6 +16,16 @@ queue_dep, ) from workerfacing_api.main import workerfacing_app +from workerfacing_api.schemas.queue_jobs import ( + AppSpecs, + EnvironmentTypes, + HandlerSpecs, + HardwareSpecs, + JobSpecs, + MetaSpecs, + PathsUploadSpecs, + SubmittedJob, +) @pytest.fixture(scope="session") @@ -23,8 +34,8 @@ def test_username() -> str: @pytest.fixture(scope="session") -def base_dir() -> str: - return "int_test_dir" +def base_dir(tmp_path_factory: pytest.TempPathFactory) -> str: + return str(tmp_path_factory.mktemp("int_test_dir")) @pytest.fixture(scope="session") @@ -34,99 +45,92 @@ def internal_api_key_secret() -> str: @pytest.fixture( scope="session", - params=["local", pytest.param("aws", marks=pytest.mark.aws)], + params=["local-fs", pytest.param("aws-fs", marks=pytest.mark.aws)], ) -def env( - request: pytest.FixtureRequest, - rds_testing_instance: RDSTestingInstance, - s3_testing_bucket: S3TestingBucket, -) -> Generator[str, Any, None]: - env = cast(str, request.param) - if env == "aws": - rds_testing_instance.create() - s3_testing_bucket.create() - yield env - if env == "aws": - rds_testing_instance.cleanup() - s3_testing_bucket.cleanup() - - -@pytest.fixture(scope="session") def base_filesystem( - env: str, base_dir: str, - monkeypatch_module: pytest.MonkeyPatch, s3_testing_bucket: S3TestingBucket, -) -> Generator[FileSystem, Any, None]: - monkeypatch_module.setattr( - settings, - "user_data_root_path", - base_dir, - ) - monkeypatch_module.setattr( - settings, - "filesystem", - "local" if env == "local" else "s3", - ) - - if env == "local": - shutil.rmtree(base_dir, ignore_errors=True) - yield LocalFilesystem(base_dir, base_dir) - shutil.rmtree(base_dir, ignore_errors=True) - - elif env == "aws": - # Update settings to use the actual unique bucket name created by S3TestingBucket - monkeypatch_module.setattr( - settings, - "s3_bucket", - s3_testing_bucket.bucket_name, - ) - yield S3Filesystem(s3_testing_bucket.s3_client, s3_testing_bucket.bucket_name) - s3_testing_bucket.cleanup() - + request: pytest.FixtureRequest, +) -> FileSystem: + if request.param == "local-fs": + return LocalFilesystem(base_dir, base_dir) + elif request.param == "aws-fs": + s3_testing_bucket.create() + return S3Filesystem(s3_testing_bucket.s3_client, s3_testing_bucket.bucket_name) else: raise NotImplementedError -@pytest.fixture(scope="session") +@pytest.fixture( + scope="session", + params=["local-queue", pytest.param("aws-queue", marks=pytest.mark.aws)], +) def queue( - env: str, + base_filesystem: FileSystem, rds_testing_instance: RDSTestingInstance, tmpdir_factory: pytest.TempdirFactory, -) -> Generator[RDSJobQueue, Any, None]: - if env == "local": - queue = RDSJobQueue( - f"sqlite:///{tmpdir_factory.mktemp('integration')}/local.db" + request: pytest.FixtureRequest, +) -> RDSJobQueue: + retry_different = False # allow retries on same worker + if request.param == "local-queue": + queue_path = tmpdir_factory.mktemp("integration") / "local.db" + s3_bucket: str | None = None + s3_client: S3Client | None = None + if isinstance(base_filesystem, S3Filesystem): + s3_bucket = base_filesystem.bucket + s3_client = base_filesystem.s3_client + return SQLiteRDSJobQueue( + f"sqlite:///{queue_path}", + retry_different=retry_different, + s3_client=s3_client, + s3_bucket=s3_bucket, ) + elif request.param == "aws-queue": + rds_testing_instance.create() + return RDSJobQueue(rds_testing_instance.db_url, retry_different=retry_different) else: - queue = RDSJobQueue(rds_testing_instance.db_url) - queue.create(err_on_exists=True) - yield queue + raise NotImplementedError -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(autouse=True) def override_filesystem_dep( - base_filesystem: FileSystem, monkeypatch_module: pytest.MonkeyPatch -) -> None: + base_filesystem: FileSystem, + s3_testing_bucket: S3TestingBucket, + base_dir: str, + monkeypatch_module: pytest.MonkeyPatch, +) -> Generator[None, None, None]: monkeypatch_module.setitem( workerfacing_app.dependency_overrides, # type: ignore filesystem_dep, lambda: base_filesystem, ) + yield + # cleanup after every test + if isinstance(base_filesystem, S3Filesystem): + s3_testing_bucket.cleanup() + else: + shutil.rmtree(base_dir, ignore_errors=True) -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(autouse=True) def override_queue_dep( - queue: RDSJobQueue, monkeypatch_module: pytest.MonkeyPatch -) -> None: + queue: RDSJobQueue, + rds_testing_instance: RDSTestingInstance, + monkeypatch_module: pytest.MonkeyPatch, +) -> Generator[None, None, None]: monkeypatch_module.setitem( workerfacing_app.dependency_overrides, # type: ignore queue_dep, lambda: queue, ) + yield + if isinstance(queue, SQLiteRDSJobQueue): + queue.delete() + else: + rds_testing_instance.cleanup() -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(autouse=True) def override_auth(monkeypatch_module: pytest.MonkeyPatch, test_username: str) -> None: monkeypatch_module.setitem( workerfacing_app.dependency_overrides, # type: ignore @@ -141,7 +145,7 @@ def override_auth(monkeypatch_module: pytest.MonkeyPatch, test_username: str) -> ) -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(autouse=True) def override_internal_api_key_secret( monkeypatch_module: pytest.MonkeyPatch, internal_api_key_secret: str ) -> str: @@ -151,3 +155,32 @@ def override_internal_api_key_secret( APIKeyDependency(internal_api_key_secret), ) return internal_api_key_secret + + +@pytest.fixture +def base_job(base_filesystem: FileSystem, test_username: str) -> SubmittedJob: + time_now = datetime.datetime.now(datetime.timezone.utc).isoformat() + if isinstance(base_filesystem, S3Filesystem): + base_path = f"s3://{base_filesystem.bucket}" + else: + base_path = cast(LocalFilesystem, base_filesystem).base_post_path + paths_upload = PathsUploadSpecs( + output=f"{base_path}/{test_username}/test_out/1", + log=f"{base_path}/{test_username}/test_log/1", + artifact=f"{base_path}/{test_username}/test_arti/1", + ) + return SubmittedJob( + job=JobSpecs( + app=AppSpecs(cmd=["cmd"], env={"env": "var"}), + handler=HandlerSpecs(image_url="u", files_up={"output": "out"}), + hardware=HardwareSpecs(), + meta=MetaSpecs( + job_id=1, + date_created=time_now, + ), + ), + environment=EnvironmentTypes.local, + group=None, + priority=1, + paths_upload=paths_upload, + ) diff --git a/tests/integration/endpoints/conftest.py b/tests/integration/endpoints/conftest.py index 6a8fd0d..23a69e3 100644 --- a/tests/integration/endpoints/conftest.py +++ b/tests/integration/endpoints/conftest.py @@ -1,6 +1,6 @@ import abc from dataclasses import dataclass, field -from typing import Any +from typing import Any, Generator import pytest from fastapi.testclient import TestClient @@ -9,6 +9,13 @@ from workerfacing_api.main import workerfacing_app +@pytest.fixture +def client() -> Generator[TestClient, None, None]: + # run everything in lifespan context + with TestClient(workerfacing_app) as client: + yield client + + @dataclass class EndpointParams: method: str @@ -24,10 +31,6 @@ class _TestEndpoint(abc.ABC): def passing_params(self, *args: Any, **kwargs: Any) -> list[EndpointParams]: raise NotImplementedError - @pytest.fixture(scope="session") - def client(self) -> TestClient: - return TestClient(workerfacing_app) - def test_required_auth( self, monkeypatch: pytest.MonkeyPatch, diff --git a/tests/integration/endpoints/test_files.py b/tests/integration/endpoints/test_files.py index 69d7283..62ee28b 100644 --- a/tests/integration/endpoints/test_files.py +++ b/tests/integration/endpoints/test_files.py @@ -1,6 +1,5 @@ import os from io import BytesIO -from typing import cast import pytest import requests @@ -10,67 +9,60 @@ from workerfacing_api.core.filesystem import FileSystem, S3Filesystem -@pytest.fixture(scope="session") +@pytest.fixture def data_file1_name(base_dir: str) -> str: return f"{base_dir}/data/test/data_file1.txt" -@pytest.fixture(scope="session") -def data_file1_path(env: str, data_file1_name: str, base_filesystem: FileSystem) -> str: - if env == "aws": - base_filesystem = cast(S3Filesystem, base_filesystem) +@pytest.fixture +def data_file1_path(data_file1_name: str, base_filesystem: FileSystem) -> str: + if isinstance(base_filesystem, S3Filesystem): return f"s3://{base_filesystem.bucket}/{data_file1_name}" return data_file1_name -@pytest.fixture(scope="session") +@pytest.fixture def data_file1_contents() -> str: return "data_file1" -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(autouse=True) def data_file1( - env: str, - base_filesystem: FileSystem, - data_file1_name: str, - data_file1_contents: str, + base_filesystem: FileSystem, data_file1_name: str, data_file1_contents: str ) -> None: - if env == "local": - os.makedirs(os.path.dirname(data_file1_name), exist_ok=True) - with open(data_file1_name, "w") as f: - f.write(data_file1_contents) - else: - base_filesystem = cast(S3Filesystem, base_filesystem) + if isinstance(base_filesystem, S3Filesystem): base_filesystem.s3_client.put_object( Bucket=base_filesystem.bucket, Key=data_file1_name, Body=BytesIO(data_file1_contents.encode("utf-8")), ) + else: + os.makedirs(os.path.dirname(data_file1_name), exist_ok=True) + with open(data_file1_name, "w") as f: + f.write(data_file1_contents) class TestFiles(_TestEndpoint): endpoint = "/files" - @pytest.fixture(scope="session") + @pytest.fixture def passing_params(self, data_file1_path: str) -> list[EndpointParams]: - return [ - EndpointParams("get", f"{data_file1_path}/url"), - ] + return [EndpointParams("get", f"{data_file1_path}/url")] def test_get_file( self, - env: str, data_file1_path: str, data_file1_contents: str, client: TestClient, + base_filesystem: FileSystem, ) -> None: - if env == "local": + if isinstance(base_filesystem, S3Filesystem): file_resp = client.get(f"{self.endpoint}/{data_file1_path}/download") - assert file_resp.status_code == 200 - assert file_resp.content.decode("utf-8") == data_file1_contents + assert file_resp.status_code == 403 else: file_resp = client.get(f"{self.endpoint}/{data_file1_path}/download") - assert file_resp.status_code == 403 + assert file_resp.status_code == 200 + assert file_resp.content.decode("utf-8") == data_file1_contents def test_get_file_not_exists( self, data_file1_path: str, client: TestClient @@ -84,21 +76,21 @@ def test_get_file_not_permitted(self, client: TestClient) -> None: def test_get_file_url( self, - env: str, data_file1_path: str, data_file1_contents: str, client: TestClient, + base_filesystem: FileSystem, ) -> None: req = f"{self.endpoint}/{data_file1_path}/url" url_resp = client.get(req) assert url_resp.status_code == 200 - if env == "local": - assert req.replace("/url", "/download") in url_resp.text - else: + if isinstance(base_filesystem, S3Filesystem): assert ( requests.request(**url_resp.json()).content.decode("utf-8") == data_file1_contents ) + else: + assert req.replace("/url", "/download") in url_resp.text def test_get_file_url_not_exists( self, data_file1_path: str, client: TestClient diff --git a/tests/integration/endpoints/test_jobs.py b/tests/integration/endpoints/test_jobs.py index 3243c44..81f193f 100644 --- a/tests/integration/endpoints/test_jobs.py +++ b/tests/integration/endpoints/test_jobs.py @@ -1,4 +1,3 @@ -import datetime import os import time from io import BytesIO @@ -9,91 +8,22 @@ import requests from fastapi.testclient import TestClient -from tests.conftest import RDSTestingInstance from tests.integration.endpoints.conftest import EndpointParams, _TestEndpoint from workerfacing_api.core.filesystem import FileSystem, LocalFilesystem, S3Filesystem -from workerfacing_api.core.queue import RDSJobQueue +from workerfacing_api.core.queue import RDSJobQueue, SQLiteRDSJobQueue from workerfacing_api.crud import job_tracking from workerfacing_api.exceptions import JobDeletedException -from workerfacing_api.schemas.queue_jobs import ( - AppSpecs, - EnvironmentTypes, - HandlerSpecs, - HardwareSpecs, - JobSpecs, - MetaSpecs, - PathsUploadSpecs, - SubmittedJob, -) +from workerfacing_api.schemas.queue_jobs import EnvironmentTypes, SubmittedJob from workerfacing_api.schemas.rds_models import JobStates -@pytest.fixture(scope="session") -def app() -> AppSpecs: - return AppSpecs(cmd=["cmd"], env={"env": "var"}) - - -@pytest.fixture(scope="session") -def handler() -> HandlerSpecs: - return HandlerSpecs(image_url="u", files_up={"output": "out"}) - - -@pytest.fixture(scope="session") -def paths_upload( - env: str, test_username: str, base_filesystem: FileSystem -) -> PathsUploadSpecs: - if env == "local": - base_path = cast(LocalFilesystem, base_filesystem).base_post_path - else: - base_path = f"s3://{cast(S3Filesystem, base_filesystem).bucket}" - return PathsUploadSpecs( - output=f"{base_path}/{test_username}/test_out/1", - log=f"{base_path}/{test_username}/test_log/1", - artifact=f"{base_path}/{test_username}/test_arti/1", - ) - - class TestJobs(_TestEndpoint): endpoint = "/jobs" - @pytest.fixture(scope="session") + @pytest.fixture def passing_params(self) -> list[EndpointParams]: return [EndpointParams("get", params={"memory": 1})] - @pytest.fixture(scope="function", autouse=True) - def cleanup_queue( - self, - queue: RDSJobQueue, - env: str, - rds_testing_instance: RDSTestingInstance, - ) -> None: - if env == "local": - queue.delete() - else: - rds_testing_instance.cleanup() - queue.create() - - @pytest.fixture(scope="function") - def base_job( - self, app: AppSpecs, handler: HandlerSpecs, paths_upload: PathsUploadSpecs - ) -> SubmittedJob: - time_now = datetime.datetime.now(datetime.timezone.utc).isoformat() - return SubmittedJob( - job=JobSpecs( - app=app, - handler=handler, - hardware=HardwareSpecs(), - meta=MetaSpecs( - job_id=1, - date_created=time_now, - ), - ), - environment=EnvironmentTypes.local, - group=None, - priority=1, - paths_upload=paths_upload, - ) - def test_get_jobs( self, queue: RDSJobQueue, @@ -295,7 +225,6 @@ def mock_update_job(*args: Any, **kwargs: Any) -> None: def test_job_files_post( self, - env: str, queue: RDSJobQueue, base_filesystem: FileSystem, base_job: SubmittedJob, @@ -309,7 +238,7 @@ def test_job_files_post( params={"type": "output", "base_path": "test"}, ) assert res.status_code == 201 - if env == "local": + if isinstance(queue, SQLiteRDSJobQueue): req_base = client else: req_base = requests # type: ignore @@ -324,8 +253,7 @@ def test_job_files_post( }, ) res.raise_for_status() - if env == "local": - base_filesystem = cast(LocalFilesystem, base_filesystem) + if isinstance(base_filesystem, LocalFilesystem): assert os.path.exists( f"{base_filesystem.base_post_path}/{test_username}/test_out/1/test/file.txt" ) diff --git a/tests/integration/endpoints/test_jobs_post.py b/tests/integration/endpoints/test_jobs_post.py index 03387bc..e1ffb38 100644 --- a/tests/integration/endpoints/test_jobs_post.py +++ b/tests/integration/endpoints/test_jobs_post.py @@ -12,10 +12,8 @@ endpoint = "/_jobs" -@pytest.fixture(scope="function") -def queue_enqueue( - monkeypatch_module: pytest.MonkeyPatch, -) -> MagicMock: +@pytest.fixture +def queue_enqueue(monkeypatch_module: pytest.MonkeyPatch) -> MagicMock: queue = MagicMock() queue.enqueue = MagicMock() monkeypatch_module.setitem( @@ -26,7 +24,7 @@ def queue_enqueue( return queue.enqueue -@pytest.fixture(scope="function") +@pytest.fixture def queue_job() -> dict[str, Any]: return { "job": { diff --git a/tests/integration/test_main.py b/tests/integration/test_main.py new file mode 100644 index 0000000..6c78bba --- /dev/null +++ b/tests/integration/test_main.py @@ -0,0 +1,163 @@ +import gzip +import sqlite3 +import tempfile +import time +from typing import cast + +import pytest +from fastapi.testclient import TestClient + +from tests.conftest import S3TestingBucket +from workerfacing_api import settings +from workerfacing_api.core.filesystem import FileSystem, S3Filesystem +from workerfacing_api.core.queue import RDSJobQueue, SQLiteRDSJobQueue +from workerfacing_api.dependencies import queue_dep +from workerfacing_api.main import workerfacing_app +from workerfacing_api.schemas.queue_jobs import SubmittedJob +from workerfacing_api.schemas.rds_models import JobStates + + +@pytest.fixture +def client() -> TestClient: + return TestClient(workerfacing_app) + + +class TestCronHandleTimeouts: + @pytest.fixture(autouse=True) + def setup_timeout_failure(self, monkeypatch_module: pytest.MonkeyPatch) -> None: + """Set timeout_failure to 1 second for faster testing.""" + monkeypatch_module.setattr(settings, "timeout_failure", 1) + + @pytest.fixture(autouse=True) + def setup_max_retries(self, monkeypatch_module: pytest.MonkeyPatch) -> None: + """Set max retries to 1 for faster testing.""" + monkeypatch_module.setattr(settings, "max_retries", 1) + + @pytest.fixture(autouse=True) + def setup_cron_interval(self, monkeypatch_module: pytest.MonkeyPatch) -> None: + """Set cron interval to 1 second for faster testing.""" + monkeypatch_module.setattr(settings, "cron_timeout_check_interval", 1) + + def test_handle_timeouts( + self, + queue: RDSJobQueue, + base_job: SubmittedJob, + client: TestClient, + ) -> None: + with client: + # Push the job + queue.enqueue(base_job) + + # Pull the job + get_params = {"memory": 1} + job_id = base_job.job.meta.job_id + assert len(client.get("/jobs", params=get_params).json()) == 1 + assert queue.get_job(job_id).status == JobStates.pulled.value + + # Job kept alive by periodic status updates + for _ in range(4): + time.sleep(1) + client.put( + f"/jobs/{job_id}/status", + params={"status": "running", "runtime_details": "Processing..."}, + ) + assert len(client.get("/jobs", params=get_params).json()) == 0 + assert queue.get_job(job_id).status == JobStates.running.value + + # Let timeout + time.sleep(4) + assert queue.get_job(job_id).status == JobStates.queued.value + + # Pull again + assert len(client.get("/jobs", params=get_params).json()) == 1 + assert queue.get_job(job_id).status == JobStates.pulled.value + assert queue.get_job(job_id).num_retries == 1 + + # Let timeout and fail + time.sleep(2) + assert queue.get_job(job_id).status == JobStates.error.value + + +class TestCronBackupDatabase: + @pytest.fixture(autouse=True) + def skip_if_not_sqlite_s3( + self, queue: RDSJobQueue, base_filesystem: FileSystem + ) -> None: + """Skip tests if not using SQLite queue with S3 filesystem.""" + if not isinstance(queue, SQLiteRDSJobQueue) or not isinstance( + base_filesystem, S3Filesystem + ): + pytest.skip("Backup tests only run with SQLite queue and S3 filesystem") + + @pytest.fixture(autouse=True) + def setup_backup_cron_interval( + self, monkeypatch_module: pytest.MonkeyPatch + ) -> None: + """Set backup cron interval to 1 seconds for faster testing.""" + monkeypatch_module.setattr(settings, "cron_backup_interval", 1) + + def get_backup_nrows(self, s3_testing_bucket: S3TestingBucket) -> int: + """Helper to get number of rows in backup database.""" + response = s3_testing_bucket.s3_client.get_object( + Bucket=s3_testing_bucket.bucket_name, + Key=SQLiteRDSJobQueue.BACKUP_KEY, + ) + backup_data_gzip = response["Body"].read() + backup_data = gzip.decompress(backup_data_gzip) + with tempfile.NamedTemporaryFile(delete=False, suffix=".db") as tmp_file: + tmp_file.write(backup_data) + tmp_path = tmp_file.name + conn = sqlite3.connect(tmp_path) + cursor = conn.cursor() + cursor.execute("SELECT COUNT(*) FROM queued_jobs") + n_rows = cursor.fetchone()[0] + conn.close() + return cast(int, n_rows) + + def test_sqlite_backup( + self, + queue: SQLiteRDSJobQueue, + base_job: SubmittedJob, + client: TestClient, + s3_testing_bucket: S3TestingBucket, + tmpdir_factory: pytest.TempdirFactory, + monkeypatch_module: pytest.MonkeyPatch, + ) -> None: + """Test the backup and restore functionality of the SQLiteRDSJobQueue.""" + # Startup: no backup present + with pytest.raises(s3_testing_bucket.s3_client.exceptions.NoSuchKey): + self.get_backup_nrows(s3_testing_bucket) + + with client: + # First start-up: no jobs + time.sleep(2) # wait for backup to run + assert self.get_backup_nrows(s3_testing_bucket) == 0 + + # Enqueue a job and verify it's backed up + queue.enqueue(base_job) + time.sleep(2) # wait for backup to run + assert self.get_backup_nrows(s3_testing_bucket) == 1 + + # Enqueue a second job and shutdown before backup runs + queue.enqueue(base_job) + + # On shutdown, final backup should run with both jobs + assert self.get_backup_nrows(s3_testing_bucket) == 2 + + # New queue (e.g., application started again) should restore from backup + new_db_url = f"sqlite:///{tmpdir_factory.mktemp('integration') / 'restored.db'}" + new_queue = SQLiteRDSJobQueue( + new_db_url, + s3_client=s3_testing_bucket.s3_client, + s3_bucket=s3_testing_bucket.bucket_name, + ) + monkeypatch_module.setitem( + workerfacing_app.dependency_overrides, # type: ignore + queue_dep, + lambda: new_queue, + ) + with client: + assert ( + len(client.get("/jobs", params={"memory": 1, "limit": 5}).json()) == 2 + ) + assert self.get_backup_nrows(s3_testing_bucket) == 2 diff --git a/tests/unit/core/test_queue.py b/tests/unit/core/test_queue.py index 5d527f7..cbd72c1 100644 --- a/tests/unit/core/test_queue.py +++ b/tests/unit/core/test_queue.py @@ -69,7 +69,7 @@ def queue( success = False for _ in range(10): # i.p. SQS, RDS, etc. might need some time to delete try: - base_queue.create(err_on_exists=True) + base_queue.create() success = True break except Exception: diff --git a/workerfacing_api/core/queue.py b/workerfacing_api/core/queue.py index 6771114..5bb25bf 100644 --- a/workerfacing_api/core/queue.py +++ b/workerfacing_api/core/queue.py @@ -10,7 +10,7 @@ from abc import ABC, abstractmethod from contextlib import nullcontext from types import TracebackType -from typing import Any, Type +from typing import Any, Type, cast import botocore.exceptions from botocore.exceptions import ClientError @@ -18,7 +18,7 @@ from dict_hash import sha256 from mypy_boto3_s3 import S3Client from mypy_boto3_sqs import SQSClient -from sqlalchemy import create_engine, inspect, not_ +from sqlalchemy import create_engine, not_ from sqlalchemy.engine import Engine from sqlalchemy.orm import Query, Session @@ -58,7 +58,7 @@ class JobQueue(ABC): """Abstract multi-environment job queue.""" @abstractmethod - def create(self, err_on_exists: bool = True) -> None: + def create(self) -> None: """Create the initialized queue.""" raise NotImplementedError @@ -122,9 +122,7 @@ def __init__(self, queue_path: str): self.queue_path = queue_path self.update_lock = UpdateLock() - def create(self, err_on_exists: bool = True) -> None: - if os.path.exists(self.queue_path) and err_on_exists: - raise ValueError("A queue at this path already exists.") + def create(self) -> None: queue: dict[EnvironmentTypes, list[SubmittedJob]] = { env: [] for env in EnvironmentTypes } @@ -217,7 +215,7 @@ def __init__(self, sqs_client: SQSClient): except self.sqs_client.exceptions.QueueDoesNotExist: pass - def create(self, err_on_exists: bool = True) -> None: + def create(self) -> None: for environment, queue_name in self.queue_names.items(): try: res = self.sqs_client.create_queue( @@ -230,10 +228,7 @@ def create(self, err_on_exists: bool = True) -> None: ) self.queue_urls[environment] = res["QueueUrl"] except self.sqs_client.exceptions.QueueNameExists: - if err_on_exists: - raise ValueError( - f"A queue with the name {queue_name} already exists." - ) + pass def delete(self) -> None: for queue_url in self.queue_urls.values(): @@ -328,33 +323,28 @@ def __init__( self.db_url = db_url self.retry_different = retry_different self.update_lock = locking_context or nullcontext() - self.engine = self._get_engine(self.db_url, connect_kwargs or {}) + self.connect_kwargs = connect_kwargs or {} self.table_name = QueuedJob.__tablename__ - def _get_engine( - self, - db_url: str, - connect_kwargs: dict[str, Any], - retry_wait: int = 60, - max_retries: int = 10, - ) -> Engine: + @property + def engine(self) -> Engine: + if hasattr(self, "_engine"): + return cast(Engine, self._engine) # type: ignore[has-type] retries = 0 while True: try: - engine = create_engine(db_url, connect_args=connect_kwargs) + engine = create_engine(self.db_url, connect_args=self.connect_kwargs) # Attempt to create a connection or perform any necessary operations engine.connect() + self._engine = engine return engine # Connection successful except Exception as e: - if retries >= max_retries: + if retries >= 10: raise RuntimeError(f"Could not create engine: {str(e)}") retries += 1 - time.sleep(retry_wait) + time.sleep(60) - def create(self, err_on_exists: bool = True) -> None: - inspector = inspect(self.engine) - if inspector.has_table(self.table_name) and err_on_exists: - raise ValueError(f"A table with the name {self.table_name} already exists.") + def create(self) -> None: Base.metadata.create_all(self.engine) def delete(self) -> None: @@ -584,14 +574,12 @@ def __init__( ): if not db_url.startswith("sqlite:///"): raise ValueError(f"SQLiteRDSJobQueue requires SQLite DB URL, got: {db_url}") - if not (s3_client is None == s3_bucket is None): + if not ((s3_client is None) == (s3_bucket is None)): raise ValueError( "Both s3_client and s3_bucket must be provided for S3 backup/restore, or both must be None." ) self.s3_client = s3_client self.s3_bucket = s3_bucket - self.db_url = db_url # Needed for _restore_database - self._restore_database() super().__init__( db_url, retry_different=retry_different, @@ -599,6 +587,10 @@ def __init__( locking_context=UpdateLock(), ) + def create(self) -> None: + self._restore_database() + super().create() + @property def db_path(self) -> str: return self.db_url[len("sqlite:///") :] diff --git a/workerfacing_api/dependencies.py b/workerfacing_api/dependencies.py index 414ae35..09f4c1c 100644 --- a/workerfacing_api/dependencies.py +++ b/workerfacing_api/dependencies.py @@ -30,7 +30,6 @@ ) else: queue_ = queue.RDSJobQueue(db_url=queue_db_url, retry_different=retry_different) -queue_.create(err_on_exists=False) def queue_dep() -> queue.RDSJobQueue: diff --git a/workerfacing_api/main.py b/workerfacing_api/main.py index e7b6276..e790699 100644 --- a/workerfacing_api/main.py +++ b/workerfacing_api/main.py @@ -8,41 +8,41 @@ dotenv.load_dotenv() from workerfacing_api import dependencies, settings, tags +from workerfacing_api.core.queue import RDSJobQueue from workerfacing_api.endpoints import access, files, jobs, jobs_post -queue = dependencies.queue_dep() - -async def cron_handle_timeouts() -> dict[str, int]: +async def cron_handle_timeouts(queue: RDSJobQueue) -> None: while True: - await asyncio.sleep(300) # every 5 minutes + await asyncio.sleep(settings.cron_timeout_check_interval) print("Silent fails check: starting...") try: max_retries = settings.max_retries timeout_failure = settings.timeout_failure n_retry, n_fail = queue.handle_timeouts(max_retries, timeout_failure) print(f"Silent fails check: {n_retry} re-queued, {n_fail} failed.") - return {"n_retry": n_retry, "n_fail": n_fail} except Exception as e: print(f"Silent fails check: failed with {e}") - return {"n_retry": 0, "n_fail": 0} -async def cron_backup_database() -> bool: +async def cron_backup_database(queue: RDSJobQueue) -> None: while True: - await asyncio.sleep(3600) # every hour + await asyncio.sleep(settings.cron_backup_interval) # Run backup in thread pool to avoid blocking event loop; # Fine instead of making backup async since it runs infrequently. if await asyncio.to_thread(queue.backup): print("Backed up database.") - return True - return False @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: - task_failed_jobs = asyncio.create_task(cron_handle_timeouts()) - task_backup = asyncio.create_task(cron_backup_database()) + queue = app.dependency_overrides.get( + dependencies.queue_dep, dependencies.queue_dep + )() + assert isinstance(queue, RDSJobQueue) + queue.create() + task_failed_jobs = asyncio.create_task(cron_handle_timeouts(queue)) + task_backup = asyncio.create_task(cron_backup_database(queue)) yield task_failed_jobs.cancel() task_backup.cancel() diff --git a/workerfacing_api/settings.py b/workerfacing_api/settings.py index 8d06d19..7831df5 100644 --- a/workerfacing_api/settings.py +++ b/workerfacing_api/settings.py @@ -12,6 +12,11 @@ def get_secret_from_env(secret_name: str) -> str | None: return secret +# Cron job intervals +cron_timeout_check_interval = 300 # 5 minutes +cron_backup_interval = 3600 # 1 hour + + # Data filesystem = os.environ.get("FILESYSTEM") # filesystem s3_bucket = os.environ.get("S3_BUCKET") From e064707b9f7ebed4ed7c287948fe4435670b1e3d Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Wed, 4 Feb 2026 20:41:15 +0100 Subject: [PATCH 05/22] temp --- tests/integration/conftest.py | 5 +++-- tests/integration/test_main.py | 8 ++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index c0412f4..0c0782d 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -67,6 +67,7 @@ def base_filesystem( ) def queue( base_filesystem: FileSystem, + s3_testing_bucket: S3TestingBucket, rds_testing_instance: RDSTestingInstance, tmpdir_factory: pytest.TempdirFactory, request: pytest.FixtureRequest, @@ -77,8 +78,8 @@ def queue( s3_bucket: str | None = None s3_client: S3Client | None = None if isinstance(base_filesystem, S3Filesystem): - s3_bucket = base_filesystem.bucket - s3_client = base_filesystem.s3_client + s3_bucket = s3_testing_bucket.bucket_name + s3_client = s3_testing_bucket.s3_client return SQLiteRDSJobQueue( f"sqlite:///{queue_path}", retry_different=retry_different, diff --git a/tests/integration/test_main.py b/tests/integration/test_main.py index 6c78bba..25107a4 100644 --- a/tests/integration/test_main.py +++ b/tests/integration/test_main.py @@ -25,8 +25,8 @@ def client() -> TestClient: class TestCronHandleTimeouts: @pytest.fixture(autouse=True) def setup_timeout_failure(self, monkeypatch_module: pytest.MonkeyPatch) -> None: - """Set timeout_failure to 1 second for faster testing.""" - monkeypatch_module.setattr(settings, "timeout_failure", 1) + """Set timeout_failure to 2 seconds for faster testing.""" + monkeypatch_module.setattr(settings, "timeout_failure", 2) @pytest.fixture(autouse=True) def setup_max_retries(self, monkeypatch_module: pytest.MonkeyPatch) -> None: @@ -74,7 +74,7 @@ def test_handle_timeouts( assert queue.get_job(job_id).num_retries == 1 # Let timeout and fail - time.sleep(2) + time.sleep(4) assert queue.get_job(job_id).status == JobStates.error.value @@ -104,7 +104,7 @@ def get_backup_nrows(self, s3_testing_bucket: S3TestingBucket) -> int: ) backup_data_gzip = response["Body"].read() backup_data = gzip.decompress(backup_data_gzip) - with tempfile.NamedTemporaryFile(delete=False, suffix=".db") as tmp_file: + with tempfile.NamedTemporaryFile(suffix=".db") as tmp_file: tmp_file.write(backup_data) tmp_path = tmp_file.name conn = sqlite3.connect(tmp_path) From 02144f183142aac1210a159a8ea74f385bb14e0f Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Thu, 5 Feb 2026 00:18:57 +0100 Subject: [PATCH 06/22] No session-scoped monkeypatch --- tests/conftest.py | 12 +++------- tests/integration/conftest.py | 16 +++++++------- tests/integration/endpoints/test_jobs_post.py | 4 ++-- tests/integration/test_main.py | 22 +++++++++---------- 4 files changed, 23 insertions(+), 31 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index c1da77f..43790e4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,16 +17,10 @@ REGION_NAME: BucketLocationConstraintType = "eu-central-1" -@pytest.fixture(scope="session") -def monkeypatch_module() -> Generator[pytest.MonkeyPatch, Any, None]: - with pytest.MonkeyPatch.context() as mp: - yield mp - - -@pytest.fixture(autouse=True, scope="session") -def patch_update_job(monkeypatch_module: pytest.MonkeyPatch) -> MagicMock: +@pytest.fixture(autouse=True) +def patch_update_job(monkeypatch: pytest.MonkeyPatch) -> MagicMock: mock_update_job = MagicMock() - monkeypatch_module.setattr(job_tracking, "update_job", mock_update_job) + monkeypatch.setattr(job_tracking, "update_job", mock_update_job) return mock_update_job diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 0c0782d..b50617e 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -98,9 +98,9 @@ def override_filesystem_dep( base_filesystem: FileSystem, s3_testing_bucket: S3TestingBucket, base_dir: str, - monkeypatch_module: pytest.MonkeyPatch, + monkeypatch: pytest.MonkeyPatch, ) -> Generator[None, None, None]: - monkeypatch_module.setitem( + monkeypatch.setitem( workerfacing_app.dependency_overrides, # type: ignore filesystem_dep, lambda: base_filesystem, @@ -117,9 +117,9 @@ def override_filesystem_dep( def override_queue_dep( queue: RDSJobQueue, rds_testing_instance: RDSTestingInstance, - monkeypatch_module: pytest.MonkeyPatch, + monkeypatch: pytest.MonkeyPatch, ) -> Generator[None, None, None]: - monkeypatch_module.setitem( + monkeypatch.setitem( workerfacing_app.dependency_overrides, # type: ignore queue_dep, lambda: queue, @@ -132,8 +132,8 @@ def override_queue_dep( @pytest.fixture(autouse=True) -def override_auth(monkeypatch_module: pytest.MonkeyPatch, test_username: str) -> None: - monkeypatch_module.setitem( +def override_auth(monkeypatch: pytest.MonkeyPatch, test_username: str) -> None: + monkeypatch.setitem( workerfacing_app.dependency_overrides, # type: ignore current_user_dep, lambda: GroupClaims( @@ -148,9 +148,9 @@ def override_auth(monkeypatch_module: pytest.MonkeyPatch, test_username: str) -> @pytest.fixture(autouse=True) def override_internal_api_key_secret( - monkeypatch_module: pytest.MonkeyPatch, internal_api_key_secret: str + monkeypatch: pytest.MonkeyPatch, internal_api_key_secret: str ) -> str: - monkeypatch_module.setitem( + monkeypatch.setitem( workerfacing_app.dependency_overrides, # type: ignore authorizer, APIKeyDependency(internal_api_key_secret), diff --git a/tests/integration/endpoints/test_jobs_post.py b/tests/integration/endpoints/test_jobs_post.py index e1ffb38..88ac4f2 100644 --- a/tests/integration/endpoints/test_jobs_post.py +++ b/tests/integration/endpoints/test_jobs_post.py @@ -13,10 +13,10 @@ @pytest.fixture -def queue_enqueue(monkeypatch_module: pytest.MonkeyPatch) -> MagicMock: +def queue_enqueue(monkeypatch: pytest.MonkeyPatch) -> MagicMock: queue = MagicMock() queue.enqueue = MagicMock() - monkeypatch_module.setitem( + monkeypatch.setitem( workerfacing_app.dependency_overrides, # type: ignore queue_dep, lambda: queue, diff --git a/tests/integration/test_main.py b/tests/integration/test_main.py index 25107a4..8653b16 100644 --- a/tests/integration/test_main.py +++ b/tests/integration/test_main.py @@ -24,19 +24,19 @@ def client() -> TestClient: class TestCronHandleTimeouts: @pytest.fixture(autouse=True) - def setup_timeout_failure(self, monkeypatch_module: pytest.MonkeyPatch) -> None: + def setup_timeout_failure(self, monkeypatch: pytest.MonkeyPatch) -> None: """Set timeout_failure to 2 seconds for faster testing.""" - monkeypatch_module.setattr(settings, "timeout_failure", 2) + monkeypatch.setattr(settings, "timeout_failure", 2) @pytest.fixture(autouse=True) - def setup_max_retries(self, monkeypatch_module: pytest.MonkeyPatch) -> None: + def setup_max_retries(self, monkeypatch: pytest.MonkeyPatch) -> None: """Set max retries to 1 for faster testing.""" - monkeypatch_module.setattr(settings, "max_retries", 1) + monkeypatch.setattr(settings, "max_retries", 1) @pytest.fixture(autouse=True) - def setup_cron_interval(self, monkeypatch_module: pytest.MonkeyPatch) -> None: + def setup_cron_interval(self, monkeypatch: pytest.MonkeyPatch) -> None: """Set cron interval to 1 second for faster testing.""" - monkeypatch_module.setattr(settings, "cron_timeout_check_interval", 1) + monkeypatch.setattr(settings, "cron_timeout_check_interval", 1) def test_handle_timeouts( self, @@ -90,11 +90,9 @@ def skip_if_not_sqlite_s3( pytest.skip("Backup tests only run with SQLite queue and S3 filesystem") @pytest.fixture(autouse=True) - def setup_backup_cron_interval( - self, monkeypatch_module: pytest.MonkeyPatch - ) -> None: + def setup_backup_cron_interval(self, monkeypatch: pytest.MonkeyPatch) -> None: """Set backup cron interval to 1 seconds for faster testing.""" - monkeypatch_module.setattr(settings, "cron_backup_interval", 1) + monkeypatch.setattr(settings, "cron_backup_interval", 1) def get_backup_nrows(self, s3_testing_bucket: S3TestingBucket) -> int: """Helper to get number of rows in backup database.""" @@ -121,7 +119,7 @@ def test_sqlite_backup( client: TestClient, s3_testing_bucket: S3TestingBucket, tmpdir_factory: pytest.TempdirFactory, - monkeypatch_module: pytest.MonkeyPatch, + monkeypatch: pytest.MonkeyPatch, ) -> None: """Test the backup and restore functionality of the SQLiteRDSJobQueue.""" # Startup: no backup present @@ -151,7 +149,7 @@ def test_sqlite_backup( s3_client=s3_testing_bucket.s3_client, s3_bucket=s3_testing_bucket.bucket_name, ) - monkeypatch_module.setitem( + monkeypatch.setitem( workerfacing_app.dependency_overrides, # type: ignore queue_dep, lambda: new_queue, From c4bbbd51ab73ef59d87383df6fc90c1cfc57c81c Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Thu, 5 Feb 2026 10:50:55 +0100 Subject: [PATCH 07/22] temp --- .github/workflows/code-checks.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/code-checks.yaml b/.github/workflows/code-checks.yaml index 9553e78..10bbe2d 100644 --- a/.github/workflows/code-checks.yaml +++ b/.github/workflows/code-checks.yaml @@ -51,7 +51,7 @@ jobs: aws-region: eu-central-1 - name: Run tests run: | - poetry run pytest -m "aws or not(aws)" --junitxml=pytest.xml --cov-report=term-missing --cov=workerfacing_api | tee pytest-coverage.txt + poetry run pytest -m "aws or not(aws)" --durations=20 --junitxml=pytest.xml --cov-report=term-missing --cov=workerfacing_api | tee pytest-coverage.txt echo "test_exit_code=${PIPESTATUS[0]}" >> $GITHUB_ENV - name: Coverage comment uses: MishaKav/pytest-coverage-comment@main From ce15b4c88f8c88664bd3875035fd079e78811a1d Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Thu, 5 Feb 2026 17:17:23 +0100 Subject: [PATCH 08/22] task group + skip some tests --- tests/integration/conftest.py | 2 ++ workerfacing_api/main.py | 11 +++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index b50617e..656b4e3 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -87,6 +87,8 @@ def queue( s3_bucket=s3_bucket, ) elif request.param == "aws-queue": + if isinstance(base_filesystem, LocalFilesystem): + pytest.skip("Only testing RDS queue in combination with S3 filesystem") rds_testing_instance.create() return RDSJobQueue(rds_testing_instance.db_url, retry_different=retry_different) else: diff --git a/workerfacing_api/main.py b/workerfacing_api/main.py index e790699..957f0b3 100644 --- a/workerfacing_api/main.py +++ b/workerfacing_api/main.py @@ -41,12 +41,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: )() assert isinstance(queue, RDSJobQueue) queue.create() - task_failed_jobs = asyncio.create_task(cron_handle_timeouts(queue)) - task_backup = asyncio.create_task(cron_backup_database(queue)) - yield - task_failed_jobs.cancel() - task_backup.cancel() - if queue.backup(): # final backup on shutdown + async with asyncio.TaskGroup() as tg: # cancels and waits on exit + tg.create_task(cron_handle_timeouts(queue)) + tg.create_task(cron_backup_database(queue)) + yield + if queue.backup(): print("Created final backup on shutdown.") From 328b26b569c55d61f4aa049ebe206922d694cf55 Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Thu, 5 Feb 2026 18:01:40 +0100 Subject: [PATCH 09/22] correct (no TaskGroup) --- workerfacing_api/main.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/workerfacing_api/main.py b/workerfacing_api/main.py index 957f0b3..2ee4dd7 100644 --- a/workerfacing_api/main.py +++ b/workerfacing_api/main.py @@ -41,10 +41,12 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: )() assert isinstance(queue, RDSJobQueue) queue.create() - async with asyncio.TaskGroup() as tg: # cancels and waits on exit - tg.create_task(cron_handle_timeouts(queue)) - tg.create_task(cron_backup_database(queue)) - yield + task_failed_jobs = asyncio.create_task(cron_handle_timeouts(queue)) + task_backup = asyncio.create_task(cron_backup_database(queue)) + yield + task_failed_jobs.cancel() + task_backup.cancel() + await asyncio.gather(task_failed_jobs, task_backup, return_exceptions=True) if queue.backup(): print("Created final backup on shutdown.") From 2ad4d0acfe362db9306b2aab6cf5e098245822a1 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Thu, 5 Feb 2026 18:23:03 +0100 Subject: [PATCH 10/22] Use Python sqlite3.backup() API instead of CLI subprocess (#82) --- workerfacing_api/core/queue.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/workerfacing_api/core/queue.py b/workerfacing_api/core/queue.py index 5bb25bf..34381d1 100644 --- a/workerfacing_api/core/queue.py +++ b/workerfacing_api/core/queue.py @@ -3,7 +3,7 @@ import json import os import pickle -import subprocess +import sqlite3 import tempfile import threading import time @@ -603,8 +603,10 @@ def backup(self) -> bool: with tempfile.TemporaryDirectory() as temp_dir: tmp_backup_path = os.path.join(temp_dir, "backup.db") tmp_gzip_path = os.path.join(temp_dir, "backup.db.gz") - backup_cmd = ["sqlite3", self.db_path, f".backup {tmp_backup_path}"] - subprocess.run(backup_cmd, text=True, check=True) + with sqlite3.connect(self.db_path) as source_conn: + with sqlite3.connect(tmp_backup_path) as backup_conn: + source_conn.backup(backup_conn) + with open(tmp_backup_path, "rb") as f_in: with gzip.open(tmp_gzip_path, "wb") as f_out: f_out.writelines(f_in) From 3c92c54b478e9e17fa4fb22ea66cce9492414c32 Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Thu, 5 Feb 2026 18:28:07 +0100 Subject: [PATCH 11/22] reformat --- workerfacing_api/core/queue.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/workerfacing_api/core/queue.py b/workerfacing_api/core/queue.py index 34381d1..0a60fa3 100644 --- a/workerfacing_api/core/queue.py +++ b/workerfacing_api/core/queue.py @@ -606,7 +606,7 @@ def backup(self) -> bool: with sqlite3.connect(self.db_path) as source_conn: with sqlite3.connect(tmp_backup_path) as backup_conn: source_conn.backup(backup_conn) - + with open(tmp_backup_path, "rb") as f_in: with gzip.open(tmp_gzip_path, "wb") as f_out: f_out.writelines(f_in) From 351c2cdcd624b330fc069ddce6484b8e1960aad1 Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Thu, 5 Feb 2026 18:55:09 +0100 Subject: [PATCH 12/22] fix tests, hopefully --- tests/integration/endpoints/test_jobs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/endpoints/test_jobs.py b/tests/integration/endpoints/test_jobs.py index 81f193f..4650d9f 100644 --- a/tests/integration/endpoints/test_jobs.py +++ b/tests/integration/endpoints/test_jobs.py @@ -10,7 +10,7 @@ from tests.integration.endpoints.conftest import EndpointParams, _TestEndpoint from workerfacing_api.core.filesystem import FileSystem, LocalFilesystem, S3Filesystem -from workerfacing_api.core.queue import RDSJobQueue, SQLiteRDSJobQueue +from workerfacing_api.core.queue import RDSJobQueue from workerfacing_api.crud import job_tracking from workerfacing_api.exceptions import JobDeletedException from workerfacing_api.schemas.queue_jobs import EnvironmentTypes, SubmittedJob @@ -238,7 +238,7 @@ def test_job_files_post( params={"type": "output", "base_path": "test"}, ) assert res.status_code == 201 - if isinstance(queue, SQLiteRDSJobQueue): + if isinstance(base_filesystem, LocalFilesystem): req_base = client else: req_base = requests # type: ignore From 2692834f6c2611eecbe9eeb3a43f27c4fed49a87 Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Thu, 5 Feb 2026 19:07:52 +0100 Subject: [PATCH 13/22] print -> logging --- workerfacing_api/main.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/workerfacing_api/main.py b/workerfacing_api/main.py index 2ee4dd7..1e438a8 100644 --- a/workerfacing_api/main.py +++ b/workerfacing_api/main.py @@ -1,4 +1,5 @@ import asyncio +import logging from contextlib import asynccontextmanager from typing import AsyncGenerator @@ -7,6 +8,8 @@ dotenv.load_dotenv() +logger = logging.getLogger(__name__) + from workerfacing_api import dependencies, settings, tags from workerfacing_api.core.queue import RDSJobQueue from workerfacing_api.endpoints import access, files, jobs, jobs_post @@ -15,14 +18,14 @@ async def cron_handle_timeouts(queue: RDSJobQueue) -> None: while True: await asyncio.sleep(settings.cron_timeout_check_interval) - print("Silent fails check: starting...") + logger.info("Silent fails check: starting...") try: max_retries = settings.max_retries timeout_failure = settings.timeout_failure n_retry, n_fail = queue.handle_timeouts(max_retries, timeout_failure) - print(f"Silent fails check: {n_retry} re-queued, {n_fail} failed.") + logger.info(f"Silent fails check: {n_retry} re-queued, {n_fail} failed.") except Exception as e: - print(f"Silent fails check: failed with {e}") + logger.error(f"Silent fails check: failed with {e}") async def cron_backup_database(queue: RDSJobQueue) -> None: @@ -31,7 +34,7 @@ async def cron_backup_database(queue: RDSJobQueue) -> None: # Run backup in thread pool to avoid blocking event loop; # Fine instead of making backup async since it runs infrequently. if await asyncio.to_thread(queue.backup): - print("Backed up database.") + logger.info("Backed up database.") @asynccontextmanager @@ -48,7 +51,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: task_backup.cancel() await asyncio.gather(task_failed_jobs, task_backup, return_exceptions=True) if queue.backup(): - print("Created final backup on shutdown.") + logger.info("Created final backup on shutdown.") workerfacing_app = FastAPI(openapi_tags=tags.tags_metadata, lifespan=lifespan) From fd2b701f1bf0fd1a63e0d30cac73a265495c04e8 Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Thu, 5 Feb 2026 19:12:40 +0100 Subject: [PATCH 14/22] fault-tolerant cron jobs --- workerfacing_api/main.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/workerfacing_api/main.py b/workerfacing_api/main.py index 1e438a8..326a95a 100644 --- a/workerfacing_api/main.py +++ b/workerfacing_api/main.py @@ -31,10 +31,14 @@ async def cron_handle_timeouts(queue: RDSJobQueue) -> None: async def cron_backup_database(queue: RDSJobQueue) -> None: while True: await asyncio.sleep(settings.cron_backup_interval) + logger.info("Database backup: starting...") # Run backup in thread pool to avoid blocking event loop; # Fine instead of making backup async since it runs infrequently. - if await asyncio.to_thread(queue.backup): - logger.info("Backed up database.") + try: + if await asyncio.to_thread(queue.backup): + logger.info("Backed up database.") + except Exception as e: + logger.error(f"Database backup failed with {e}") @asynccontextmanager From bb25cda0fa45ca106d0e45ed2e5d268893581ee7 Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Thu, 5 Feb 2026 22:24:59 +0100 Subject: [PATCH 15/22] working rds tests? --- tests/conftest.py | 9 +++------ workerfacing_api/main.py | 4 ++-- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 43790e4..9b64a48 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,7 +34,7 @@ def create(self) -> None: self.add_ingress_rule() self.db_url = self.create_db_url() self.engine = self.get_engine() - self.delete_db_tables() + self.cleanup() def get_engine(self) -> Engine: for _ in range(5): @@ -73,7 +73,7 @@ def add_ingress_rule(self) -> None: else: raise e - def delete_db_tables(self) -> None: + def cleanup(self) -> None: metadata = MetaData() engine = self.engine metadata.reflect(engine) @@ -132,14 +132,11 @@ def create_db_url(self) -> str: address = response["DBInstances"][0]["Endpoint"]["Address"] return f"postgresql://{user}:{password}@{address}:5432/{self.db_name}" - def cleanup(self) -> None: - self.delete_db_tables() - self.ec2_client.revoke_security_group_ingress(**self.vpc_sg_rule_params) - def delete(self) -> None: # never used (AWS tests skipped) if not hasattr(self, "rds_client"): return + self.ec2_client.revoke_security_group_ingress(**self.vpc_sg_rule_params) self.rds_client.delete_db_instance( DBInstanceIdentifier=self.db_name, SkipFinalSnapshot=True, diff --git a/workerfacing_api/main.py b/workerfacing_api/main.py index 326a95a..c4be533 100644 --- a/workerfacing_api/main.py +++ b/workerfacing_api/main.py @@ -17,7 +17,6 @@ async def cron_handle_timeouts(queue: RDSJobQueue) -> None: while True: - await asyncio.sleep(settings.cron_timeout_check_interval) logger.info("Silent fails check: starting...") try: max_retries = settings.max_retries @@ -26,11 +25,11 @@ async def cron_handle_timeouts(queue: RDSJobQueue) -> None: logger.info(f"Silent fails check: {n_retry} re-queued, {n_fail} failed.") except Exception as e: logger.error(f"Silent fails check: failed with {e}") + await asyncio.sleep(settings.cron_timeout_check_interval) async def cron_backup_database(queue: RDSJobQueue) -> None: while True: - await asyncio.sleep(settings.cron_backup_interval) logger.info("Database backup: starting...") # Run backup in thread pool to avoid blocking event loop; # Fine instead of making backup async since it runs infrequently. @@ -39,6 +38,7 @@ async def cron_backup_database(queue: RDSJobQueue) -> None: logger.info("Backed up database.") except Exception as e: logger.error(f"Database backup failed with {e}") + await asyncio.sleep(settings.cron_backup_interval) @asynccontextmanager From f3dba65922a40b4bd4302690f33009366959060b Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Fri, 6 Feb 2026 02:13:24 +0100 Subject: [PATCH 16/22] temp log failure --- tests/conftest.py | 15 ++++++++++++++- tests/integration/test_main.py | 4 +++- workerfacing_api/core/queue.py | 4 ++++ 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 9b64a48..0bd4a58 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -73,6 +73,19 @@ def add_ingress_rule(self) -> None: else: raise e + def remove_ingress_rules(self) -> None: + # cleans up earlier tests too (in case of failures) + security_groups = self.ec2_client.describe_security_groups( + GroupNames=[self.vpc_sg_rule_params["GroupName"]] + ) + for sg in security_groups["SecurityGroups"]: + for rule in sg["IpPermissions"]: + if rule.get("FromPort") == 5432 and rule.get("ToPort") == 5432: + self.ec2_client.revoke_security_group_ingress( + GroupId=sg["GroupId"], + IpPermissions=[rule], # type: ignore + ) + def cleanup(self) -> None: metadata = MetaData() engine = self.engine @@ -136,7 +149,7 @@ def delete(self) -> None: # never used (AWS tests skipped) if not hasattr(self, "rds_client"): return - self.ec2_client.revoke_security_group_ingress(**self.vpc_sg_rule_params) + self.remove_ingress_rules() self.rds_client.delete_db_instance( DBInstanceIdentifier=self.db_name, SkipFinalSnapshot=True, diff --git a/tests/integration/test_main.py b/tests/integration/test_main.py index 8653b16..384919e 100644 --- a/tests/integration/test_main.py +++ b/tests/integration/test_main.py @@ -66,7 +66,9 @@ def test_handle_timeouts( # Let timeout time.sleep(4) - assert queue.get_job(job_id).status == JobStates.queued.value + assert queue.get_job(job_id).status == JobStates.queued.value, [ + el.__dict__ for el in queue.get_all() + ] # Pull again assert len(client.get("/jobs", params=get_params).json()) == 1 diff --git a/workerfacing_api/core/queue.py b/workerfacing_api/core/queue.py index 0a60fa3..19b6b08 100644 --- a/workerfacing_api/core/queue.py +++ b/workerfacing_api/core/queue.py @@ -350,6 +350,10 @@ def create(self) -> None: def delete(self) -> None: Base.metadata.drop_all(self.engine) + def get_all(self) -> Any: + with Session(self.engine) as session: + return session.query(QueuedJob).all() + def enqueue(self, job: SubmittedJob) -> None: with Session(self.engine) as session: session.add( From 6e77253c68e62cc4fc3e769419644e6e95dfa85e Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Fri, 6 Feb 2026 03:20:39 +0100 Subject: [PATCH 17/22] verbose test output --- .github/workflows/code-checks.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/code-checks.yaml b/.github/workflows/code-checks.yaml index 10bbe2d..cbf5699 100644 --- a/.github/workflows/code-checks.yaml +++ b/.github/workflows/code-checks.yaml @@ -51,7 +51,7 @@ jobs: aws-region: eu-central-1 - name: Run tests run: | - poetry run pytest -m "aws or not(aws)" --durations=20 --junitxml=pytest.xml --cov-report=term-missing --cov=workerfacing_api | tee pytest-coverage.txt + poetry run pytest -m "aws or not(aws)" -vv --durations=20 --junitxml=pytest.xml --cov-report=term-missing --cov=workerfacing_api | tee pytest-coverage.txt echo "test_exit_code=${PIPESTATUS[0]}" >> $GITHUB_ENV - name: Coverage comment uses: MishaKav/pytest-coverage-comment@main From d14a112b93dbd209b2284e0109db32e1e916b2d4 Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Fri, 6 Feb 2026 03:52:34 +0100 Subject: [PATCH 18/22] deprecate SQS tests --- .github/workflows/code-checks.yaml | 2 +- pyproject.toml | 3 ++- tests/unit/core/test_queue.py | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/code-checks.yaml b/.github/workflows/code-checks.yaml index cbf5699..643ee4d 100644 --- a/.github/workflows/code-checks.yaml +++ b/.github/workflows/code-checks.yaml @@ -51,7 +51,7 @@ jobs: aws-region: eu-central-1 - name: Run tests run: | - poetry run pytest -m "aws or not(aws)" -vv --durations=20 --junitxml=pytest.xml --cov-report=term-missing --cov=workerfacing_api | tee pytest-coverage.txt + poetry run pytest -m "aws and not(deprecated) or not(aws)" -vv --durations=20 --junitxml=pytest.xml --cov-report=term-missing --cov=workerfacing_api | tee pytest-coverage.txt echo "test_exit_code=${PIPESTATUS[0]}" >> $GITHUB_ENV - name: Coverage comment uses: MishaKav/pytest-coverage-comment@main diff --git a/pyproject.toml b/pyproject.toml index 4ee6cd6..aff5cbc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,8 @@ serve = "scripts.serve:main" [tool.pytest.ini_options] markers = [ - "aws: requires aws credentials" + "aws: requires aws credentials", + "deprecated: tests for deprecated features", ] addopts = "-m 'not aws'" diff --git a/tests/unit/core/test_queue.py b/tests/unit/core/test_queue.py index cbd72c1..440750a 100644 --- a/tests/unit/core/test_queue.py +++ b/tests/unit/core/test_queue.py @@ -171,6 +171,7 @@ def base_queue( base_queue.delete() +@pytest.mark.deprecated class TestSQSQueue(_TestJobQueue): @pytest.fixture( params=[True, pytest.param(False, marks=pytest.mark.aws)], scope="class" From 5a9349d3f1c80a72bcb467ab0585ad6d7e2a4f59 Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Fri, 6 Feb 2026 11:39:17 +0100 Subject: [PATCH 19/22] better debugging --- .github/workflows/code-checks.yaml | 2 +- tests/integration/test_main.py | 36 ++++++++++++++++++++++-------- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/.github/workflows/code-checks.yaml b/.github/workflows/code-checks.yaml index 643ee4d..e128928 100644 --- a/.github/workflows/code-checks.yaml +++ b/.github/workflows/code-checks.yaml @@ -51,7 +51,7 @@ jobs: aws-region: eu-central-1 - name: Run tests run: | - poetry run pytest -m "aws and not(deprecated) or not(aws)" -vv --durations=20 --junitxml=pytest.xml --cov-report=term-missing --cov=workerfacing_api | tee pytest-coverage.txt + poetry run pytest -m "aws and not(deprecated) or not(aws)" -vv -s --durations=20 --junitxml=pytest.xml --cov-report=term-missing --cov=workerfacing_api | tee pytest-coverage.txt echo "test_exit_code=${PIPESTATUS[0]}" >> $GITHUB_ENV - name: Coverage comment uses: MishaKav/pytest-coverage-comment@main diff --git a/tests/integration/test_main.py b/tests/integration/test_main.py index 384919e..aa17d2b 100644 --- a/tests/integration/test_main.py +++ b/tests/integration/test_main.py @@ -44,15 +44,20 @@ def test_handle_timeouts( base_job: SubmittedJob, client: TestClient, ) -> None: + job_id = base_job.job.meta.job_id with client: # Push the job queue.enqueue(base_job) + job = queue.get_job(job_id) + assert job.status == JobStates.queued.value + assert job.num_retries == 0 # Pull the job get_params = {"memory": 1} - job_id = base_job.job.meta.job_id assert len(client.get("/jobs", params=get_params).json()) == 1 - assert queue.get_job(job_id).status == JobStates.pulled.value + job = queue.get_job(job_id) + assert job.status == JobStates.pulled.value + assert job.num_retries == 0 # Job kept alive by periodic status updates for _ in range(4): @@ -62,22 +67,35 @@ def test_handle_timeouts( params={"status": "running", "runtime_details": "Processing..."}, ) assert len(client.get("/jobs", params=get_params).json()) == 0 - assert queue.get_job(job_id).status == JobStates.running.value + job = queue.get_job(job_id) + assert job.status == JobStates.running.value + assert job.num_retries == 0 # Let timeout time.sleep(4) - assert queue.get_job(job_id).status == JobStates.queued.value, [ - el.__dict__ for el in queue.get_all() - ] + job = queue.get_job(job_id) + if not job.status == JobStates.queued.value: + all_jobs = queue.get_all() + print(f"N_jobs={all_jobs}") + for job_ in all_jobs: + print( + f"Job {job_.job_id}: status={job_.status}, num_retries={job_.num_retries}" + ) + print(job.__dict__) + assert job.status == JobStates.queued.value + assert job.num_retries == 0 # Pull again assert len(client.get("/jobs", params=get_params).json()) == 1 - assert queue.get_job(job_id).status == JobStates.pulled.value - assert queue.get_job(job_id).num_retries == 1 + job = queue.get_job(job_id) + assert job.status == JobStates.pulled.value + assert job.num_retries == 1 # Let timeout and fail time.sleep(4) - assert queue.get_job(job_id).status == JobStates.error.value + job = queue.get_job(job_id) + assert job.status == JobStates.error.value + assert job.num_retries == 1 class TestCronBackupDatabase: From fad39608834a28a3cd4bd692a931e78682f1570f Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Fri, 6 Feb 2026 12:00:50 +0100 Subject: [PATCH 20/22] now real fix --- .github/workflows/code-checks.yaml | 2 +- tests/integration/test_main.py | 10 +--------- workerfacing_api/core/queue.py | 7 +++++-- 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/.github/workflows/code-checks.yaml b/.github/workflows/code-checks.yaml index e128928..643ee4d 100644 --- a/.github/workflows/code-checks.yaml +++ b/.github/workflows/code-checks.yaml @@ -51,7 +51,7 @@ jobs: aws-region: eu-central-1 - name: Run tests run: | - poetry run pytest -m "aws and not(deprecated) or not(aws)" -vv -s --durations=20 --junitxml=pytest.xml --cov-report=term-missing --cov=workerfacing_api | tee pytest-coverage.txt + poetry run pytest -m "aws and not(deprecated) or not(aws)" -vv --durations=20 --junitxml=pytest.xml --cov-report=term-missing --cov=workerfacing_api | tee pytest-coverage.txt echo "test_exit_code=${PIPESTATUS[0]}" >> $GITHUB_ENV - name: Coverage comment uses: MishaKav/pytest-coverage-comment@main diff --git a/tests/integration/test_main.py b/tests/integration/test_main.py index aa17d2b..ddc74c1 100644 --- a/tests/integration/test_main.py +++ b/tests/integration/test_main.py @@ -74,16 +74,8 @@ def test_handle_timeouts( # Let timeout time.sleep(4) job = queue.get_job(job_id) - if not job.status == JobStates.queued.value: - all_jobs = queue.get_all() - print(f"N_jobs={all_jobs}") - for job_ in all_jobs: - print( - f"Job {job_.job_id}: status={job_.status}, num_retries={job_.num_retries}" - ) - print(job.__dict__) assert job.status == JobStates.queued.value - assert job.num_retries == 0 + assert job.num_retries == 1 # Pull again assert len(client.get("/jobs", params=get_params).json()) == 1 diff --git a/workerfacing_api/core/queue.py b/workerfacing_api/core/queue.py index 19b6b08..81ed92d 100644 --- a/workerfacing_api/core/queue.py +++ b/workerfacing_api/core/queue.py @@ -525,7 +525,11 @@ def handle_timeouts( < time_now - datetime.timedelta(seconds=timeout_failure) ), ) - jobs_retry = jobs_timeout.filter(QueuedJob.num_retries < max_retries) + # Evaluate both queries before modifying any jobs to avoid race condition + jobs_retry = jobs_timeout.filter(QueuedJob.num_retries < max_retries).all() + jobs_failed = jobs_timeout.filter( + QueuedJob.num_retries >= max_retries + ).all() for job in jobs_retry: # TODO: increase priority? job.num_retries += 1 @@ -540,7 +544,6 @@ def handle_timeouts( except JobDeletedException: # job probably deleted by user, skip updating status pass - jobs_failed = jobs_timeout.filter(QueuedJob.num_retries >= max_retries) for job in jobs_failed: try: self.update_job_status( From 201696d34b66b5813b145e9069c93e2d71d014f8 Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Fri, 6 Feb 2026 16:54:08 +0100 Subject: [PATCH 21/22] Bigger time intervals to avoid flaky tests --- .github/workflows/code-checks.yaml | 2 +- tests/integration/test_main.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/code-checks.yaml b/.github/workflows/code-checks.yaml index 643ee4d..b071fb5 100644 --- a/.github/workflows/code-checks.yaml +++ b/.github/workflows/code-checks.yaml @@ -51,7 +51,7 @@ jobs: aws-region: eu-central-1 - name: Run tests run: | - poetry run pytest -m "aws and not(deprecated) or not(aws)" -vv --durations=20 --junitxml=pytest.xml --cov-report=term-missing --cov=workerfacing_api | tee pytest-coverage.txt + poetry run pytest -m "aws and not(deprecated) or not(aws)" --junitxml=pytest.xml --cov-report=term-missing --cov=workerfacing_api | tee pytest-coverage.txt echo "test_exit_code=${PIPESTATUS[0]}" >> $GITHUB_ENV - name: Coverage comment uses: MishaKav/pytest-coverage-comment@main diff --git a/tests/integration/test_main.py b/tests/integration/test_main.py index ddc74c1..1468ec8 100644 --- a/tests/integration/test_main.py +++ b/tests/integration/test_main.py @@ -25,8 +25,8 @@ def client() -> TestClient: class TestCronHandleTimeouts: @pytest.fixture(autouse=True) def setup_timeout_failure(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Set timeout_failure to 2 seconds for faster testing.""" - monkeypatch.setattr(settings, "timeout_failure", 2) + """Set timeout_failure to 5 seconds for faster testing (but sufficient margin).""" + monkeypatch.setattr(settings, "timeout_failure", 5) @pytest.fixture(autouse=True) def setup_max_retries(self, monkeypatch: pytest.MonkeyPatch) -> None: @@ -60,8 +60,7 @@ def test_handle_timeouts( assert job.num_retries == 0 # Job kept alive by periodic status updates - for _ in range(4): - time.sleep(1) + for _ in range(5): client.put( f"/jobs/{job_id}/status", params={"status": "running", "runtime_details": "Processing..."}, @@ -70,9 +69,10 @@ def test_handle_timeouts( job = queue.get_job(job_id) assert job.status == JobStates.running.value assert job.num_retries == 0 + time.sleep(2) - # Let timeout - time.sleep(4) + # Let timeout (wait longer than timeout_failure) + time.sleep(10) job = queue.get_job(job_id) assert job.status == JobStates.queued.value assert job.num_retries == 1 @@ -83,8 +83,8 @@ def test_handle_timeouts( assert job.status == JobStates.pulled.value assert job.num_retries == 1 - # Let timeout and fail - time.sleep(4) + # Let timeout and fail (wait longer than timeout_failure) + time.sleep(10) job = queue.get_job(job_id) assert job.status == JobStates.error.value assert job.num_retries == 1 From fe04bad5c56fb7b6ff9ed51b4ea290df1be79696 Mon Sep 17 00:00:00 2001 From: nolan1999 Date: Sat, 28 Feb 2026 22:42:39 +0100 Subject: [PATCH 22/22] bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index aff5cbc..87fc4df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "workerfacing-api" -version = "0.1.0" +version = "0.2.0" description = "Worker-facing API of DECODE OpenCloud." authors = ["Arthur Jaques "] readme = "README.md"