diff --git a/AGENTS.md b/AGENTS.md index cdf8497e1..eb348b291 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -17,6 +17,7 @@ ## Coding Style & Naming Conventions - Python targets 3.9+ with 4-space indentation and max line length of 99 (see `ruff.toml`; `E501` is ignored but keep lines readable). - Imports are sorted via Ruff’s isort settings (`dstack` treated as first-party). +- Keep primary/public functions before local helper functions in a module section. - Prefer pydantic-style models in `core/models`. - Tests use `test_*.py` modules and `test_*` functions; fixtures live near usage. diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index 209679f0e..03a54ccf2 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -167,8 +167,9 @@ async def lifespan(app: FastAPI): pipeline_manager = None if settings.SERVER_BACKGROUND_PROCESSING_ENABLED: scheduler = start_scheduled_tasks() - pipeline_manager = start_pipeline_tasks() - app.state.pipeline_manager = pipeline_manager + if core_settings.FeatureFlags.PIPELINE_PROCESSING_ENABLED: + pipeline_manager = start_pipeline_tasks() + app.state.pipeline_manager = pipeline_manager else: logger.info("Background processing is disabled") PROBES_SCHEDULER.start() diff --git a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py index 355e04247..01feb958d 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py @@ -2,10 +2,10 @@ from dstack._internal.server.background.pipeline_tasks.base import Pipeline from dstack._internal.server.background.pipeline_tasks.compute_groups import ComputeGroupPipeline +from dstack._internal.server.background.pipeline_tasks.gateways import GatewayPipeline from dstack._internal.server.background.pipeline_tasks.placement_groups import ( PlacementGroupPipeline, ) -from dstack._internal.settings import FeatureFlags from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -13,12 +13,11 @@ class PipelineManager: def __init__(self) -> None: - self._pipelines: list[Pipeline] = [] - if FeatureFlags.PIPELINE_PROCESSING_ENABLED: - self._pipelines += [ - ComputeGroupPipeline(), - PlacementGroupPipeline(), - ] + self._pipelines: list[Pipeline] = [ + ComputeGroupPipeline(), + GatewayPipeline(), + PlacementGroupPipeline(), + ] self._hinter = PipelineHinter(self._pipelines) def start(self): diff --git a/src/dstack/_internal/server/background/pipeline_tasks/base.py b/src/dstack/_internal/server/background/pipeline_tasks/base.py index 30be480bf..9d016934c 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/base.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/base.py @@ -19,6 +19,10 @@ @dataclass class PipelineItem: + """ + Pipelines can work with this class or its subclass if the worker needs to access extra attributes. + """ + __tablename__: str id: uuid.UUID lock_expires_at: datetime @@ -26,7 +30,14 @@ class PipelineItem: prev_lock_expired: bool +ItemT = TypeVar("ItemT", bound=PipelineItem) + + class PipelineModel(Protocol): + """ + Heartbeater can work with any DB model implementing this protocol. + """ + __tablename__: str __mapper__: ClassVar[Any] __table__: ClassVar[Any] @@ -39,7 +50,7 @@ class PipelineError(Exception): pass -class Pipeline(ABC): +class Pipeline(Generic[ItemT], ABC): def __init__( self, workers_num: int, @@ -57,7 +68,7 @@ def __init__( self._min_processing_interval = min_processing_interval self._lock_timeout = lock_timeout self._heartbeat_trigger = heartbeat_trigger - self._queue = asyncio.Queue[PipelineItem](maxsize=self._queue_maxsize) + self._queue = asyncio.Queue[ItemT](maxsize=self._queue_maxsize) self._tasks: list[asyncio.Task] = [] self._running = False self._shutdown = False @@ -119,27 +130,24 @@ def hint_fetch_model_name(self) -> str: @property @abstractmethod - def _heartbeater(self) -> "Heartbeater": + def _heartbeater(self) -> "Heartbeater[ItemT]": pass @property @abstractmethod - def _fetcher(self) -> "Fetcher": + def _fetcher(self) -> "Fetcher[ItemT]": pass @property @abstractmethod - def _workers(self) -> Sequence["Worker"]: + def _workers(self) -> Sequence["Worker[ItemT]"]: pass -ModelT = TypeVar("ModelT", bound=PipelineModel) - - -class Heartbeater(Generic[ModelT]): +class Heartbeater(Generic[ItemT]): def __init__( self, - model_type: type[ModelT], + model_type: type[PipelineModel], lock_timeout: timedelta, heartbeat_trigger: timedelta, heartbeat_delay: float = 1.0, @@ -147,7 +155,7 @@ def __init__( self._model_type = model_type self._lock_timeout = lock_timeout self._hearbeat_margin = heartbeat_trigger - self._items: dict[uuid.UUID, PipelineItem] = {} + self._items: dict[uuid.UUID, ItemT] = {} self._untrack_lock = asyncio.Lock() self._heartbeat_delay = heartbeat_delay self._running = False @@ -164,10 +172,10 @@ async def start(self): def stop(self): self._running = False - async def track(self, item: PipelineItem): + async def track(self, item: ItemT): self._items[item.id] = item - async def untrack(self, item: PipelineItem): + async def untrack(self, item: ItemT): async with self._untrack_lock: tracked = self._items.get(item.id) # Prevent expired fetch iteration to unlock item processed by new iteration. @@ -175,7 +183,7 @@ async def untrack(self, item: PipelineItem): del self._items[item.id] async def heartbeat(self): - items_to_update: list[PipelineItem] = [] + items_to_update: list[ItemT] = [] now = get_current_datetime() items = list(self._items.values()) failed_to_heartbeat_count = 0 @@ -227,16 +235,16 @@ async def heartbeat(self): ) -class Fetcher(ABC): +class Fetcher(Generic[ItemT], ABC): _DEFAULT_FETCH_DELAYS = [0.5, 1, 2, 5] def __init__( self, - queue: asyncio.Queue[PipelineItem], + queue: asyncio.Queue[ItemT], queue_desired_minsize: int, min_processing_interval: timedelta, lock_timeout: timedelta, - heartbeater: Heartbeater, + heartbeater: Heartbeater[ItemT], queue_check_delay: float = 1.0, fetch_delays: Optional[list[float]] = None, ) -> None: @@ -289,7 +297,7 @@ def hint(self): self._fetch_event.set() @abstractmethod - async def fetch(self, limit: int) -> list[PipelineItem]: + async def fetch(self, limit: int) -> list[ItemT]: pass def _next_fetch_delay(self, empty_fetch_count: int) -> float: @@ -298,11 +306,11 @@ def _next_fetch_delay(self, empty_fetch_count: int) -> float: return next_delay * (1 + jitter) -class Worker(ABC): +class Worker(Generic[ItemT], ABC): def __init__( self, - queue: asyncio.Queue[PipelineItem], - heartbeater: Heartbeater, + queue: asyncio.Queue[ItemT], + heartbeater: Heartbeater[ItemT], ) -> None: self._queue = queue self._heartbeater = heartbeater @@ -325,7 +333,7 @@ def stop(self): self._running = False @abstractmethod - async def process(self, item: PipelineItem): + async def process(self, item: ItemT): pass diff --git a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py index 685c5205a..938c6013c 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py @@ -25,7 +25,7 @@ from dstack._internal.server.models import ComputeGroupModel, InstanceModel, ProjectModel from dstack._internal.server.services import backends as backends_services from dstack._internal.server.services.compute_groups import compute_group_model_to_compute_group -from dstack._internal.server.services.instances import switch_instance_status +from dstack._internal.server.services.instances import emit_instance_status_change_event from dstack._internal.server.services.locking import get_locker from dstack._internal.utils.common import get_current_datetime, run_async from dstack._internal.utils.logging import get_logger @@ -36,7 +36,7 @@ TERMINATION_RETRY_MAX_DURATION = timedelta(minutes=15) -class ComputeGroupPipeline(Pipeline): +class ComputeGroupPipeline(Pipeline[PipelineItem]): def __init__( self, workers_num: int = 10, @@ -54,7 +54,7 @@ def __init__( lock_timeout=lock_timeout, heartbeat_trigger=heartbeat_trigger, ) - self.__heartbeater = Heartbeater[ComputeGroupModel]( + self.__heartbeater = Heartbeater[PipelineItem]( model_type=ComputeGroupModel, lock_timeout=self._lock_timeout, heartbeat_trigger=self._heartbeat_trigger, @@ -76,11 +76,11 @@ def hint_fetch_model_name(self) -> str: return ComputeGroupModel.__name__ @property - def _heartbeater(self) -> Heartbeater: + def _heartbeater(self) -> Heartbeater[PipelineItem]: return self.__heartbeater @property - def _fetcher(self) -> Fetcher: + def _fetcher(self) -> Fetcher[PipelineItem]: return self.__fetcher @property @@ -88,14 +88,14 @@ def _workers(self) -> Sequence["ComputeGroupWorker"]: return self.__workers -class ComputeGroupFetcher(Fetcher): +class ComputeGroupFetcher(Fetcher[PipelineItem]): def __init__( self, queue: asyncio.Queue[PipelineItem], queue_desired_minsize: int, min_processing_interval: timedelta, lock_timeout: timedelta, - heartbeater: Heartbeater[ComputeGroupModel], + heartbeater: Heartbeater[PipelineItem], queue_check_delay: float = 1.0, ) -> None: super().__init__( @@ -161,11 +161,11 @@ async def fetch(self, limit: int) -> list[PipelineItem]: return items -class ComputeGroupWorker(Worker): +class ComputeGroupWorker(Worker[PipelineItem]): def __init__( self, queue: asyncio.Queue[PipelineItem], - heartbeater: Heartbeater[ComputeGroupModel], + heartbeater: Heartbeater[PipelineItem], ) -> None: super().__init__( queue=queue, @@ -235,7 +235,12 @@ async def process(self, item: PipelineItem): .values(**terminate_result.instances_update_map) ) for instance_model in compute_group_model.instances: - switch_instance_status(session, instance_model, InstanceStatus.TERMINATED) + emit_instance_status_change_event( + session=session, + instance_model=instance_model, + old_status=instance_model.status, + new_status=InstanceStatus.TERMINATED, + ) @dataclass diff --git a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py new file mode 100644 index 000000000..c64cd719a --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py @@ -0,0 +1,548 @@ +import asyncio +import uuid +from dataclasses import dataclass, field +from datetime import timedelta +from typing import Optional, Sequence + +from sqlalchemy import delete, or_, select, update +from sqlalchemy.orm import joinedload, load_only + +from dstack._internal.core.backends.base.compute import ComputeWithGatewaySupport +from dstack._internal.core.errors import BackendError, BackendNotAvailable +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.gateways import GatewayStatus +from dstack._internal.server.background.pipeline_tasks.base import ( + Fetcher, + Heartbeater, + Pipeline, + PipelineItem, + UpdateMap, + Worker, + get_processed_update_map, + get_unlock_update_map, +) +from dstack._internal.server.db import get_db, get_session_ctx +from dstack._internal.server.models import ( + BackendModel, + GatewayComputeModel, + GatewayModel, + ProjectModel, +) +from dstack._internal.server.services import backends as backends_services +from dstack._internal.server.services import events +from dstack._internal.server.services import gateways as gateways_services +from dstack._internal.server.services.gateways import emit_gateway_status_change_event +from dstack._internal.server.services.gateways.pool import gateway_connections_pool +from dstack._internal.server.services.locking import get_locker +from dstack._internal.server.services.logging import fmt +from dstack._internal.utils.common import get_current_datetime, run_async +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +@dataclass +class GatewayPipelineItem(PipelineItem): + status: GatewayStatus + to_be_deleted: bool + + +class GatewayPipeline(Pipeline[GatewayPipelineItem]): + def __init__( + self, + workers_num: int = 10, + queue_lower_limit_factor: float = 0.5, + queue_upper_limit_factor: float = 2.0, + min_processing_interval: timedelta = timedelta(seconds=15), + lock_timeout: timedelta = timedelta(seconds=30), + heartbeat_trigger: timedelta = timedelta(seconds=15), + ) -> None: + super().__init__( + workers_num=workers_num, + queue_lower_limit_factor=queue_lower_limit_factor, + queue_upper_limit_factor=queue_upper_limit_factor, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeat_trigger=heartbeat_trigger, + ) + self.__heartbeater = Heartbeater[GatewayPipelineItem]( + model_type=GatewayModel, + lock_timeout=self._lock_timeout, + heartbeat_trigger=self._heartbeat_trigger, + ) + self.__fetcher = GatewayFetcher( + queue=self._queue, + queue_desired_minsize=self._queue_desired_minsize, + min_processing_interval=self._min_processing_interval, + lock_timeout=self._lock_timeout, + heartbeater=self._heartbeater, + ) + self.__workers = [ + GatewayWorker(queue=self._queue, heartbeater=self._heartbeater) + for _ in range(self._workers_num) + ] + + @property + def hint_fetch_model_name(self) -> str: + return GatewayModel.__name__ + + @property + def _heartbeater(self) -> Heartbeater[GatewayPipelineItem]: + return self.__heartbeater + + @property + def _fetcher(self) -> Fetcher[GatewayPipelineItem]: + return self.__fetcher + + @property + def _workers(self) -> Sequence["GatewayWorker"]: + return self.__workers + + +class GatewayFetcher(Fetcher[GatewayPipelineItem]): + def __init__( + self, + queue: asyncio.Queue[GatewayPipelineItem], + queue_desired_minsize: int, + min_processing_interval: timedelta, + lock_timeout: timedelta, + heartbeater: Heartbeater[GatewayPipelineItem], + queue_check_delay: float = 1.0, + ) -> None: + super().__init__( + queue=queue, + queue_desired_minsize=queue_desired_minsize, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeater=heartbeater, + queue_check_delay=queue_check_delay, + ) + + async def fetch(self, limit: int) -> list[GatewayPipelineItem]: + gateway_lock, _ = get_locker(get_db().dialect_name).get_lockset(GatewayModel.__tablename__) + async with gateway_lock: + async with get_session_ctx() as session: + now = get_current_datetime() + res = await session.execute( + select(GatewayModel) + .where( + or_( + GatewayModel.status.in_( + [GatewayStatus.SUBMITTED, GatewayStatus.PROVISIONING] + ), + GatewayModel.to_be_deleted == True, + ), + or_( + GatewayModel.last_processed_at <= now - self._min_processing_interval, + GatewayModel.last_processed_at == GatewayModel.created_at, + ), + or_( + GatewayModel.lock_expires_at.is_(None), + GatewayModel.lock_expires_at < now, + ), + or_( + GatewayModel.lock_owner.is_(None), + GatewayModel.lock_owner == GatewayPipeline.__name__, + ), + ) + .order_by(GatewayModel.last_processed_at.asc()) + .limit(limit) + .with_for_update(skip_locked=True, key_share=True) + .options( + load_only( + GatewayModel.id, + GatewayModel.lock_token, + GatewayModel.lock_expires_at, + GatewayModel.status, + GatewayModel.to_be_deleted, + ) + ) + ) + gateway_models = list(res.scalars().all()) + lock_expires_at = get_current_datetime() + self._lock_timeout + lock_token = uuid.uuid4() + items = [] + for gateway_model in gateway_models: + prev_lock_expired = gateway_model.lock_expires_at is not None + gateway_model.lock_expires_at = lock_expires_at + gateway_model.lock_token = lock_token + gateway_model.lock_owner = GatewayPipeline.__name__ + items.append( + GatewayPipelineItem( + __tablename__=GatewayModel.__tablename__, + id=gateway_model.id, + lock_expires_at=lock_expires_at, + lock_token=lock_token, + prev_lock_expired=prev_lock_expired, + status=gateway_model.status, + to_be_deleted=gateway_model.to_be_deleted, + ) + ) + await session.commit() + return items + + +class GatewayWorker(Worker[GatewayPipelineItem]): + def __init__( + self, + queue: asyncio.Queue[GatewayPipelineItem], + heartbeater: Heartbeater[GatewayPipelineItem], + ) -> None: + super().__init__( + queue=queue, + heartbeater=heartbeater, + ) + + async def process(self, item: GatewayPipelineItem): + if item.to_be_deleted: + await _process_to_be_deleted_item(item) + elif item.status == GatewayStatus.SUBMITTED: + await _process_submitted_item(item) + elif item.status == GatewayStatus.PROVISIONING: + await _process_provisioning_item(item) + + +async def _process_submitted_item(item: GatewayPipelineItem): + async with get_session_ctx() as session: + res = await session.execute( + select(GatewayModel) + .where( + GatewayModel.id == item.id, + GatewayModel.lock_token == item.lock_token, + ) + .options(joinedload(GatewayModel.project).joinedload(ProjectModel.backends)) + .options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) + ) + gateway_model = res.unique().scalar_one_or_none() + if gateway_model is None: + logger.warning( + "Failed to process %s item %s: lock_token mismatch." + " The item is expected to be processed and updated on another fetch iteration.", + item.__tablename__, + item.id, + ) + return + + result = await _process_submitted_gateway(gateway_model) + update_map = result.update_map | get_processed_update_map() | get_unlock_update_map() + async with get_session_ctx() as session: + gateway_compute_model = result.gateway_compute_model + if gateway_compute_model is not None: + session.add(gateway_compute_model) + await session.flush() + update_map["gateway_compute_id"] = gateway_compute_model.id + res = await session.execute( + update(GatewayModel) + .where( + GatewayModel.id == gateway_model.id, + GatewayModel.lock_token == gateway_model.lock_token, + ) + .values(**update_map) + .returning(GatewayModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + logger.warning( + "Failed to update %s item %s after processing: lock_token changed." + " The item is expected to be processed and updated on another fetch iteration.", + item.__tablename__, + item.id, + ) + # TODO: Clean up gateway_compute_model. + return + emit_gateway_status_change_event( + session=session, + gateway_model=gateway_model, + old_status=gateway_model.status, + new_status=update_map.get("status", gateway_model.status), + status_message=update_map.get("status_message", gateway_model.status_message), + ) + + +@dataclass +class _SubmittedResult: + update_map: UpdateMap = field(default_factory=dict) + gateway_compute_model: Optional[GatewayComputeModel] = None + + +async def _process_submitted_gateway(gateway_model: GatewayModel) -> _SubmittedResult: + logger.info("%s: started gateway provisioning", fmt(gateway_model)) + configuration = gateways_services.get_gateway_configuration(gateway_model) + try: + ( + backend_model, + backend, + ) = await backends_services.get_project_backend_with_model_by_type_or_error( + project=gateway_model.project, backend_type=configuration.backend + ) + except BackendNotAvailable: + return _SubmittedResult( + update_map={ + "status": GatewayStatus.FAILED, + "status_message": "Backend not available", + } + ) + try: + gateway_compute_model = await gateways_services.create_gateway_compute( + backend_compute=backend.compute(), + project_name=gateway_model.project.name, + configuration=configuration, + backend_id=backend_model.id, + ) + return _SubmittedResult( + update_map={"status": GatewayStatus.PROVISIONING}, + gateway_compute_model=gateway_compute_model, + ) + except BackendError as e: + status_message = f"Backend error: {repr(e)}" + if len(e.args) > 0: + status_message = str(e.args[0]) + return _SubmittedResult( + update_map={ + "status": GatewayStatus.FAILED, + "status_message": status_message, + } + ) + except Exception as e: + logger.exception("%s: got exception when creating gateway compute", fmt(gateway_model)) + return _SubmittedResult( + update_map={ + "status": GatewayStatus.FAILED, + "status_message": f"Unexpected error: {repr(e)}", + } + ) + + +async def _process_provisioning_item(item: GatewayPipelineItem): + async with get_session_ctx() as session: + res = await session.execute( + select(GatewayModel) + .where( + GatewayModel.id == item.id, + GatewayModel.lock_token == item.lock_token, + ) + .options(joinedload(GatewayModel.gateway_compute)) + ) + gateway_model = res.unique().scalar_one_or_none() + if gateway_model is None: + logger.warning( + "Failed to process %s item %s: lock_token mismatch." + " The item is expected to be processed and updated on another fetch iteration.", + item.__tablename__, + item.id, + ) + return + + result = await _process_provisioning_gateway(gateway_model) + update_map = result.gateway_update_map | get_processed_update_map() | get_unlock_update_map() + async with get_session_ctx() as session: + res = await session.execute( + update(GatewayModel) + .where( + GatewayModel.id == gateway_model.id, + GatewayModel.lock_token == gateway_model.lock_token, + ) + .values(**update_map) + .returning(GatewayModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + logger.warning( + "Failed to update %s item %s after processing: lock_token changed." + " The item is expected to be processed and updated on another fetch iteration.", + item.__tablename__, + item.id, + ) + return + emit_gateway_status_change_event( + session=session, + gateway_model=gateway_model, + old_status=gateway_model.status, + new_status=update_map.get("status", gateway_model.status), + status_message=update_map.get("status_message", gateway_model.status_message), + ) + if result.gateway_compute_update_map: + res = await session.execute( + update(GatewayComputeModel) + .where(GatewayComputeModel.id == gateway_model.gateway_compute_id) + .values(**result.gateway_compute_update_map) + .returning(GatewayComputeModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + logger.error( + "Failed to update compute model %s for gateway %s." + " This is unexpected and may happen only if the compute model was manually deleted.", + gateway_model.id, + item.id, + ) + + +@dataclass +class _ProvisioningResult: + gateway_update_map: UpdateMap = field(default_factory=dict) + gateway_compute_update_map: UpdateMap = field(default_factory=dict) + + +async def _process_provisioning_gateway(gateway_model: GatewayModel) -> _ProvisioningResult: + # Provisioning gateways must have compute. + assert gateway_model.gateway_compute is not None + + # FIXME: problems caused by blocking on connect_to_gateway_with_retry and configure_gateway: + # - cannot delete the gateway before it is provisioned because the DB model is locked + # - connection retry counter is reset on server restart + # - only one server replica is processing the gateway + # Easy to fix by doing only one connection/configuration attempt per processing iteration. The + # main challenge is applying the same provisioning model to the dstack Sky gateway to avoid + # maintaining a different model for Sky. + connection = await gateways_services.connect_to_gateway_with_retry( + gateway_model.gateway_compute + ) + if connection is None: + return _ProvisioningResult( + gateway_update_map={ + "status": GatewayStatus.FAILED, + "status_message": "Failed to connect to gateway", + }, + gateway_compute_update_map={"active": False}, + ) + try: + await gateways_services.configure_gateway(connection) + except Exception: + logger.exception("%s: failed to configure gateway", fmt(gateway_model)) + await gateway_connections_pool.remove(gateway_model.gateway_compute.ip_address) + return _ProvisioningResult( + gateway_update_map={ + "status": GatewayStatus.FAILED, + "status_message": "Failed to configure gateway", + }, + gateway_compute_update_map={"active": False}, + ) + return _ProvisioningResult( + gateway_update_map={"status": GatewayStatus.RUNNING}, + ) + + +async def _process_to_be_deleted_item(item: GatewayPipelineItem): + async with get_session_ctx() as session: + res = await session.execute( + select(GatewayModel) + .where( + GatewayModel.id == item.id, + GatewayModel.lock_token == item.lock_token, + ) + .options(joinedload(GatewayModel.project).joinedload(ProjectModel.backends)) + .options(joinedload(GatewayModel.gateway_compute)) + .options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) + ) + gateway_model = res.unique().scalar_one_or_none() + if gateway_model is None: + logger.warning( + "Failed to process %s item %s: lock_token mismatch." + " The item is expected to be processed and updated on another fetch iteration.", + item.__tablename__, + item.id, + ) + return + + result = await _process_to_be_deleted_gateway(gateway_model) + async with get_session_ctx() as session: + if result.delete_gateway: + res = await session.execute( + delete(GatewayModel) + .where( + GatewayModel.id == gateway_model.id, + GatewayModel.lock_token == gateway_model.lock_token, + ) + .returning(GatewayModel.id) + ) + deleted_ids = list(res.scalars().all()) + if len(deleted_ids) == 0: + logger.warning( + "Failed to delete %s item %s after processing: lock_token changed." + " The item is expected to be processed and deleted on another fetch iteration.", + item.__tablename__, + item.id, + ) + return + events.emit( + session, + "Gateway deleted", + actor=events.SystemActor(), + targets=[events.Target.from_model(gateway_model)], + ) + else: + res = await session.execute( + update(GatewayModel) + .where( + GatewayModel.id == gateway_model.id, + GatewayModel.lock_token == gateway_model.lock_token, + ) + .values(**get_processed_update_map()) + .returning(GatewayModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + logger.warning( + "Failed to update %s item %s after processing: lock_token changed." + " The item is expected to be processed and updated on another fetch iteration.", + item.__tablename__, + item.id, + ) + return + + if result.gateway_compute_update_map: + res = await session.execute( + update(GatewayComputeModel) + .where(GatewayComputeModel.id == gateway_model.gateway_compute_id) + .values(**result.gateway_compute_update_map) + .returning(GatewayComputeModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + logger.error( + "Failed to update compute model %s for gateway %s." + " This is unexpected and may happen only if the compute model was manually deleted.", + gateway_model.id, + item.id, + ) + return + + +@dataclass +class _DeletedResult: + delete_gateway: bool + gateway_compute_update_map: UpdateMap = field(default_factory=dict) + + +async def _process_to_be_deleted_gateway(gateway_model: GatewayModel) -> _DeletedResult: + assert gateway_model.backend.type != BackendType.DSTACK + backend = await backends_services.get_project_backend_by_type_or_error( + project=gateway_model.project, backend_type=gateway_model.backend.type + ) + compute = backend.compute() + assert isinstance(compute, ComputeWithGatewaySupport) + gateway_compute_configuration = gateways_services.get_gateway_compute_configuration( + gateway_model + ) + if gateway_model.gateway_compute is not None and gateway_compute_configuration is not None: + logger.info("Deleting gateway compute for %s...", gateway_model.name) + try: + await run_async( + compute.terminate_gateway, + gateway_model.gateway_compute.instance_id, + gateway_compute_configuration, + gateway_model.gateway_compute.backend_data, + ) + except Exception: + logger.exception( + "Error when deleting gateway compute for %s", + gateway_model.name, + ) + return _DeletedResult(delete_gateway=False) + logger.info("Deleted gateway compute for %s", gateway_model.name) + result = _DeletedResult(delete_gateway=True) + if gateway_model.gateway_compute is not None: + await gateway_connections_pool.remove(gateway_model.gateway_compute.ip_address) + result.gateway_compute_update_map = {"active": False, "deleted": True} + return result diff --git a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py index 9fac5665a..a184379c3 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py @@ -32,7 +32,7 @@ logger = get_logger(__name__) -class PlacementGroupPipeline(Pipeline): +class PlacementGroupPipeline(Pipeline[PipelineItem]): def __init__( self, workers_num: int = 10, @@ -50,7 +50,7 @@ def __init__( lock_timeout=lock_timeout, heartbeat_trigger=heartbeat_trigger, ) - self.__heartbeater = Heartbeater[PlacementGroupModel]( + self.__heartbeater = Heartbeater[PipelineItem]( model_type=PlacementGroupModel, lock_timeout=self._lock_timeout, heartbeat_trigger=self._heartbeat_trigger, @@ -72,11 +72,11 @@ def hint_fetch_model_name(self) -> str: return PlacementGroupModel.__name__ @property - def _heartbeater(self) -> Heartbeater: + def _heartbeater(self) -> Heartbeater[PipelineItem]: return self.__heartbeater @property - def _fetcher(self) -> Fetcher: + def _fetcher(self) -> Fetcher[PipelineItem]: return self.__fetcher @property @@ -84,14 +84,14 @@ def _workers(self) -> Sequence["PlacementGroupWorker"]: return self.__workers -class PlacementGroupFetcher(Fetcher): +class PlacementGroupFetcher(Fetcher[PipelineItem]): def __init__( self, queue: asyncio.Queue[PipelineItem], queue_desired_minsize: int, min_processing_interval: timedelta, lock_timeout: timedelta, - heartbeater: Heartbeater[PlacementGroupModel], + heartbeater: Heartbeater[PipelineItem], queue_check_delay: float = 1.0, ) -> None: super().__init__( @@ -159,11 +159,11 @@ async def fetch(self, limit: int) -> list[PipelineItem]: return items -class PlacementGroupWorker(Worker): +class PlacementGroupWorker(Worker[PipelineItem]): def __init__( self, queue: asyncio.Queue[PipelineItem], - heartbeater: Heartbeater[PlacementGroupModel], + heartbeater: Heartbeater[PipelineItem], ) -> None: super().__init__( queue=queue, diff --git a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py index c4baf96c5..6067d9d4d 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py @@ -99,21 +99,23 @@ def start_scheduled_tasks() -> AsyncIOScheduler: ) _scheduler.add_job(delete_prometheus_metrics, IntervalTrigger(minutes=5), max_instances=1) _scheduler.add_job(process_gateways_connections, IntervalTrigger(seconds=15)) - _scheduler.add_job(process_gateways, IntervalTrigger(seconds=10, jitter=2), max_instances=5) _scheduler.add_job( process_submitted_volumes, IntervalTrigger(seconds=10, jitter=2), max_instances=5 ) _scheduler.add_job( process_idle_volumes, IntervalTrigger(seconds=60, jitter=10), max_instances=1 ) - if not FeatureFlags.PIPELINE_PROCESSING_ENABLED: - _scheduler.add_job(process_placement_groups, IntervalTrigger(seconds=30, jitter=5)) _scheduler.add_job( process_fleets, IntervalTrigger(seconds=10, jitter=2), max_instances=1, ) _scheduler.add_job(delete_instance_health_checks, IntervalTrigger(minutes=5), max_instances=1) + if not FeatureFlags.PIPELINE_PROCESSING_ENABLED: + _scheduler.add_job( + process_gateways, IntervalTrigger(seconds=10, jitter=2), max_instances=5 + ) + _scheduler.add_job(process_placement_groups, IntervalTrigger(seconds=30, jitter=5)) for replica in range(settings.SERVER_BACKGROUND_PROCESSING_FACTOR): # Add multiple copies of tasks if requested. # max_instances=1 for additional copies to avoid running too many tasks. diff --git a/src/dstack/_internal/server/background/scheduled_tasks/gateways.py b/src/dstack/_internal/server/background/scheduled_tasks/gateways.py index 2566a4f4d..3b6bee012 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/gateways.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/gateways.py @@ -7,7 +7,12 @@ from dstack._internal.core.errors import BackendError, BackendNotAvailable, SSHError from dstack._internal.core.models.gateways import GatewayStatus from dstack._internal.server.db import get_db, get_session_ctx -from dstack._internal.server.models import GatewayComputeModel, GatewayModel, ProjectModel +from dstack._internal.server.models import ( + BackendModel, + GatewayComputeModel, + GatewayModel, + ProjectModel, +) from dstack._internal.server.services import backends as backends_services from dstack._internal.server.services import gateways as gateways_services from dstack._internal.server.services.gateways import ( @@ -109,6 +114,7 @@ async def _process_submitted_gateway(session: AsyncSession, gateway_model: Gatew select(GatewayModel) .where(GatewayModel.id == gateway_model.id) .options(joinedload(GatewayModel.project).joinedload(ProjectModel.backends)) + .options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) .execution_options(populate_existing=True) ) gateway_model = res.unique().scalar_one() @@ -153,6 +159,7 @@ async def _process_provisioning_gateway( res = await session.execute( select(GatewayModel) .where(GatewayModel.id == gateway_model.id) + .options(joinedload(GatewayModel.gateway_compute)) .execution_options(populate_existing=True) ) gateway_model = res.unique().scalar_one() diff --git a/src/dstack/_internal/server/migrations/versions/2026/02_23_0548_140331002ece_add_gatewaymodel_pipeline_and_to_be_.py b/src/dstack/_internal/server/migrations/versions/2026/02_23_0548_140331002ece_add_gatewaymodel_pipeline_and_to_be_.py new file mode 100644 index 000000000..fa3c8ce30 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/02_23_0548_140331002ece_add_gatewaymodel_pipeline_and_to_be_.py @@ -0,0 +1,51 @@ +"""Add GatewayModel pipeline and to_be_deleted columns + +Revision ID: 140331002ece +Revises: a8ed24fd7f90 +Create Date: 2026-02-23 05:48:55.948838+00:00 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +import dstack._internal.server.models + +# revision identifiers, used by Alembic. +revision = "140331002ece" +down_revision = "a8ed24fd7f90" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("gateways", schema=None) as batch_op: + batch_op.add_column( + sa.Column("to_be_deleted", sa.Boolean(), server_default=sa.false(), nullable=False) + ) + batch_op.add_column( + sa.Column( + "lock_expires_at", dstack._internal.server.models.NaiveDateTime(), nullable=True + ) + ) + batch_op.add_column( + sa.Column( + "lock_token", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True + ) + ) + batch_op.add_column(sa.Column("lock_owner", sa.String(length=100), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("gateways", schema=None) as batch_op: + batch_op.drop_column("lock_owner") + batch_op.drop_column("lock_token") + batch_op.drop_column("lock_expires_at") + batch_op.drop_column("to_be_deleted") + + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index a837137a1..df9cf8607 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -492,7 +492,7 @@ class JobModel(BaseModel): waiting_master_job: Mapped[Optional[bool]] = mapped_column(Boolean) -class GatewayModel(BaseModel): +class GatewayModel(PipelineModelMixin, BaseModel): __tablename__ = "gateways" id: Mapped[uuid.UUID] = mapped_column( @@ -508,21 +508,24 @@ class GatewayModel(BaseModel): status: Mapped[GatewayStatus] = mapped_column(EnumAsString(GatewayStatus, 100)) status_message: Mapped[Optional[str]] = mapped_column(Text) last_processed_at: Mapped[datetime] = mapped_column(NaiveDateTime) + to_be_deleted: Mapped[bool] = mapped_column(Boolean, server_default=false()) project_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE")) project: Mapped["ProjectModel"] = relationship(foreign_keys=[project_id]) backend_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("backends.id", ondelete="CASCADE")) - backend: Mapped["BackendModel"] = relationship(lazy="selectin") + backend: Mapped["BackendModel"] = relationship() gateway_compute_id: Mapped[Optional[uuid.UUID]] = mapped_column( ForeignKey("gateway_computes.id", ondelete="CASCADE") ) - gateway_compute: Mapped[Optional["GatewayComputeModel"]] = relationship(lazy="joined") + gateway_compute: Mapped[Optional["GatewayComputeModel"]] = relationship() runs: Mapped[List["RunModel"]] = relationship(back_populates="gateway") __table_args__ = (UniqueConstraint("project_id", "name", name="uq_gateways_project_id_name"),) + # TODO: Add pipeline index ("ix_gateways_pipeline_fetch_q") if gateways become soft-deleted. + class GatewayComputeModel(BaseModel): __tablename__ = "gateway_computes" diff --git a/src/dstack/_internal/server/routers/gateways.py b/src/dstack/_internal/server/routers/gateways.py index 0f89e5db4..af4557a44 100644 --- a/src/dstack/_internal/server/routers/gateways.py +++ b/src/dstack/_internal/server/routers/gateways.py @@ -13,6 +13,7 @@ ProjectAdmin, ProjectMemberOrPublicAccess, ) +from dstack._internal.server.services.pipelines import PipelineHinterProtocol, get_pipeline_hinter from dstack._internal.server.utils.routers import ( CustomORJSONResponse, get_base_api_additional_responses, @@ -54,6 +55,7 @@ async def create_gateway( body: schemas.CreateGatewayRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), + pipeline_hinter: PipelineHinterProtocol = Depends(get_pipeline_hinter), ): user, project = user_project return CustomORJSONResponse( @@ -62,6 +64,7 @@ async def create_gateway( user=user, project=project, configuration=body.configuration, + pipeline_hinter=pipeline_hinter, ) ) diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index ab89c2a7c..762af8bef 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -10,7 +10,7 @@ import httpx from sqlalchemy import func, select, update from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload +from sqlalchemy.orm import joinedload import dstack._internal.utils.random_names as random_names from dstack._internal.core.backends.base.compute import ( @@ -42,6 +42,7 @@ from dstack._internal.server import settings from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite from dstack._internal.server.models import ( + BackendModel, GatewayComputeModel, GatewayModel, ProjectModel, @@ -60,8 +61,10 @@ get_locker, string_to_lock_id, ) +from dstack._internal.server.services.pipelines import PipelineHinterProtocol from dstack._internal.server.services.plugins import apply_plugin_policies from dstack._internal.server.utils.common import gather_map_async +from dstack._internal.settings import FeatureFlags from dstack._internal.utils.common import get_current_datetime, run_async from dstack._internal.utils.crypto import generate_rsa_key_pair_bytes from dstack._internal.utils.logging import get_logger @@ -80,13 +83,43 @@ def switch_gateway_status( return gateway_model.status = new_status + emit_gateway_status_change_event( + session=session, + gateway_model=gateway_model, + old_status=old_status, + new_status=new_status, + status_message=gateway_model.status_message, + actor=actor, + ) - msg = f"Gateway status changed {old_status.upper()} -> {new_status.upper()}" - if gateway_model.status_message is not None: - msg += f" ({gateway_model.status_message})" + +def emit_gateway_status_change_event( + session: AsyncSession, + gateway_model: GatewayModel, + old_status: GatewayStatus, + new_status: GatewayStatus, + status_message: Optional[str], + actor: events.AnyActor = events.SystemActor(), +) -> None: + if old_status == new_status: + return + msg = get_gateway_status_change_message( + old_status=old_status, + new_status=new_status, + status_message=status_message, + ) events.emit(session, msg, actor=actor, targets=[events.Target.from_model(gateway_model)]) +def get_gateway_status_change_message( + old_status: GatewayStatus, new_status: GatewayStatus, status_message: Optional[str] +) -> str: + msg = f"Gateway status changed {old_status.upper()} -> {new_status.upper()}" + if status_message is not None: + msg += f" ({status_message})" + return msg + + GATEWAY_CONNECT_ATTEMPTS = 30 GATEWAY_CONNECT_DELAY = 10 GATEWAY_CONFIGURE_ATTEMPTS = 50 @@ -94,14 +127,25 @@ def switch_gateway_status( async def list_project_gateways(session: AsyncSession, project: ProjectModel) -> List[Gateway]: - gateways = await list_project_gateway_models(session=session, project=project) + gateways = await list_project_gateway_models( + session=session, + project=project, + load_gateway_compute=True, + load_backend_type=True, + ) return [gateway_model_to_gateway(g) for g in gateways] async def get_gateway_by_name( session: AsyncSession, project: ProjectModel, name: str ) -> Optional[Gateway]: - gateway = await get_project_gateway_model_by_name(session=session, project=project, name=name) + gateway = await get_project_gateway_model_by_name( + session=session, + project=project, + name=name, + load_gateway_compute=True, + load_backend_type=True, + ) if gateway is None: return None return gateway_model_to_gateway(gateway) @@ -156,6 +200,7 @@ async def create_gateway( user: UserModel, project: ProjectModel, configuration: GatewayConfiguration, + pipeline_hinter: PipelineHinterProtocol, ) -> Gateway: spec = await apply_plugin_policies( user=user.name, @@ -183,6 +228,7 @@ async def create_gateway( if configuration.name is None: configuration.name = await generate_gateway_name(session=session, project=project) + now = get_current_datetime() gateway = GatewayModel( id=uuid.uuid4(), name=configuration.name, @@ -192,7 +238,8 @@ async def create_gateway( wildcard_domain=configuration.domain, configuration=configuration.json(), status=GatewayStatus.SUBMITTED, - last_processed_at=get_current_datetime(), + created_at=now, + last_processed_at=now, ) session.add(gateway) events.emit( @@ -208,6 +255,15 @@ async def create_gateway( await set_default_gateway( session=session, project=project, name=configuration.name, user=user ) + pipeline_hinter.hint_fetch(GatewayModel.__name__) + gateway = await get_project_gateway_model_by_name( + session=session, + project=project, + name=configuration.name, + load_gateway_compute=True, + load_backend_type=True, + ) + assert gateway is not None return gateway_model_to_gateway(gateway) @@ -245,6 +301,86 @@ async def delete_gateways( project: ProjectModel, gateways_names: List[str], user: UserModel, +): + # Keep both delete code paths while pipeline processing is behind a feature flag: + # - pipeline path marks gateways for async deletion by GatewayPipeline + # - sync path deletes gateway resources inline for non-pipeline processing + # TODO: Drop sync path after pipeline processing is enabled by default. + if FeatureFlags.PIPELINE_PROCESSING_ENABLED: + await _delete_gateways_pipeline( + session=session, + project=project, + gateways_names=gateways_names, + user=user, + ) + else: + await _delete_gateways_sync( + session=session, + project=project, + gateways_names=gateways_names, + user=user, + ) + + +async def _delete_gateways_pipeline( + session: AsyncSession, + project: ProjectModel, + gateways_names: List[str], + user: UserModel, +): + res = await session.execute( + select(GatewayModel).where( + GatewayModel.project_id == project.id, + GatewayModel.name.in_(gateways_names), + ) + ) + gateway_models = res.scalars().all() + gateways_ids = sorted([g.id for g in gateway_models]) + await session.commit() + logger.info("Deleting gateways: %s", [g.name for g in gateway_models]) + async with get_locker(get_db().dialect_name).lock_ctx( + GatewayModel.__tablename__, gateways_ids + ): + # Refetch after lock + res = await session.execute( + select(GatewayModel) + .where( + GatewayModel.id.in_(gateways_ids), + GatewayModel.project_id == project.id, + GatewayModel.lock_expires_at.is_(None), + ) + .options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) + .order_by(GatewayModel.id) # take locks in order + .with_for_update(key_share=True, nowait=True, of=GatewayModel) + .execution_options(populate_existing=True) + ) + gateway_models = res.scalars().all() + if len(gateway_models) != len(gateways_ids): + # TODO: Make the delete endpoint fully async so we don't need to lock and error: + # put the request in queue and process in the background. + raise ServerClientError( + "Failed to delete gateways: gateways are being processed currently. Try again later." + ) + for gateway_model in gateway_models: + if gateway_model.backend.type == BackendType.DSTACK: + raise ServerClientError("Cannot delete dstack Sky gateway") + for gateway_model in gateway_models: + if not gateway_model.to_be_deleted: + gateway_model.to_be_deleted = True + events.emit( + session, + "Gateway marked for deletion", + actor=events.UserActor.from_user(user), + targets=[events.Target.from_model(gateway_model)], + ) + await session.commit() + + +async def _delete_gateways_sync( + session: AsyncSession, + project: ProjectModel, + gateways_names: List[str], + user: UserModel, ): res = await session.execute( select(GatewayModel).where( @@ -266,10 +402,11 @@ async def delete_gateways( GatewayModel.project_id == project.id, GatewayModel.name.in_(gateways_names), ) - .options(selectinload(GatewayModel.gateway_compute)) + .options(joinedload(GatewayModel.gateway_compute)) + .options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) .execution_options(populate_existing=True) .order_by(GatewayModel.id) # take locks in order - .with_for_update(key_share=True) + .with_for_update(key_share=True, of=GatewayModel) ) gateway_models = res.scalars().all() for gateway_model in gateway_models: @@ -346,6 +483,8 @@ async def set_default_gateway( gateway = await get_project_gateway_model_by_name(session=session, project=project, name=name) if gateway is None: raise ResourceNotExistsError() + if gateway.to_be_deleted: + raise ServerClientError("Cannot set gateway marked for deletion as default") if project.default_gateway_id == gateway.id: return previous_gateway = await get_project_default_gateway_model(session, project) @@ -375,20 +514,36 @@ async def set_default_gateway( async def list_project_gateway_models( - session: AsyncSession, project: ProjectModel + session: AsyncSession, + project: ProjectModel, + load_gateway_compute: bool = False, + load_backend_type: bool = False, ) -> Sequence[GatewayModel]: - res = await session.execute(select(GatewayModel).where(GatewayModel.project_id == project.id)) + stmt = select(GatewayModel).where(GatewayModel.project_id == project.id) + if load_gateway_compute: + stmt = stmt.options(joinedload(GatewayModel.gateway_compute)) + if load_backend_type: + stmt = stmt.options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) + res = await session.execute(stmt) return res.scalars().all() async def get_project_gateway_model_by_name( - session: AsyncSession, project: ProjectModel, name: str + session: AsyncSession, + project: ProjectModel, + name: str, + load_gateway_compute: bool = False, + load_backend_type: bool = False, ) -> Optional[GatewayModel]: - res = await session.execute( - select(GatewayModel).where( - GatewayModel.project_id == project.id, GatewayModel.name == name - ) + stmt = select(GatewayModel).where( + GatewayModel.project_id == project.id, + GatewayModel.name == name, ) + if load_gateway_compute: + stmt = stmt.options(joinedload(GatewayModel.gateway_compute)) + if load_backend_type: + stmt = stmt.options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) + res = await session.execute(stmt) return res.scalar() @@ -419,17 +574,28 @@ async def get_project_gateway_model_by_name_for_update( res = await session.execute( select(GatewayModel) .where(GatewayModel.id.in_([gateway_id]), *filters) + .options(joinedload(GatewayModel.gateway_compute)) + .options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) .with_for_update(key_share=True, of=GatewayModel) ) yield res.scalar_one_or_none() async def get_project_default_gateway_model( - session: AsyncSession, project: ProjectModel + session: AsyncSession, + project: ProjectModel, + load_gateway_compute: bool = False, + load_backend_type: bool = False, ) -> Optional[GatewayModel]: - res = await session.execute( - select(GatewayModel).where(GatewayModel.id == project.default_gateway_id) + stmt = select(GatewayModel).where( + GatewayModel.id == project.default_gateway_id, + GatewayModel.to_be_deleted == False, ) + if load_gateway_compute: + stmt = stmt.options(joinedload(GatewayModel.gateway_compute)) + if load_backend_type: + stmt = stmt.options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) + res = await session.execute(stmt) return res.scalar_one_or_none() @@ -445,7 +611,12 @@ async def generate_gateway_name(session: AsyncSession, project: ProjectModel) -> async def get_or_add_gateway_connection( session: AsyncSession, gateway_id: uuid.UUID ) -> tuple[GatewayModel, GatewayConnection]: - gateway = await session.get(GatewayModel, gateway_id) + gateway = await session.get( + GatewayModel, + gateway_id, + options=[joinedload(GatewayModel.gateway_compute)], + populate_existing=True, + ) if gateway is None: raise GatewayError("Gateway not found") if gateway.gateway_compute is None: diff --git a/src/dstack/_internal/server/services/instances.py b/src/dstack/_internal/server/services/instances.py index 8506ad273..f37e1c968 100644 --- a/src/dstack/_internal/server/services/instances.py +++ b/src/dstack/_internal/server/services/instances.py @@ -82,9 +82,36 @@ def switch_instance_status( old_status = instance_model.status if old_status == new_status: return - instance_model.status = new_status + emit_instance_status_change_event( + session=session, + instance_model=instance_model, + old_status=old_status, + new_status=new_status, + actor=actor, + ) + +def emit_instance_status_change_event( + session: AsyncSession, + instance_model: InstanceModel, + old_status: InstanceStatus, + new_status: InstanceStatus, + actor: events.AnyActor = events.SystemActor(), +) -> None: + if old_status == new_status: + return + msg = get_instance_status_change_message( + instance_model=instance_model, + old_status=old_status, + new_status=new_status, + ) + events.emit(session, msg, actor=actor, targets=[events.Target.from_model(instance_model)]) + + +def get_instance_status_change_message( + instance_model: InstanceModel, old_status: InstanceStatus, new_status: InstanceStatus +) -> str: msg = f"Instance status changed {old_status.upper()} -> {new_status.upper()}" if ( new_status == InstanceStatus.TERMINATING @@ -105,7 +132,7 @@ def switch_instance_status( msg += f". Termination reason: {instance_model.termination_reason.upper()}" if instance_model.termination_reason_message: msg += f" ({instance_model.termination_reason_message})" - events.emit(session, msg, actor=actor, targets=[events.Target.from_model(instance_model)]) + return msg def format_instance_blocks_for_event(instance_model: InstanceModel) -> str: diff --git a/src/dstack/_internal/server/services/pipelines.py b/src/dstack/_internal/server/services/pipelines.py index 19f4df902..cbe2a2874 100644 --- a/src/dstack/_internal/server/services/pipelines.py +++ b/src/dstack/_internal/server/services/pipelines.py @@ -5,8 +5,26 @@ class PipelineHinterProtocol(Protocol): def hint_fetch(self, model_name: str) -> None: + """ + Pass `Model.__name__` to hint replica's pipelines to fetch the model's items ASAP. + """ pass +class _NoopPipelineHinter: + def hint_fetch(self, model_name: str) -> None: + pass + + +_noop_pipeline_hinter = _NoopPipelineHinter() + + def get_pipeline_hinter(request: Request) -> PipelineHinterProtocol: - return request.app.state.pipeline_manager.hinter + """ + Returns pipeline hinter that allows hinting replica's pipelines that there are new items for processing. + This can reduce processing latency if the processing happens rarely. + """ + pipeline_manager = getattr(request.app.state, "pipeline_manager", None) + if pipeline_manager is None: + return _noop_pipeline_hinter + return pipeline_manager.hinter diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index b701b822b..8ba424c56 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -90,17 +90,28 @@ async def register_service(session: AsyncSession, run_model: RunModel, run_spec: if isinstance(run_spec.configuration.gateway, str): gateway = await get_project_gateway_model_by_name( - session=session, project=run_model.project, name=run_spec.configuration.gateway + session=session, + project=run_model.project, + name=run_spec.configuration.gateway, + load_gateway_compute=True, + load_backend_type=True, ) if gateway is None: raise ResourceNotExistsError( f"Gateway {run_spec.configuration.gateway} does not exist" ) + if gateway.to_be_deleted: + raise ResourceNotExistsError( + f"Gateway {run_spec.configuration.gateway} was marked for deletion" + ) elif run_spec.configuration.gateway == False: gateway = None else: gateway = await get_project_default_gateway_model( - session=session, project=run_model.project + session=session, + project=run_model.project, + load_gateway_compute=True, + load_backend_type=True, ) if gateway is None and run_spec.configuration.gateway == True: raise ResourceNotExistsError( diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_base.py b/src/tests/_internal/server/background/pipeline_tasks/test_base.py index 7e84d9f80..303fb0854 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_base.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_base.py @@ -21,7 +21,7 @@ def now() -> datetime: @pytest.fixture -def heartbeater() -> Heartbeater[PlacementGroupModel]: +def heartbeater() -> Heartbeater[PipelineItem]: return Heartbeater( model_type=PlacementGroupModel, lock_timeout=timedelta(seconds=30), @@ -63,7 +63,7 @@ def _placement_group_to_pipeline_item(placement_group: PlacementGroupModel) -> P class TestHeartbeater: @pytest.mark.asyncio async def test_untrack_preserves_item_when_lock_token_mismatches( - self, heartbeater: Heartbeater[PlacementGroupModel], now: datetime + self, heartbeater: Heartbeater[PipelineItem], now: datetime ): item = PipelineItem( __tablename__=PlacementGroupModel.__tablename__, @@ -93,7 +93,7 @@ async def test_heartbeat_extends_locks_close_to_expiration( self, test_db, session: AsyncSession, - heartbeater: Heartbeater[PlacementGroupModel], + heartbeater: Heartbeater[PipelineItem], now: datetime, ): placement_group = await _create_locked_placement_group( @@ -122,7 +122,7 @@ async def test_heartbeat_untracks_expired_items_without_db_update( self, test_db, session: AsyncSession, - heartbeater: Heartbeater[PlacementGroupModel], + heartbeater: Heartbeater[PipelineItem], now: datetime, ): original_lock_expires_at = now - timedelta(seconds=1) @@ -150,7 +150,7 @@ async def test_heartbeat_untracks_item_when_lock_token_changed_in_db( self, test_db, session: AsyncSession, - heartbeater: Heartbeater[PlacementGroupModel], + heartbeater: Heartbeater[PipelineItem], now: datetime, ): original_lock_expires_at = now + timedelta(seconds=2) diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py b/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py new file mode 100644 index 000000000..9628451bd --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py @@ -0,0 +1,292 @@ +import uuid +from datetime import datetime, timezone +from unittest.mock import MagicMock, Mock, patch + +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload + +from dstack._internal.core.errors import BackendError +from dstack._internal.core.models.gateways import GatewayProvisioningData, GatewayStatus +from dstack._internal.server.background.pipeline_tasks.gateways import ( + GatewayPipelineItem, + GatewayWorker, +) +from dstack._internal.server.models import GatewayModel +from dstack._internal.server.testing.common import ( + AsyncContextManager, + ComputeMockSpec, + create_backend, + create_gateway, + create_gateway_compute, + create_project, + list_events, +) + + +@pytest.fixture +def worker() -> GatewayWorker: + return GatewayWorker(queue=Mock(), heartbeater=Mock()) + + +def _gateway_to_pipeline_item(gateway_model: GatewayModel) -> GatewayPipelineItem: + assert gateway_model.lock_token is not None + assert gateway_model.lock_expires_at is not None + return GatewayPipelineItem( + __tablename__=gateway_model.__tablename__, + id=gateway_model.id, + lock_token=gateway_model.lock_token, + lock_expires_at=gateway_model.lock_expires_at, + prev_lock_expired=False, + status=gateway_model.status, + to_be_deleted=gateway_model.to_be_deleted, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestGatewayWorkerSubmitted: + async def test_submitted_to_provisioning( + self, test_db, session: AsyncSession, worker: GatewayWorker + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.SUBMITTED, + ) + gateway.lock_token = uuid.uuid4() + gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + with patch( + "dstack._internal.server.services.backends.get_project_backend_with_model_by_type_or_error" + ) as m: + aws = Mock() + m.return_value = (backend, aws) + aws.compute.return_value = Mock(spec=ComputeMockSpec) + aws.compute.return_value.create_gateway.return_value = GatewayProvisioningData( + instance_id="i-1234567890", + ip_address="2.2.2.2", + region="us", + ) + await worker.process(_gateway_to_pipeline_item(gateway)) + m.assert_called_once() + aws.compute.return_value.create_gateway.assert_called_once() + + await session.refresh(gateway) + res = await session.execute( + select(GatewayModel) + .where(GatewayModel.id == gateway.id) + .options(joinedload(GatewayModel.gateway_compute)) + ) + gateway = res.unique().scalar_one() + assert gateway.status == GatewayStatus.PROVISIONING + assert gateway.gateway_compute is not None + assert gateway.gateway_compute.ip_address == "2.2.2.2" + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Gateway status changed SUBMITTED -> PROVISIONING" + + async def test_marks_gateway_as_failed_if_gateway_creation_errors( + self, test_db, session: AsyncSession, worker: GatewayWorker + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.SUBMITTED, + ) + gateway.lock_token = uuid.uuid4() + gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + with patch( + "dstack._internal.server.services.backends.get_project_backend_with_model_by_type_or_error" + ) as m: + aws = Mock() + m.return_value = (backend, aws) + aws.compute.return_value = Mock(spec=ComputeMockSpec) + aws.compute.return_value.create_gateway.side_effect = BackendError("Some error") + await worker.process(_gateway_to_pipeline_item(gateway)) + m.assert_called_once() + aws.compute.return_value.create_gateway.assert_called_once() + + await session.refresh(gateway) + assert gateway.status == GatewayStatus.FAILED + assert gateway.status_message == "Some error" + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Gateway status changed SUBMITTED -> FAILED (Some error)" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestGatewayWorkerProvisioning: + async def test_provisioning_to_running( + self, test_db, session: AsyncSession, worker: GatewayWorker + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway_compute = await create_gateway_compute(session) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + status=GatewayStatus.PROVISIONING, + ) + gateway.lock_token = uuid.uuid4() + gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + with patch( + "dstack._internal.server.services.gateways.gateway_connections_pool.get_or_add" + ) as pool_add: + pool_add.return_value = MagicMock() + pool_add.return_value.client.return_value = MagicMock(AsyncContextManager()) + await worker.process(_gateway_to_pipeline_item(gateway)) + pool_add.assert_called_once() + + await session.refresh(gateway) + assert gateway.status == GatewayStatus.RUNNING + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Gateway status changed PROVISIONING -> RUNNING" + + async def test_marks_gateway_as_failed_if_fails_to_connect( + self, test_db, session: AsyncSession, worker: GatewayWorker + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway_compute = await create_gateway_compute(session) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + status=GatewayStatus.PROVISIONING, + ) + gateway.lock_token = uuid.uuid4() + gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + with patch( + "dstack._internal.server.services.gateways.connect_to_gateway_with_retry" + ) as connect_to_gateway_with_retry_mock: + connect_to_gateway_with_retry_mock.return_value = None + await worker.process(_gateway_to_pipeline_item(gateway)) + connect_to_gateway_with_retry_mock.assert_called_once() + + await session.refresh(gateway) + assert gateway.status == GatewayStatus.FAILED + assert gateway.status_message == "Failed to connect to gateway" + events = await list_events(session) + assert len(events) == 1 + assert ( + events[0].message + == "Gateway status changed PROVISIONING -> FAILED (Failed to connect to gateway)" + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestGatewayWorkerDeleted: + async def test_deletes_gateway_and_marks_compute_deleted( + self, test_db, session: AsyncSession, worker: GatewayWorker + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + status=GatewayStatus.RUNNING, + ) + gateway.lock_token = uuid.uuid4() + gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + gateway.to_be_deleted = True + await session.commit() + + with ( + patch( + "dstack._internal.server.services.backends.get_project_backend_by_type_or_error" + ) as get_backend_mock, + patch( + "dstack._internal.server.background.pipeline_tasks.gateways.gateway_connections_pool.remove" + ) as remove_connection_mock, + ): + backend_mock = Mock() + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + get_backend_mock.return_value = backend_mock + + await worker.process(_gateway_to_pipeline_item(gateway)) + + get_backend_mock.assert_called_once() + backend_mock.compute.return_value.terminate_gateway.assert_called_once() + remove_connection_mock.assert_called_once_with(gateway_compute.ip_address) + + await session.refresh(gateway_compute) + res = await session.execute(select(GatewayModel.id).where(GatewayModel.id == gateway.id)) + assert res.scalar_one_or_none() is None + assert gateway_compute.active is False + assert gateway_compute.deleted is True + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Gateway deleted" + + async def test_keeps_gateway_if_terminate_fails( + self, test_db, session: AsyncSession, worker: GatewayWorker + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + status=GatewayStatus.RUNNING, + ) + gateway.lock_token = uuid.uuid4() + gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + gateway.to_be_deleted = True + original_last_processed_at = gateway.last_processed_at + await session.commit() + + with ( + patch( + "dstack._internal.server.services.backends.get_project_backend_by_type_or_error" + ) as get_backend_mock, + patch( + "dstack._internal.server.background.pipeline_tasks.gateways.gateway_connections_pool.remove" + ) as remove_connection_mock, + ): + backend_mock = Mock() + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.terminate_gateway.side_effect = BackendError( + "Terminate failed" + ) + get_backend_mock.return_value = backend_mock + + await worker.process(_gateway_to_pipeline_item(gateway)) + + get_backend_mock.assert_called_once() + backend_mock.compute.return_value.terminate_gateway.assert_called_once() + remove_connection_mock.assert_not_called() + + await session.refresh(gateway) + await session.refresh(gateway_compute) + assert gateway.to_be_deleted is True + assert gateway.last_processed_at > original_last_processed_at + assert gateway_compute.active is True + assert gateway_compute.deleted is False + events = await list_events(session) + assert len(events) == 0 diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py b/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py index 87cab83e1..7baed58b6 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py @@ -7,6 +7,7 @@ from dstack._internal.server.background.pipeline_tasks.base import PipelineItem from dstack._internal.server.background.pipeline_tasks.placement_groups import PlacementGroupWorker +from dstack._internal.server.models import PlacementGroupModel from dstack._internal.server.testing.common import ( ComputeMockSpec, create_fleet, @@ -20,7 +21,7 @@ def worker() -> PlacementGroupWorker: return PlacementGroupWorker(queue=Mock(), heartbeater=Mock()) -def _placement_group_to_pipeline_item(placement_group) -> PipelineItem: +def _placement_group_to_pipeline_item(placement_group: PlacementGroupModel) -> PipelineItem: assert placement_group.lock_token is not None assert placement_group.lock_expires_at is not None return PipelineItem( diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_gateways.py b/src/tests/_internal/server/background/scheduled_tasks/test_gateways.py index 5f19d2cfc..b97abe914 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_gateways.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_gateways.py @@ -1,11 +1,14 @@ from unittest.mock import MagicMock, Mock, patch import pytest +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload from dstack._internal.core.errors import BackendError from dstack._internal.core.models.gateways import GatewayProvisioningData, GatewayStatus from dstack._internal.server.background.scheduled_tasks.gateways import process_gateways +from dstack._internal.server.models import GatewayModel from dstack._internal.server.testing.common import ( AsyncContextManager, ComputeMockSpec, @@ -44,6 +47,12 @@ async def test_submitted_to_provisioning(self, test_db, session: AsyncSession): m.assert_called_once() aws.compute.return_value.create_gateway.assert_called_once() await session.refresh(gateway) + res = await session.execute( + select(GatewayModel) + .where(GatewayModel.id == gateway.id) + .options(joinedload(GatewayModel.gateway_compute)) + ) + gateway = res.unique().scalar_one() assert gateway.status == GatewayStatus.PROVISIONING assert gateway.gateway_compute is not None assert gateway.gateway_compute.ip_address == "2.2.2.2" diff --git a/src/tests/_internal/server/routers/test_gateways.py b/src/tests/_internal/server/routers/test_gateways.py index f80537a1b..a0f7566bf 100644 --- a/src/tests/_internal/server/routers/test_gateways.py +++ b/src/tests/_internal/server/routers/test_gateways.py @@ -20,6 +20,15 @@ list_events, ) from dstack._internal.server.testing.matchers import SomeUUID4Str +from dstack._internal.settings import FeatureFlags + + +@pytest.fixture +def patch_pipeline_processing_flag(monkeypatch: pytest.MonkeyPatch): + def _apply(enabled: bool): + monkeypatch.setattr(FeatureFlags, "PIPELINE_PROCESSING_ENABLED", enabled) + + return _apply class TestListAndGetGateways: @@ -455,9 +464,88 @@ async def test_only_admin_can_delete( ) assert response.status_code == 403 + +class TestDeleteGatewayPipelineEnabled: + @pytest.fixture(autouse=True) + def _pipeline_processing_enabled(self, patch_pipeline_processing_flag): + patch_pipeline_processing_flag(True) + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_delete_gateway(self, test_db, session: AsyncSession, client: AsyncClient): + async def test_marks_gateways_to_be_deleted( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + backend_aws = await create_backend(session, project.id) + backend_gcp = await create_backend(session, project.id, backend_type=BackendType.GCP) + gateway_compute_aws = await create_gateway_compute( + session=session, + backend_id=backend_aws.id, + ) + gateway_aws = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend_aws.id, + name="gateway-aws", + gateway_compute_id=gateway_compute_aws.id, + ) + gateway_compute_gcp = await create_gateway_compute( + session=session, + backend_id=backend_gcp.id, + ) + gateway_gcp = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend_gcp.id, + name="gateway-gcp", + gateway_compute_id=gateway_compute_gcp.id, + ) + response = await client.post( + f"/api/project/{project.name}/gateways/delete", + json={"names": [gateway_aws.name, gateway_gcp.name]}, + headers=get_auth_headers(user.token), + ) + assert response.status_code == 200 + + await session.refresh(gateway_aws) + await session.refresh(gateway_gcp) + await session.refresh(gateway_compute_aws) + await session.refresh(gateway_compute_gcp) + assert gateway_aws.to_be_deleted is True + assert gateway_gcp.to_be_deleted is True + assert gateway_compute_aws.active is True + assert gateway_compute_aws.deleted is False + assert gateway_compute_gcp.active is True + assert gateway_compute_gcp.deleted is False + + response = await client.post( + f"/api/project/{project.name}/gateways/list", + headers=get_auth_headers(user.token), + ) + assert response.status_code == 200 + assert {g["name"] for g in response.json()} == {"gateway-aws", "gateway-gcp"} + + events = await list_events(session) + assert len(events) == 2 + assert all(e.message == "Gateway marked for deletion" for e in events) + assert {e.targets[0].entity_name for e in events} == {"gateway-aws", "gateway-gcp"} + assert all(e.actor_user_id == user.id for e in events) + + +class TestDeleteGatewayPipelineDisabled: + @pytest.fixture(autouse=True) + def _pipeline_processing_disabled(self, patch_pipeline_processing_flag): + patch_pipeline_processing_flag(False) + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_deletes_gateways_synchronously( + self, test_db, session: AsyncSession, client: AsyncClient + ): user = await create_user(session, global_role=GlobalRole.USER) project = await create_project(session) await add_project_member( @@ -545,6 +633,7 @@ def get_backend(project, backend_type): }, } ] + events = await list_events(session) assert len(events) == 1 assert events[0].message == "Gateway deleted"