Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
## Coding Style & Naming Conventions
- Python targets 3.9+ with 4-space indentation and max line length of 99 (see `ruff.toml`; `E501` is ignored but keep lines readable).
- Imports are sorted via Ruff’s isort settings (`dstack` treated as first-party).
- Keep primary/public functions before local helper functions in a module section.
- Prefer pydantic-style models in `core/models`.
- Tests use `test_*.py` modules and `test_*` functions; fixtures live near usage.

Expand Down
5 changes: 3 additions & 2 deletions src/dstack/_internal/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,9 @@ async def lifespan(app: FastAPI):
pipeline_manager = None
if settings.SERVER_BACKGROUND_PROCESSING_ENABLED:
scheduler = start_scheduled_tasks()
pipeline_manager = start_pipeline_tasks()
app.state.pipeline_manager = pipeline_manager
if core_settings.FeatureFlags.PIPELINE_PROCESSING_ENABLED:
pipeline_manager = start_pipeline_tasks()
app.state.pipeline_manager = pipeline_manager
else:
logger.info("Background processing is disabled")
PROBES_SCHEDULER.start()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,22 @@

from dstack._internal.server.background.pipeline_tasks.base import Pipeline
from dstack._internal.server.background.pipeline_tasks.compute_groups import ComputeGroupPipeline
from dstack._internal.server.background.pipeline_tasks.gateways import GatewayPipeline
from dstack._internal.server.background.pipeline_tasks.placement_groups import (
PlacementGroupPipeline,
)
from dstack._internal.settings import FeatureFlags
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)


class PipelineManager:
def __init__(self) -> None:
self._pipelines: list[Pipeline] = []
if FeatureFlags.PIPELINE_PROCESSING_ENABLED:
self._pipelines += [
ComputeGroupPipeline(),
PlacementGroupPipeline(),
]
self._pipelines: list[Pipeline] = [
ComputeGroupPipeline(),
GatewayPipeline(),
PlacementGroupPipeline(),
]
self._hinter = PipelineHinter(self._pipelines)

def start(self):
Expand Down
52 changes: 30 additions & 22 deletions src/dstack/_internal/server/background/pipeline_tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,25 @@

@dataclass
class PipelineItem:
"""
Pipelines can work with this class or its subclass if the worker needs to access extra attributes.
"""

__tablename__: str
id: uuid.UUID
lock_expires_at: datetime
lock_token: uuid.UUID
prev_lock_expired: bool


ItemT = TypeVar("ItemT", bound=PipelineItem)


class PipelineModel(Protocol):
"""
Heartbeater can work with any DB model implementing this protocol.
"""

__tablename__: str
__mapper__: ClassVar[Any]
__table__: ClassVar[Any]
Expand All @@ -39,7 +50,7 @@ class PipelineError(Exception):
pass


class Pipeline(ABC):
class Pipeline(Generic[ItemT], ABC):
def __init__(
self,
workers_num: int,
Expand All @@ -57,7 +68,7 @@ def __init__(
self._min_processing_interval = min_processing_interval
self._lock_timeout = lock_timeout
self._heartbeat_trigger = heartbeat_trigger
self._queue = asyncio.Queue[PipelineItem](maxsize=self._queue_maxsize)
self._queue = asyncio.Queue[ItemT](maxsize=self._queue_maxsize)
self._tasks: list[asyncio.Task] = []
self._running = False
self._shutdown = False
Expand Down Expand Up @@ -119,35 +130,32 @@ def hint_fetch_model_name(self) -> str:

@property
@abstractmethod
def _heartbeater(self) -> "Heartbeater":
def _heartbeater(self) -> "Heartbeater[ItemT]":
pass

@property
@abstractmethod
def _fetcher(self) -> "Fetcher":
def _fetcher(self) -> "Fetcher[ItemT]":
pass

@property
@abstractmethod
def _workers(self) -> Sequence["Worker"]:
def _workers(self) -> Sequence["Worker[ItemT]"]:
pass


ModelT = TypeVar("ModelT", bound=PipelineModel)


class Heartbeater(Generic[ModelT]):
class Heartbeater(Generic[ItemT]):
def __init__(
self,
model_type: type[ModelT],
model_type: type[PipelineModel],
lock_timeout: timedelta,
heartbeat_trigger: timedelta,
heartbeat_delay: float = 1.0,
) -> None:
self._model_type = model_type
self._lock_timeout = lock_timeout
self._hearbeat_margin = heartbeat_trigger
self._items: dict[uuid.UUID, PipelineItem] = {}
self._items: dict[uuid.UUID, ItemT] = {}
self._untrack_lock = asyncio.Lock()
self._heartbeat_delay = heartbeat_delay
self._running = False
Expand All @@ -164,18 +172,18 @@ async def start(self):
def stop(self):
self._running = False

async def track(self, item: PipelineItem):
async def track(self, item: ItemT):
self._items[item.id] = item

async def untrack(self, item: PipelineItem):
async def untrack(self, item: ItemT):
async with self._untrack_lock:
tracked = self._items.get(item.id)
# Prevent expired fetch iteration to unlock item processed by new iteration.
if tracked is not None and tracked.lock_token == item.lock_token:
del self._items[item.id]

async def heartbeat(self):
items_to_update: list[PipelineItem] = []
items_to_update: list[ItemT] = []
now = get_current_datetime()
items = list(self._items.values())
failed_to_heartbeat_count = 0
Expand Down Expand Up @@ -227,16 +235,16 @@ async def heartbeat(self):
)


class Fetcher(ABC):
class Fetcher(Generic[ItemT], ABC):
_DEFAULT_FETCH_DELAYS = [0.5, 1, 2, 5]

def __init__(
self,
queue: asyncio.Queue[PipelineItem],
queue: asyncio.Queue[ItemT],
queue_desired_minsize: int,
min_processing_interval: timedelta,
lock_timeout: timedelta,
heartbeater: Heartbeater,
heartbeater: Heartbeater[ItemT],
queue_check_delay: float = 1.0,
fetch_delays: Optional[list[float]] = None,
) -> None:
Expand Down Expand Up @@ -289,7 +297,7 @@ def hint(self):
self._fetch_event.set()

@abstractmethod
async def fetch(self, limit: int) -> list[PipelineItem]:
async def fetch(self, limit: int) -> list[ItemT]:
pass

def _next_fetch_delay(self, empty_fetch_count: int) -> float:
Expand All @@ -298,11 +306,11 @@ def _next_fetch_delay(self, empty_fetch_count: int) -> float:
return next_delay * (1 + jitter)


class Worker(ABC):
class Worker(Generic[ItemT], ABC):
def __init__(
self,
queue: asyncio.Queue[PipelineItem],
heartbeater: Heartbeater,
queue: asyncio.Queue[ItemT],
heartbeater: Heartbeater[ItemT],
) -> None:
self._queue = queue
self._heartbeater = heartbeater
Expand All @@ -325,7 +333,7 @@ def stop(self):
self._running = False

@abstractmethod
async def process(self, item: PipelineItem):
async def process(self, item: ItemT):
pass


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from dstack._internal.server.models import ComputeGroupModel, InstanceModel, ProjectModel
from dstack._internal.server.services import backends as backends_services
from dstack._internal.server.services.compute_groups import compute_group_model_to_compute_group
from dstack._internal.server.services.instances import switch_instance_status
from dstack._internal.server.services.instances import emit_instance_status_change_event
from dstack._internal.server.services.locking import get_locker
from dstack._internal.utils.common import get_current_datetime, run_async
from dstack._internal.utils.logging import get_logger
Expand All @@ -36,7 +36,7 @@
TERMINATION_RETRY_MAX_DURATION = timedelta(minutes=15)


class ComputeGroupPipeline(Pipeline):
class ComputeGroupPipeline(Pipeline[PipelineItem]):
def __init__(
self,
workers_num: int = 10,
Expand All @@ -54,7 +54,7 @@ def __init__(
lock_timeout=lock_timeout,
heartbeat_trigger=heartbeat_trigger,
)
self.__heartbeater = Heartbeater[ComputeGroupModel](
self.__heartbeater = Heartbeater[PipelineItem](
model_type=ComputeGroupModel,
lock_timeout=self._lock_timeout,
heartbeat_trigger=self._heartbeat_trigger,
Expand All @@ -76,26 +76,26 @@ def hint_fetch_model_name(self) -> str:
return ComputeGroupModel.__name__

@property
def _heartbeater(self) -> Heartbeater:
def _heartbeater(self) -> Heartbeater[PipelineItem]:
return self.__heartbeater

@property
def _fetcher(self) -> Fetcher:
def _fetcher(self) -> Fetcher[PipelineItem]:
return self.__fetcher

@property
def _workers(self) -> Sequence["ComputeGroupWorker"]:
return self.__workers


class ComputeGroupFetcher(Fetcher):
class ComputeGroupFetcher(Fetcher[PipelineItem]):
def __init__(
self,
queue: asyncio.Queue[PipelineItem],
queue_desired_minsize: int,
min_processing_interval: timedelta,
lock_timeout: timedelta,
heartbeater: Heartbeater[ComputeGroupModel],
heartbeater: Heartbeater[PipelineItem],
queue_check_delay: float = 1.0,
) -> None:
super().__init__(
Expand Down Expand Up @@ -161,11 +161,11 @@ async def fetch(self, limit: int) -> list[PipelineItem]:
return items


class ComputeGroupWorker(Worker):
class ComputeGroupWorker(Worker[PipelineItem]):
def __init__(
self,
queue: asyncio.Queue[PipelineItem],
heartbeater: Heartbeater[ComputeGroupModel],
heartbeater: Heartbeater[PipelineItem],
) -> None:
super().__init__(
queue=queue,
Expand Down Expand Up @@ -235,7 +235,12 @@ async def process(self, item: PipelineItem):
.values(**terminate_result.instances_update_map)
)
for instance_model in compute_group_model.instances:
switch_instance_status(session, instance_model, InstanceStatus.TERMINATED)
emit_instance_status_change_event(
session=session,
instance_model=instance_model,
old_status=instance_model.status,
new_status=InstanceStatus.TERMINATED,
)


@dataclass
Expand Down
Loading