Skip to content

Commit

Permalink
fix: get project (#17666)
Browse files Browse the repository at this point in the history
  • Loading branch information
yurijmikhalevich authored May 19, 2023
1 parent 3a6d0d8 commit 61246c3
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 37 deletions.
1 change: 0 additions & 1 deletion src/lightning/app/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def get_lightning_cloud_url() -> str:
# Project under which the resources need to run in cloud. If this env is not set,
# cloud runner will try to get the default project from the cloud
LIGHTNING_CLOUD_PROJECT_ID = os.getenv("LIGHTNING_CLOUD_PROJECT_ID")
LIGHTNING_CLOUD_ORGANIZATION_ID = os.getenv("LIGHTNING_CLOUD_ORGANIZATION_ID")
LIGHTNING_CLOUD_PRINT_SPECS = os.getenv("LIGHTNING_CLOUD_PRINT_SPECS")
LIGHTNING_DIR = os.getenv("LIGHTNING_DIR", str(Path.home() / ".lightning"))
LIGHTNING_CREDENTIAL_PATH = os.getenv("LIGHTNING_CREDENTIAL_PATH", str(Path(LIGHTNING_DIR) / "credentials.json"))
Expand Down
37 changes: 18 additions & 19 deletions src/lightning/app/utilities/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,35 @@
from lightning_cloud.openapi import V1Membership

import lightning.app
from lightning.app.core import constants
from lightning.app.core.constants import LIGHTNING_CLOUD_PROJECT_ID
from lightning.app.utilities.enum import AppStage
from lightning.app.utilities.network import LightningClient


def _get_project(
client: LightningClient,
organization_id: Optional[str] = None,
project_id: Optional[str] = None,
verbose: bool = True,
) -> V1Membership:
def _get_project(client: LightningClient, project_id: Optional[str] = None, verbose: bool = True) -> V1Membership:
"""Get a project membership for the user from the backend."""
if project_id is None:
project_id = constants.LIGHTNING_CLOUD_PROJECT_ID
if organization_id is None:
organization_id = constants.LIGHTNING_CLOUD_ORGANIZATION_ID
project_id = LIGHTNING_CLOUD_PROJECT_ID

projects = client.projects_service_list_memberships(
**({"organization_id": organization_id} if organization_id is not None else {})
)
if project_id is not None:
for membership in projects.memberships:
if membership.project_id == project_id:
break
else:
project = client.projects_service_get_project(project_id)
if not project:
raise ValueError(
"Environment variable `LIGHTNING_CLOUD_PROJECT_ID` is set but could not find an associated project."
)
return membership

return V1Membership(
name=project.name,
display_name=project.display_name,
description=project.description,
created_at=project.created_at,
project_id=project.id,
owner_id=project.owner_id,
owner_type=project.owner_type,
quotas=project.quotas,
updated_at=project.updated_at,
)

projects = client.projects_service_list_memberships()
if len(projects.memberships) == 0:
raise ValueError("No valid projects found. Please reach out to lightning.ai team to create a project")
if len(projects.memberships) > 1 and verbose:
Expand Down
24 changes: 7 additions & 17 deletions tests/tests_app/utilities/test_cloud.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,19 @@
import os
from unittest import mock

from lightning_cloud.openapi.models import V1ListMembershipsResponse, V1Membership
from lightning_cloud.openapi.models import V1Project

from lightning.app.utilities.cloud import _get_project, is_running_in_cloud


@mock.patch("lightning.app.core.constants.LIGHTNING_CLOUD_ORGANIZATION_ID", "organization_id")
def test_get_project_picks_up_organization_id():
"""Uses organization_id from `LIGHTNING_CLOUD_ORGANIZATION_ID` config var if none passed."""
def test_get_project_queries_by_project_id_directly_if_it_is_passed():
lightning_client = mock.MagicMock()
lightning_client.projects_service_list_memberships = mock.MagicMock(
return_value=V1ListMembershipsResponse(memberships=[V1Membership(project_id="project_id")]),
lightning_client.projects_service_get_project = mock.MagicMock(
return_value=V1Project(id="project_id"),
)
_get_project(lightning_client)
lightning_client.projects_service_list_memberships.assert_called_once_with(organization_id="organization_id")


def test_get_project_doesnt_pass_organization_id_if_its_not_set():
lightning_client = mock.MagicMock()
lightning_client.projects_service_list_memberships = mock.MagicMock(
return_value=V1ListMembershipsResponse(memberships=[V1Membership(project_id="project_id")]),
)
_get_project(lightning_client)
lightning_client.projects_service_list_memberships.assert_called_once_with()
project = _get_project(lightning_client, project_id="project_id")
assert project.project_id == "project_id"
lightning_client.projects_service_get_project.assert_called_once_with("project_id")


def test_is_running_cloud():
Expand Down

0 comments on commit 61246c3

Please sign in to comment.