Skip to content

Commit

Permalink
🐛 fix the issues with experiment listing (#64)
Browse files Browse the repository at this point in the history
* add fixes

* fix experiment listing
  • Loading branch information
renardeinside authored Aug 2, 2023
1 parent c6df11f commit 7a62801
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 35 deletions.
1 change: 1 addition & 0 deletions src/uc_migration_toolkit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class WorkspaceAuthConfig:
host: str | None = None
client_id: str | None = None
client_secret: str | None = None
profile: str | None = None


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion src/uc_migration_toolkit/managers/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _get_clean_group_info(group: Group, cleanup_keys: list[str] | None = None) -
def _get_group(group_name, level: GroupLevel) -> Group | None:
method = provider.ws.groups.list if level == GroupLevel.WORKSPACE else provider.ws.list_account_level_groups
query_filter = f"displayName eq '{group_name}'"
attributes = ",".join(["id", "displayName", "meta", "entitlements", "roles"])
attributes = ",".join(["id", "displayName", "meta", "entitlements", "roles", "members"])

group = next(
iter(method(filter=query_filter, attributes=attributes)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def provide(migration_provider: MigrationGroupsProvider):
StandardInventorizer(
logical_object_type=LogicalObjectType.EXPERIMENT,
request_object_type=RequestObjectType.EXPERIMENTS,
listing_function=provider.ws.experiments.list_experiments,
listing_function=CustomListing.list_experiments,
id_attribute="experiment_id",
),
StandardInventorizer(
Expand Down
7 changes: 7 additions & 0 deletions src/uc_migration_toolkit/managers/inventory/listing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ def list_models() -> Iterator[ModelDatabricks]:
model_with_id = provider.ws.model_registry.get_model(model.name).registered_model_databricks
yield model_with_id

@staticmethod
def list_experiments() -> Iterator[ModelDatabricks]:
for experiment in provider.ws.experiments.list_experiments():
nb_tag = [t for t in experiment.tags if t.key == "mlflow.experimentType" and t.value == "NOTEBOOK"]
if not nb_tag:
yield experiment


class WorkspaceListing:
def __init__(
Expand Down
33 changes: 16 additions & 17 deletions src/uc_migration_toolkit/providers/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,34 +82,33 @@ def update_permissions(

def apply_roles_and_entitlements(self, group_id: str, roles: list, entitlements: list):
op_schema = "urn:ietf:params:scim:api:messages:2.0:PatchOp"
schemas = [op_schema, op_schema]
schemas = []
operations = []

entitlements_payload = (
{
if entitlements:
schemas.append(op_schema)
entitlements_payload = {
"op": "add",
"path": "entitlements",
"value": entitlements,
}
if entitlements
else {}
)
operations.append(entitlements_payload)

roles_payload = (
{
if roles:
schemas.append(op_schema)
roles_payload = {
"op": "add",
"path": "roles",
"value": roles,
}
if roles
else {}
)
operations.append(roles_payload)

operations = [entitlements_payload, roles_payload]
request = {
"schemas": schemas,
"Operations": operations,
}
self.patch_workspace_group(group_id, request)
if operations:
request = {
"schemas": schemas,
"Operations": operations,
}
self.patch_workspace_group(group_id, request)


class ClientProvider:
Expand Down
27 changes: 16 additions & 11 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,17 +178,19 @@ def instance_profiles(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list
profiles.append(InstanceProfile(instance_profile_arn=profile_arn, iam_role_arn=iam_role_arn))

for ws_group, _ in env.groups:
roles = {
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
"Operations": [
{
"op": "add",
"path": "roles",
"value": [{"value": p.instance_profile_arn} for p in random.choices(profiles, k=2)],
}
],
}
provider.ws.api_client.do("PATCH", f"/api/2.0/preview/scim/v2/Groups/{ws_group.id}", data=json.dumps(roles))
if random.choice([True, False]):
# randomize to apply roles randomly
roles = {
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
"Operations": [
{
"op": "add",
"path": "roles",
"value": [{"value": p.instance_profile_arn} for p in random.choices(profiles, k=2)],
}
],
}
provider.ws.api_client.do("PATCH", f"/api/2.0/preview/scim/v2/Groups/{ws_group.id}", data=json.dumps(roles))

yield profiles

Expand Down Expand Up @@ -517,6 +519,9 @@ def workspace_objects(ws: ImprovedWorkspaceClient, env: EnvironmentInfo) -> Work
random_group = random.choice([g[0] for g in env.groups])
_nb_path = f"/{env.test_uid}/{random_group.display_name}/nb-{nb_idx}.py"
ws.workspace.upload(path=_nb_path, content=io.BytesIO(b"print(1)"))
# TODO: add a proper test for this
# if random.choice([True, False]):
# ws.experiments.create_experiment(name=_nb_path) # create experiment to test nb-based experiments
_nb_obj = ws.workspace.get_status(_nb_path)
notebooks.append(_nb_obj)
ws.permissions.set(
Expand Down
12 changes: 7 additions & 5 deletions tests/integration/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,13 @@ def test_e2e(
toolkit.group_manager.migration_groups_provider.groups
)

for _info in toolkit.group_manager.migration_groups_provider.groups:
_ws = ws.groups.get(id=_info.workspace.id)
_backup = ws.groups.get(id=_info.backup.id)
_ws_members = sorted([m.value for m in _ws.members])
_backup_members = sorted([m.value for m in _backup.members])
assert _ws_members == _backup_members

logger.debug("Verifying that the groups were created - done")

toolkit.cleanup_inventory_table()
Expand Down Expand Up @@ -231,8 +238,3 @@ def test_e2e(
assert len(backup_groups) == 0

toolkit.cleanup_inventory_table()


def test_fixtures():
# fake test to verify the fixtures
assert 1 == 1 # noqa: PLR0133

0 comments on commit 7a62801

Please sign in to comment.