diff --git a/src/dstack/_internal/server/routers/instances.py b/src/dstack/_internal/server/routers/instances.py index 67f22d9b6..b241d7e76 100644 --- a/src/dstack/_internal/server/routers/instances.py +++ b/src/dstack/_internal/server/routers/instances.py @@ -1,15 +1,17 @@ -from typing import List +from typing import Annotated, List from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession import dstack._internal.server.services.instances as instances_services +from dstack._internal.core.errors import ResourceNotExistsError from dstack._internal.core.models.instances import Instance from dstack._internal.server.db import get_session from dstack._internal.server.models import ProjectModel, UserModel from dstack._internal.server.schemas.instances import ( GetInstanceHealthChecksRequest, GetInstanceHealthChecksResponse, + GetInstanceRequest, ListInstancesRequest, ) from dstack._internal.server.security.permissions import Authenticated, ProjectMember @@ -75,3 +77,21 @@ async def get_instance_health_checks( limit=body.limit, ) return CustomORJSONResponse(GetInstanceHealthChecksResponse(health_checks=health_checks)) + + +@project_router.post("/get", response_model=Instance) +async def get_instance( + body: GetInstanceRequest, + session: Annotated[AsyncSession, Depends(get_session)], + user_project: Annotated[tuple[UserModel, ProjectModel], Depends(ProjectMember())], +): + """ + Returns an instance given its ID. + """ + _, project = user_project + instance = await instances_services.get_instance( + session=session, project=project, instance_id=body.id + ) + if instance is None: + raise ResourceNotExistsError() + return CustomORJSONResponse(instance) diff --git a/src/dstack/_internal/server/schemas/instances.py b/src/dstack/_internal/server/schemas/instances.py index 60843c629..120ff161d 100644 --- a/src/dstack/_internal/server/schemas/instances.py +++ b/src/dstack/_internal/server/schemas/instances.py @@ -7,6 +7,10 @@ from dstack._internal.server.schemas.runner import InstanceHealthResponse +class GetInstanceRequest(CoreModel): + id: UUID + + class ListInstancesRequest(CoreModel): project_names: Optional[list[str]] = None fleet_ids: Optional[list[UUID]] = None diff --git a/src/dstack/_internal/server/services/instances.py b/src/dstack/_internal/server/services/instances.py index 14f26cc3f..c311df7db 100644 --- a/src/dstack/_internal/server/services/instances.py +++ b/src/dstack/_internal/server/services/instances.py @@ -163,6 +163,28 @@ async def get_instance_health_checks( return health_checks +async def get_instance( + session: AsyncSession, + project: ProjectModel, + instance_id: uuid.UUID, +) -> Optional[Instance]: + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.id == instance_id, + InstanceModel.project_id == project.id, + ) + .options( + joinedload(InstanceModel.fleet).load_only(FleetModel.name), + joinedload(InstanceModel.project).load_only(ProjectModel.name), + ) + ) + instance_model = res.scalar_one_or_none() + if instance_model is None: + return None + return instance_model_to_instance(instance_model) + + def instance_model_to_instance(instance_model: InstanceModel) -> Instance: instance = Instance( id=instance_model.id, diff --git a/src/tests/_internal/server/routers/test_instances.py b/src/tests/_internal/server/routers/test_instances.py index 8aee09e6d..5f9e41df3 100644 --- a/src/tests/_internal/server/routers/test_instances.py +++ b/src/tests/_internal/server/routers/test_instances.py @@ -395,3 +395,87 @@ async def test_converts_legacy_termination_reason_string( ) # Must convert legacy "Fleet has too many instances" to "max_instances_limit" assert resp.json()[0]["termination_reason"] == "max_instances_limit" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestGetInstance: + async def test_returns_instance_by_id( + self, session: AsyncSession, client: AsyncClient + ) -> None: + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session, owner=user) + await add_project_member( + session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + fleet = await create_fleet(session, project) + instance = await create_instance(session=session, project=project, fleet=fleet) + + resp = await client.post( + f"/api/project/{project.name}/instances/get", + headers=get_auth_headers(user.token), + json={"id": str(instance.id)}, + ) + assert resp.status_code == 200 + resp_data = resp.json() + assert resp_data["id"] == str(instance.id) + assert resp_data["project_name"] == project.name + assert resp_data["fleet_name"] == fleet.name + + async def test_returns_400_if_instance_not_found( + self, session: AsyncSession, client: AsyncClient + ) -> None: + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session, owner=user) + await add_project_member( + session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + + resp = await client.post( + f"/api/project/{project.name}/instances/get", + headers=get_auth_headers(user.token), + json={"id": str(uuid.uuid4())}, + ) + assert resp.status_code == 400 + assert resp.json()["detail"][0]["code"] == "resource_not_exists" + + async def test_returns_400_if_instance_exists_in_different_project( + self, session: AsyncSession, client: AsyncClient + ) -> None: + user = await create_user(session, global_role=GlobalRole.USER) + + project1 = await create_project(session, owner=user, name="p1") + project2 = await create_project(session, owner=user, name="p2") + + await add_project_member( + session, project=project1, user=user, project_role=ProjectRole.ADMIN + ) + await add_project_member( + session, project=project2, user=user, project_role=ProjectRole.ADMIN + ) + + fleet = await create_fleet(session, project2) + instance = await create_instance(session=session, project=project2, fleet=fleet) + + resp = await client.post( + f"/api/project/{project1.name}/instances/get", + headers=get_auth_headers(user.token), + json={"id": str(instance.id)}, + ) + assert resp.status_code == 400 + assert resp.json()["detail"][0]["code"] == "resource_not_exists" + + async def test_returns_403_if_not_project_member( + self, session: AsyncSession, client: AsyncClient + ) -> None: + user = await create_user(session, name="non_member", global_role=GlobalRole.USER) + project = await create_project(session) + fleet = await create_fleet(session, project) + instance = await create_instance(session=session, project=project, fleet=fleet) + + resp = await client.post( + f"/api/project/{project.name}/instances/get", + headers=get_auth_headers(user.token), + json={"id": str(instance.id)}, + ) + assert resp.status_code == 403