Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize ProjectModel loading #1199

Merged
merged 1 commit into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/dstack/_internal/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
from dstack._internal.server.services.users import get_or_create_admin_user
from dstack._internal.server.settings import (
DEFAULT_PROJECT_NAME,
DSTACK_DO_NOT_UPDATE_DEFAULT_PROJECT,
DSTACK_UPDATE_DEFAULT_PROJECT,
DO_NOT_UPDATE_DEFAULT_PROJECT,
SERVER_CONFIG_FILE_PATH,
SERVER_URL,
UPDATE_DEFAULT_PROJECT,
)
from dstack._internal.server.utils.logging import configure_logging
from dstack._internal.server.utils.routers import (
Expand Down Expand Up @@ -109,8 +109,8 @@ async def lifespan(app: FastAPI):
project_name=DEFAULT_PROJECT_NAME,
url=SERVER_URL,
token=admin.token,
default=DSTACK_UPDATE_DEFAULT_PROJECT,
no_default=DSTACK_DO_NOT_UPDATE_DEFAULT_PROJECT,
default=UPDATE_DEFAULT_PROJECT,
no_default=DO_NOT_UPDATE_DEFAULT_PROJECT,
)
if settings.SERVER_BUCKET is not None:
init_default_storage()
Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/server/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class Database:
def __init__(self, url: str):
self.url = url
self.engine = create_async_engine(self.url, echo=False)
self.engine = create_async_engine(self.url, echo=settings.SQL_ECHO_ENABLED)
self.session_maker = sessionmaker(
bind=self.engine, expire_on_commit=False, class_=AsyncSession
)
Expand Down
6 changes: 2 additions & 4 deletions src/dstack/_internal/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ class ProjectModel(BaseModel):
default_pool_id: Mapped[Optional[UUIDType]] = mapped_column(
ForeignKey("pools.id", use_alter=True, ondelete="SET NULL"), nullable=True
)
default_pool: Mapped[Optional["PoolModel"]] = relationship(
foreign_keys=[default_pool_id], lazy="selectin"
)
default_pool: Mapped[Optional["PoolModel"]] = relationship(foreign_keys=[default_pool_id])


class MemberModel(BaseModel):
Expand Down Expand Up @@ -350,5 +348,5 @@ class InstanceModel(BaseModel):

# current job
job_id: Mapped[Optional[uuid.UUID]] = mapped_column(ForeignKey("jobs.id"))
job: Mapped[Optional["JobModel"]] = relationship(back_populates="instance", lazy="immediate")
job: Mapped[Optional["JobModel"]] = relationship(back_populates="instance", lazy="joined")
last_job_processed_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
19 changes: 8 additions & 11 deletions src/dstack/_internal/server/services/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,12 @@ async def get_pool(
return res.one_or_none()


async def get_named_or_default_pool(
session: AsyncSession, project: ProjectModel, pool_name: Optional[str]
) -> Optional[PoolModel]:
if pool_name is not None:
return await get_pool(session, project, pool_name)
return project.default_pool


async def get_or_create_pool_by_name(
session: AsyncSession, project: ProjectModel, pool_name: Optional[str]
) -> PoolModel:
if pool_name is None:
if project.default_pool is not None:
return project.default_pool
if project.default_pool_id is not None:
return await get_default_pool_or_error(session, project)
default_pool = await get_pool(session, project, DEFAULT_POOL_NAME)
if default_pool is not None:
await set_default_pool(session, project, DEFAULT_POOL_NAME)
Expand All @@ -87,6 +79,11 @@ async def get_or_create_pool_by_name(
return await create_pool(session, project, pool_name)


async def get_default_pool_or_error(session: AsyncSession, project: ProjectModel) -> PoolModel:
res = await session.execute(select(PoolModel).where(PoolModel.id == project.default_pool_id))
return res.scalar_one()


async def create_pool(session: AsyncSession, project: ProjectModel, name: str) -> PoolModel:
pool = await get_pool(session, project, name)
if pool is not None:
Expand All @@ -98,7 +95,7 @@ async def create_pool(session: AsyncSession, project: ProjectModel, name: str) -
session.add(pool)
await session.commit()
await session.refresh(pool)
if project.default_pool is None:
if project.default_pool_id is None:
await set_default_pool(session, project, pool.name)
return pool

Expand Down
18 changes: 11 additions & 7 deletions src/dstack/_internal/server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,12 @@

SERVER_CONFIG_DISABLED = os.getenv("DSTACK_SERVER_CONFIG_DISABLED") is not None
SERVER_CONFIG_ENABLED = not SERVER_CONFIG_DISABLED
LOCAL_BACKEND_ENABLED = os.getenv("DSTACK_LOCAL_BACKEND_ENABLED") is not None

SERVER_BUCKET = os.getenv("DSTACK_SERVER_BUCKET")
SERVER_BUCKET_REGION = os.getenv("DSTACK_SERVER_BUCKET_REGION", "eu-west-1")

DEFAULT_PROJECT_NAME = "main"

DSTACK_UPDATE_DEFAULT_PROJECT = os.getenv("DSTACK_UPDATE_DEFAULT_PROJECT") is not None
DSTACK_DO_NOT_UPDATE_DEFAULT_PROJECT = (
os.getenv("DSTACK_DO_NOT_UPDATE_DEFAULT_PROJECT") is not None
)
SKIP_GATEWAY_UPDATE = bool(os.getenv("DSTACK_SKIP_GATEWAY_UPDATE", None))

SENTRY_DSN = os.getenv("DSTACK_SENTRY_DSN")
SENTRY_TRACES_SAMPLE_RATE = float(os.getenv("DSTACK_SENTRY_TRACES_SAMPLE_RATE", 0.1))

Expand All @@ -51,3 +44,14 @@
ACME_EAB_HMAC_KEY = os.getenv("DSTACK_ACME_EAB_HMAC_KEY")

USER_PROJECT_DEFAULT_QUOTA = int(os.getenv("DSTACK_USER_PROJECT_DEFAULT_QUOTA", 10))


# Development settings

SQL_ECHO_ENABLED = os.getenv("DSTACK_SQL_ECHO_ENABLED") is not None

LOCAL_BACKEND_ENABLED = os.getenv("DSTACK_LOCAL_BACKEND_ENABLED") is not None

UPDATE_DEFAULT_PROJECT = os.getenv("DSTACK_UPDATE_DEFAULT_PROJECT") is not None
DO_NOT_UPDATE_DEFAULT_PROJECT = os.getenv("DSTACK_DO_NOT_UPDATE_DEFAULT_PROJECT") is not None
SKIP_GATEWAY_UPDATE = bool(os.getenv("DSTACK_SKIP_GATEWAY_UPDATE", None))
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
JobTerminationReason,
)
from dstack._internal.server.background.tasks.process_submitted_jobs import process_submitted_jobs
from dstack._internal.server.models import JobModel
from dstack._internal.server.models import JobModel, ProjectModel
from dstack._internal.server.services.pools import (
get_or_create_pool_by_name,
)
Expand Down Expand Up @@ -117,7 +117,12 @@ async def test_provisiones_job(self, test_db, session: AsyncSession):
assert job is not None
assert job.status == JobStatus.PROVISIONING

await session.refresh(project)
res = await session.execute(
select(ProjectModel)
.where(ProjectModel.id == project.id)
.options(joinedload(ProjectModel.default_pool))
)
project = res.scalar_one()
assert project.default_pool.name == DEFAULT_POOL_NAME

instance_offer = InstanceOfferWithAvailability.parse_raw(
Expand Down Expand Up @@ -165,7 +170,12 @@ async def test_fails_job_when_no_capacity(self, test_db, session: AsyncSession):
assert job.status == JobStatus.TERMINATING
assert job.termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY

await session.refresh(project)
res = await session.execute(
select(ProjectModel)
.where(ProjectModel.id == project.id)
.options(joinedload(ProjectModel.default_pool))
)
project = res.scalar_one()
assert not project.default_pool.instances

@pytest.mark.asyncio
Expand Down
Loading