Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SSH fleet hosts validation #1955

Merged
merged 1 commit into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions src/dstack/_internal/server/services/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,16 @@ async def get_plan(
spec: FleetSpec,
) -> FleetPlan:
# TODO: refactor offers logic into a separate module to avoid depending on runs
await _check_ssh_hosts_not_yet_added(session, spec)
current_fleet: Optional[Fleet] = None
current_fleet_id: Optional[uuid.UUID] = None
if spec.configuration.name is not None:
current_fleet_model = await get_project_fleet_model_by_name(
session=session, project=project, name=spec.configuration.name
)
if current_fleet_model is not None:
current_fleet = fleet_model_to_fleet(current_fleet_model)
current_fleet_id = current_fleet_model.id
await _check_ssh_hosts_not_yet_added(session, spec, current_fleet_id)

offers = []
if spec.configuration.ssh_config is None:
Expand All @@ -148,13 +157,6 @@ async def get_plan(
requirements=_get_fleet_requirements(spec),
)
offers = [offer for _, offer in offers_with_backends]
current_fleet = None
if spec.configuration.name is not None:
current_fleet = await get_fleet_by_name(
session=session,
project=project,
name=spec.configuration.name,
)
plan = FleetPlan(
project_name=project.name,
user=user.name,
Expand Down Expand Up @@ -540,13 +542,18 @@ def _check_can_manage_ssh_fleets(user: UserModel, project: ProjectModel):
raise ForbiddenError()


async def _check_ssh_hosts_not_yet_added(session: AsyncSession, spec: FleetSpec):
async def _check_ssh_hosts_not_yet_added(
session: AsyncSession, spec: FleetSpec, current_fleet_id: Optional[uuid.UUID] = None
):
if spec.configuration.ssh_config and spec.configuration.ssh_config.hosts:
# there are manually listed hosts, need to check them for existence
active_instances = await list_active_remote_instances(session=session)

existing_hosts = set()
for instance in active_instances:
# ignore instances belonging to the same fleet -- in-place update/recreate
if current_fleet_id is not None and instance.fleet_id == current_fleet_id:
continue
instance_conn_info = RemoteConnectionInfo.parse_raw(
cast(str, instance.remote_connection_info)
)
Expand Down
5 changes: 5 additions & 0 deletions src/dstack/_internal/server/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
InstanceConfiguration,
InstanceStatus,
InstanceType,
RemoteConnectionInfo,
Resources,
)
from dstack._internal.core.models.placement import (
Expand Down Expand Up @@ -418,6 +419,7 @@ async def create_instance(
session: AsyncSession,
project: ProjectModel,
pool: PoolModel,
fleet: Optional[FleetModel] = None,
status: InstanceStatus = InstanceStatus.IDLE,
created_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc),
finished_at: Optional[datetime] = None,
Expand All @@ -430,6 +432,7 @@ async def create_instance(
instance_num: int = 0,
backend: BackendType = BackendType.DATACRUNCH,
region: str = "eu-west",
remote_connection_info: Optional[RemoteConnectionInfo] = None,
) -> InstanceModel:
if instance_id is None:
instance_id = uuid.uuid4()
Expand Down Expand Up @@ -495,6 +498,7 @@ async def create_instance(
name="test_instance",
instance_num=instance_num,
pool=pool,
fleet=fleet,
project=project,
status=status,
unreachable=False,
Expand All @@ -510,6 +514,7 @@ async def create_instance(
profile=profile.json(),
requirements=requirements.json(),
instance_configuration=instance_configuration.json(),
remote_connection_info=remote_connection_info.json() if remote_connection_info else None,
job=job,
)
session.add(im)
Expand Down
170 changes: 170 additions & 0 deletions src/tests/_internal/server/services/test_fleets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from typing import Optional, Union
from unittest.mock import Mock

import pytest
from sqlalchemy.ext.asyncio import AsyncSession

from dstack._internal.core.backends.base import Backend
from dstack._internal.core.errors import ServerClientError
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.fleets import (
FleetConfiguration,
FleetSpec,
SSHHostParams,
SSHParams,
)
from dstack._internal.core.models.instances import RemoteConnectionInfo
from dstack._internal.server.models import FleetModel, ProjectModel
from dstack._internal.server.services.backends import get_project_backends
from dstack._internal.server.services.fleets import get_plan
from dstack._internal.server.testing.common import (
create_fleet,
create_instance,
create_pool,
create_project,
create_user,
get_fleet_spec,
)


class TestGetPlanSSHFleetHostsValidation:
@pytest.fixture
def get_project_backends_mock(self, monkeypatch: pytest.MonkeyPatch) -> list[Backend]:
mock = Mock(spec_set=get_project_backends, return_value=[])
monkeypatch.setattr("dstack._internal.server.services.backends.get_project_backends", mock)
return mock

def get_ssh_fleet_spec(
self, name: Optional[str], hosts: list[Union[SSHHostParams, str]]
) -> FleetSpec:
ssh_config = SSHParams(hosts=hosts, network=None)
fleet_conf = FleetConfiguration(name=name, ssh_config=ssh_config)
return get_fleet_spec(conf=fleet_conf)

async def create_fleet(
self, session: AsyncSession, project: ProjectModel, spec: FleetSpec
) -> FleetModel:
assert spec.configuration.ssh_config is not None, spec.configuration
pool = await create_pool(session=session, project=project)
fleet = await create_fleet(session=session, project=project, spec=spec)
for host in spec.configuration.ssh_config.hosts:
if isinstance(host, SSHHostParams):
hostname = host.hostname
else:
hostname = host
rci = RemoteConnectionInfo(host=hostname, port=22, ssh_user="admin", ssh_keys=[])
await create_instance(
session=session,
project=project,
pool=pool,
fleet=fleet,
backend=BackendType.REMOTE,
remote_connection_info=rci,
)
return fleet

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
@pytest.mark.usefixtures("test_db", "get_project_backends_mock")
async def test_ok_same_fleet_update(self, session: AsyncSession):
user = await create_user(session=session)
project = await create_project(session=session, owner=user)
old_fleet_spec = self.get_ssh_fleet_spec(name="my-fleet", hosts=["192.168.100.201"])
await self.create_fleet(session, project, old_fleet_spec)
new_fleet_spec = self.get_ssh_fleet_spec(
name="my-fleet", hosts=["192.168.100.201", "192.168.100.202"]
)
plan = await get_plan(session=session, project=project, user=user, spec=new_fleet_spec)
assert plan.current_resource is not None

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
@pytest.mark.usefixtures("test_db", "get_project_backends_mock")
async def test_ok_deleted_instances_ignored(self, session: AsyncSession):
user = await create_user(session=session)
project = await create_project(session=session, owner=user)
deleted_fleet_spec = self.get_ssh_fleet_spec(name="my-fleet", hosts=["192.168.100.201"])
deleted_fleet = await self.create_fleet(session, project, deleted_fleet_spec)
for instance in deleted_fleet.instances:
instance.deleted = True
deleted_fleet.deleted = True
await session.commit()
fleet_spec = self.get_ssh_fleet_spec(
name="my-fleet", hosts=["192.168.100.201", "192.168.100.202"]
)
plan = await get_plan(session=session, project=project, user=user, spec=fleet_spec)
assert plan.current_resource is None

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
@pytest.mark.usefixtures("test_db", "get_project_backends_mock")
async def test_ok_no_common_hosts_with_another_fleet(self, session: AsyncSession):
user = await create_user(session=session)
project = await create_project(session=session, owner=user)
another_fleet_spec = self.get_ssh_fleet_spec(
name="another-fleet", hosts=["192.168.100.201"]
)
await self.create_fleet(session, project, another_fleet_spec)
fleet_spec = self.get_ssh_fleet_spec(name="new-fleet", hosts=["192.168.100.202"])
plan = await get_plan(session=session, project=project, user=user, spec=fleet_spec)
assert plan.current_resource is None

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
@pytest.mark.usefixtures("test_db", "get_project_backends_mock")
async def test_error_another_fleet_same_project(self, session: AsyncSession):
user = await create_user(session=session)
project = await create_project(session=session, owner=user)
another_fleet_spec = self.get_ssh_fleet_spec(
name="another-fleet", hosts=["192.168.100.201"]
)
await self.create_fleet(session, project, another_fleet_spec)
fleet_spec = self.get_ssh_fleet_spec(
name="new-fleet", hosts=["192.168.100.201", "192.168.100.202"]
)
with pytest.raises(
ServerClientError, match=r"Instances \[192\.168\.100\.201\] are already assigned"
):
await get_plan(session=session, project=project, user=user, spec=fleet_spec)

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
@pytest.mark.usefixtures("test_db", "get_project_backends_mock")
async def test_error_another_fleet_another_project(self, session: AsyncSession):
another_user = await create_user(session=session, name="another-user")
another_project = await create_project(
session=session, owner=another_user, name="another-project"
)
another_fleet_spec = self.get_ssh_fleet_spec(
name="another-fleet", hosts=["192.168.100.201"]
)
await self.create_fleet(session, another_project, another_fleet_spec)
user = await create_user(session=session, name="my-user")
project = await create_project(session=session, owner=user, name="my-project")
fleet_spec = self.get_ssh_fleet_spec(
name="my-fleet", hosts=["192.168.100.201", "192.168.100.202"]
)
with pytest.raises(
ServerClientError, match=r"Instances \[192\.168\.100\.201\] are already assigned"
):
await get_plan(session=session, project=project, user=user, spec=fleet_spec)

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
@pytest.mark.usefixtures("test_db", "get_project_backends_mock")
async def test_error_fleet_spec_without_name(self, session: AsyncSession):
# Even if the user apply the same configuration again, we cannot be sure if it is the same
# fleet or a brand new fleet, as we identify fleets by name.
user = await create_user(session=session)
project = await create_project(session=session, owner=user)
existing_fleet_spec = self.get_ssh_fleet_spec(
name="autogenerated-fleet-name", hosts=["192.168.100.201"]
)
await self.create_fleet(session, project, existing_fleet_spec)
fleet_spec_without_name = self.get_ssh_fleet_spec(name=None, hosts=["192.168.100.201"])
with pytest.raises(
ServerClientError, match=r"Instances \[192\.168\.100\.201\] are already assigned"
):
await get_plan(
session=session, project=project, user=user, spec=fleet_spec_without_name
)