Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/dstack/_internal/server/routers/instances.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from typing import List
from typing import Annotated, List

from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession

import dstack._internal.server.services.instances as instances_services
from dstack._internal.core.errors import ResourceNotExistsError
from dstack._internal.core.models.instances import Instance
from dstack._internal.server.db import get_session
from dstack._internal.server.models import ProjectModel, UserModel
from dstack._internal.server.schemas.instances import (
GetInstanceHealthChecksRequest,
GetInstanceHealthChecksResponse,
GetInstanceRequest,
ListInstancesRequest,
)
from dstack._internal.server.security.permissions import Authenticated, ProjectMember
Expand Down Expand Up @@ -75,3 +77,21 @@ async def get_instance_health_checks(
limit=body.limit,
)
return CustomORJSONResponse(GetInstanceHealthChecksResponse(health_checks=health_checks))


@project_router.post("/get", response_model=Instance)
async def get_instance(
body: GetInstanceRequest,
session: Annotated[AsyncSession, Depends(get_session)],
user_project: Annotated[tuple[UserModel, ProjectModel], Depends(ProjectMember())],
):
"""
Returns an instance given its ID.
"""
_, project = user_project
instance = await instances_services.get_instance(
session=session, project=project, instance_id=body.id
)
if instance is None:
raise ResourceNotExistsError()
return CustomORJSONResponse(instance)
4 changes: 4 additions & 0 deletions src/dstack/_internal/server/schemas/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from dstack._internal.server.schemas.runner import InstanceHealthResponse


class GetInstanceRequest(CoreModel):
id: UUID


class ListInstancesRequest(CoreModel):
project_names: Optional[list[str]] = None
fleet_ids: Optional[list[UUID]] = None
Expand Down
22 changes: 22 additions & 0 deletions src/dstack/_internal/server/services/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,28 @@ async def get_instance_health_checks(
return health_checks


async def get_instance(
session: AsyncSession,
project: ProjectModel,
instance_id: uuid.UUID,
) -> Optional[Instance]:
res = await session.execute(
select(InstanceModel)
.where(
InstanceModel.id == instance_id,
InstanceModel.project_id == project.id,
)
.options(
joinedload(InstanceModel.fleet).load_only(FleetModel.name),
joinedload(InstanceModel.project).load_only(ProjectModel.name),
)
)
instance_model = res.scalar_one_or_none()
if instance_model is None:
return None
return instance_model_to_instance(instance_model)


def instance_model_to_instance(instance_model: InstanceModel) -> Instance:
instance = Instance(
id=instance_model.id,
Expand Down
84 changes: 84 additions & 0 deletions src/tests/_internal/server/routers/test_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,3 +395,87 @@ async def test_converts_legacy_termination_reason_string(
)
# Must convert legacy "Fleet has too many instances" to "max_instances_limit"
assert resp.json()[0]["termination_reason"] == "max_instances_limit"


@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
class TestGetInstance:
async def test_returns_instance_by_id(
self, session: AsyncSession, client: AsyncClient
) -> None:
user = await create_user(session, global_role=GlobalRole.USER)
project = await create_project(session, owner=user)
await add_project_member(
session, project=project, user=user, project_role=ProjectRole.ADMIN
)
fleet = await create_fleet(session, project)
instance = await create_instance(session=session, project=project, fleet=fleet)

resp = await client.post(
f"/api/project/{project.name}/instances/get",
headers=get_auth_headers(user.token),
json={"id": str(instance.id)},
)
assert resp.status_code == 200
resp_data = resp.json()
assert resp_data["id"] == str(instance.id)
assert resp_data["project_name"] == project.name
assert resp_data["fleet_name"] == fleet.name

async def test_returns_400_if_instance_not_found(
self, session: AsyncSession, client: AsyncClient
) -> None:
user = await create_user(session, global_role=GlobalRole.USER)
project = await create_project(session, owner=user)
await add_project_member(
session, project=project, user=user, project_role=ProjectRole.ADMIN
)

resp = await client.post(
f"/api/project/{project.name}/instances/get",
headers=get_auth_headers(user.token),
json={"id": str(uuid.uuid4())},
)
assert resp.status_code == 400
assert resp.json()["detail"][0]["code"] == "resource_not_exists"

async def test_returns_400_if_instance_exists_in_different_project(
self, session: AsyncSession, client: AsyncClient
) -> None:
user = await create_user(session, global_role=GlobalRole.USER)

project1 = await create_project(session, owner=user, name="p1")
project2 = await create_project(session, owner=user, name="p2")

await add_project_member(
session, project=project1, user=user, project_role=ProjectRole.ADMIN
)
await add_project_member(
session, project=project2, user=user, project_role=ProjectRole.ADMIN
)

fleet = await create_fleet(session, project2)
instance = await create_instance(session=session, project=project2, fleet=fleet)

resp = await client.post(
f"/api/project/{project1.name}/instances/get",
headers=get_auth_headers(user.token),
json={"id": str(instance.id)},
)
assert resp.status_code == 400
assert resp.json()["detail"][0]["code"] == "resource_not_exists"

async def test_returns_403_if_not_project_member(
self, session: AsyncSession, client: AsyncClient
) -> None:
user = await create_user(session, name="non_member", global_role=GlobalRole.USER)
project = await create_project(session)
fleet = await create_fleet(session, project)
instance = await create_instance(session=session, project=project, fleet=fleet)

resp = await client.post(
f"/api/project/{project.name}/instances/get",
headers=get_auth_headers(user.token),
json={"id": str(instance.id)},
)
assert resp.status_code == 403