Skip to content

Commit

Permalink
Optimize ProjectModel loading (#1199)
Browse files Browse the repository at this point in the history
  • Loading branch information
r4victor authored May 6, 2024
1 parent 1da1809 commit 00aec78
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 30 deletions.
8 changes: 4 additions & 4 deletions src/dstack/_internal/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
from dstack._internal.server.services.users import get_or_create_admin_user
from dstack._internal.server.settings import (
DEFAULT_PROJECT_NAME,
DSTACK_DO_NOT_UPDATE_DEFAULT_PROJECT,
DSTACK_UPDATE_DEFAULT_PROJECT,
DO_NOT_UPDATE_DEFAULT_PROJECT,
SERVER_CONFIG_FILE_PATH,
SERVER_URL,
UPDATE_DEFAULT_PROJECT,
)
from dstack._internal.server.utils.logging import configure_logging
from dstack._internal.server.utils.routers import (
Expand Down Expand Up @@ -109,8 +109,8 @@ async def lifespan(app: FastAPI):
project_name=DEFAULT_PROJECT_NAME,
url=SERVER_URL,
token=admin.token,
default=DSTACK_UPDATE_DEFAULT_PROJECT,
no_default=DSTACK_DO_NOT_UPDATE_DEFAULT_PROJECT,
default=UPDATE_DEFAULT_PROJECT,
no_default=DO_NOT_UPDATE_DEFAULT_PROJECT,
)
if settings.SERVER_BUCKET is not None:
init_default_storage()
Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/server/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class Database:
def __init__(self, url: str):
self.url = url
self.engine = create_async_engine(self.url, echo=False)
self.engine = create_async_engine(self.url, echo=settings.SQL_ECHO_ENABLED)
self.session_maker = sessionmaker(
bind=self.engine, expire_on_commit=False, class_=AsyncSession
)
Expand Down
6 changes: 2 additions & 4 deletions src/dstack/_internal/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ class ProjectModel(BaseModel):
default_pool_id: Mapped[Optional[UUIDType]] = mapped_column(
ForeignKey("pools.id", use_alter=True, ondelete="SET NULL"), nullable=True
)
default_pool: Mapped[Optional["PoolModel"]] = relationship(
foreign_keys=[default_pool_id], lazy="selectin"
)
default_pool: Mapped[Optional["PoolModel"]] = relationship(foreign_keys=[default_pool_id])


class MemberModel(BaseModel):
Expand Down Expand Up @@ -350,5 +348,5 @@ class InstanceModel(BaseModel):

# current job
job_id: Mapped[Optional[uuid.UUID]] = mapped_column(ForeignKey("jobs.id"))
job: Mapped[Optional["JobModel"]] = relationship(back_populates="instance", lazy="immediate")
job: Mapped[Optional["JobModel"]] = relationship(back_populates="instance", lazy="joined")
last_job_processed_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
19 changes: 8 additions & 11 deletions src/dstack/_internal/server/services/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,12 @@ async def get_pool(
return res.one_or_none()


async def get_named_or_default_pool(
session: AsyncSession, project: ProjectModel, pool_name: Optional[str]
) -> Optional[PoolModel]:
if pool_name is not None:
return await get_pool(session, project, pool_name)
return project.default_pool


async def get_or_create_pool_by_name(
session: AsyncSession, project: ProjectModel, pool_name: Optional[str]
) -> PoolModel:
if pool_name is None:
if project.default_pool is not None:
return project.default_pool
if project.default_pool_id is not None:
return await get_default_pool_or_error(session, project)
default_pool = await get_pool(session, project, DEFAULT_POOL_NAME)
if default_pool is not None:
await set_default_pool(session, project, DEFAULT_POOL_NAME)
Expand All @@ -87,6 +79,11 @@ async def get_or_create_pool_by_name(
return await create_pool(session, project, pool_name)


async def get_default_pool_or_error(session: AsyncSession, project: ProjectModel) -> PoolModel:
res = await session.execute(select(PoolModel).where(PoolModel.id == project.default_pool_id))
return res.scalar_one()


async def create_pool(session: AsyncSession, project: ProjectModel, name: str) -> PoolModel:
pool = await get_pool(session, project, name)
if pool is not None:
Expand All @@ -98,7 +95,7 @@ async def create_pool(session: AsyncSession, project: ProjectModel, name: str) -
session.add(pool)
await session.commit()
await session.refresh(pool)
if project.default_pool is None:
if project.default_pool_id is None:
await set_default_pool(session, project, pool.name)
return pool

Expand Down
18 changes: 11 additions & 7 deletions src/dstack/_internal/server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,12 @@

SERVER_CONFIG_DISABLED = os.getenv("DSTACK_SERVER_CONFIG_DISABLED") is not None
SERVER_CONFIG_ENABLED = not SERVER_CONFIG_DISABLED
LOCAL_BACKEND_ENABLED = os.getenv("DSTACK_LOCAL_BACKEND_ENABLED") is not None

SERVER_BUCKET = os.getenv("DSTACK_SERVER_BUCKET")
SERVER_BUCKET_REGION = os.getenv("DSTACK_SERVER_BUCKET_REGION", "eu-west-1")

DEFAULT_PROJECT_NAME = "main"

DSTACK_UPDATE_DEFAULT_PROJECT = os.getenv("DSTACK_UPDATE_DEFAULT_PROJECT") is not None
DSTACK_DO_NOT_UPDATE_DEFAULT_PROJECT = (
os.getenv("DSTACK_DO_NOT_UPDATE_DEFAULT_PROJECT") is not None
)
SKIP_GATEWAY_UPDATE = bool(os.getenv("DSTACK_SKIP_GATEWAY_UPDATE", None))

SENTRY_DSN = os.getenv("DSTACK_SENTRY_DSN")
SENTRY_TRACES_SAMPLE_RATE = float(os.getenv("DSTACK_SENTRY_TRACES_SAMPLE_RATE", 0.1))

Expand All @@ -51,3 +44,14 @@
ACME_EAB_HMAC_KEY = os.getenv("DSTACK_ACME_EAB_HMAC_KEY")

USER_PROJECT_DEFAULT_QUOTA = int(os.getenv("DSTACK_USER_PROJECT_DEFAULT_QUOTA", 10))


# Development settings

SQL_ECHO_ENABLED = os.getenv("DSTACK_SQL_ECHO_ENABLED") is not None

LOCAL_BACKEND_ENABLED = os.getenv("DSTACK_LOCAL_BACKEND_ENABLED") is not None

UPDATE_DEFAULT_PROJECT = os.getenv("DSTACK_UPDATE_DEFAULT_PROJECT") is not None
DO_NOT_UPDATE_DEFAULT_PROJECT = os.getenv("DSTACK_DO_NOT_UPDATE_DEFAULT_PROJECT") is not None
SKIP_GATEWAY_UPDATE = bool(os.getenv("DSTACK_SKIP_GATEWAY_UPDATE", None))
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
JobTerminationReason,
)
from dstack._internal.server.background.tasks.process_submitted_jobs import process_submitted_jobs
from dstack._internal.server.models import JobModel
from dstack._internal.server.models import JobModel, ProjectModel
from dstack._internal.server.services.pools import (
get_or_create_pool_by_name,
)
Expand Down Expand Up @@ -117,7 +117,12 @@ async def test_provisiones_job(self, test_db, session: AsyncSession):
assert job is not None
assert job.status == JobStatus.PROVISIONING

await session.refresh(project)
res = await session.execute(
select(ProjectModel)
.where(ProjectModel.id == project.id)
.options(joinedload(ProjectModel.default_pool))
)
project = res.scalar_one()
assert project.default_pool.name == DEFAULT_POOL_NAME

instance_offer = InstanceOfferWithAvailability.parse_raw(
Expand Down Expand Up @@ -165,7 +170,12 @@ async def test_fails_job_when_no_capacity(self, test_db, session: AsyncSession):
assert job.status == JobStatus.TERMINATING
assert job.termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY

await session.refresh(project)
res = await session.execute(
select(ProjectModel)
.where(ProjectModel.id == project.id)
.options(joinedload(ProjectModel.default_pool))
)
project = res.scalar_one()
assert not project.default_pool.instances

@pytest.mark.asyncio
Expand Down

0 comments on commit 00aec78

Please sign in to comment.