diff --git a/backend/dataall/core/environment/cdk/environment_stack.py b/backend/dataall/core/environment/cdk/environment_stack.py index 9198bbf89..f0161f14b 100644 --- a/backend/dataall/core/environment/cdk/environment_stack.py +++ b/backend/dataall/core/environment/cdk/environment_stack.py @@ -621,3 +621,34 @@ def create_integration_tests_role(self): resources=['*'], ) ) + + self.test_role.add_to_policy( + iam.PolicyStatement( + actions=['ec2:Describe*', 'ec2:*Vpc*', 'ec2:*Subnet*', 'ec2:*Route*', 'ec2:*Tags*'], + effect=iam.Effect.ALLOW, + resources=[ + 'arn:aws:ec2:*:139956106467:route-table/*', + 'arn:aws:ec2:*:139956106467:security-group/*', + 'arn:aws:ec2:*:139956106467:vpc/*', + 'arn:aws:ec2:*:139956106467:security-group-rule/*', + 'arn:aws:ec2:*:139956106467:subnet/*', + 'arn:aws:ec2:*:139956106467:network-acl/*', + ], + ), + ) + + self.test_role.add_to_policy( + iam.PolicyStatement( + actions=['ec2:Describe*'], + effect=iam.Effect.ALLOW, + resources=['*'], + ), + ) + + self.test_role.add_to_policy( + iam.PolicyStatement( + actions=['cloudformation:Describe*'], + effect=iam.Effect.ALLOW, + resources=['arn:aws:cloudformation:*:139956106467:stack/*/*'], + ), + ) diff --git a/backend/dataall/modules/notebooks/db/notebook_repository.py b/backend/dataall/modules/notebooks/db/notebook_repository.py index c7ad805c9..f5219fdde 100644 --- a/backend/dataall/modules/notebooks/db/notebook_repository.py +++ b/backend/dataall/modules/notebooks/db/notebook_repository.py @@ -46,10 +46,12 @@ def _query_user_notebooks(self, username, groups, filter) -> Query: ) ) if filter and filter.get('term'): + term = filter['term'] query = query.filter( or_( - SagemakerNotebook.description.ilike(filter.get('term') + '%%'), - SagemakerNotebook.label.ilike(filter.get('term') + '%%'), + SagemakerNotebook.description.ilike(term + '%%'), + SagemakerNotebook.label.ilike(term + '%%'), + SagemakerNotebook.tags.contains(f'{{{term}}}'), ) ) return query.order_by(SagemakerNotebook.label) diff --git a/deploy/stacks/param_store_stack.py b/deploy/stacks/param_store_stack.py index b7008f7da..5b351d120 100644 --- a/deploy/stacks/param_store_stack.py +++ b/deploy/stacks/param_store_stack.py @@ -125,7 +125,7 @@ def __init__( f'toolingAccountParam{envname}', parameter_name=f'/dataall/{envname}/toolingAccount', string_value=str(tooling_account_id), - description=f'Stores AWS account if for the tooling account that hosts the code for environment {envname}', + description=f'Store AWS account if for the tooling account that hosts the code for environment {envname}', ) if prod_sizing: diff --git a/tests_new/integration_tests/client.py b/tests_new/integration_tests/client.py index 9006fc264..f92f6dbd5 100644 --- a/tests_new/integration_tests/client.py +++ b/tests_new/integration_tests/client.py @@ -2,7 +2,7 @@ import boto3 import os from munch import DefaultMunch - +from retrying import retry from integration_tests.errors import GqlError ENVNAME = os.getenv('ENVNAME', 'dev') @@ -14,6 +14,17 @@ def __init__(self, username, password): self.password = password self.token = self._get_jwt_token() + @staticmethod + def _retry_if_connection_error(exception): + """Return True if we should retry, False otherwise""" + return isinstance(exception, requests.exceptions.ConnectionError) or isinstance(exception, requests.ReadTimeout) + + @retry( + retry_on_exception=_retry_if_connection_error, + stop_max_attempt_number=3, + wait_random_min=1000, + wait_random_max=3000, + ) def query(self, query: str): graphql_endpoint = os.path.join(os.environ['API_ENDPOINT'], 'graphql', 'api') headers = {'AccessKeyId': 'none', 'SecretKey': 'none', 'authorization': self.token} diff --git a/tests_new/integration_tests/core/environment/global_conftest.py b/tests_new/integration_tests/core/environment/global_conftest.py index 05876ce81..b700afbeb 100644 --- a/tests_new/integration_tests/core/environment/global_conftest.py +++ b/tests_new/integration_tests/core/environment/global_conftest.py @@ -51,19 +51,13 @@ def session_env1(client1, group1, org1, session_id, testdata): delete_env(client1, env) -@pytest.fixture(scope='session') -def session_env1_integration_role_arn(session_env1): - yield f'arn:aws:iam::{session_env1.AwsAccountId}:role/dataall-integration-tests-role-{session_env1.region}' - - -@pytest.fixture(scope='session') -def session_env1_aws_client(session_env1, session_id, session_env1_integration_role_arn): +def get_environment_aws_session(role_arn, env): try: base_session = boto3.Session() - response = base_session.client('sts', region_name=session_env1.region).assume_role( - RoleArn=session_env1_integration_role_arn, RoleSessionName=session_env1_integration_role_arn.split('/')[1] + response = base_session.client('sts', region_name=env.region).assume_role( + RoleArn=role_arn, RoleSessionName=role_arn.split('/')[1] ) - yield boto3.Session( + return boto3.Session( aws_access_key_id=response['Credentials']['AccessKeyId'], aws_secret_access_key=response['Credentials']['SecretAccessKey'], aws_session_token=response['Credentials']['SessionToken'], @@ -73,29 +67,24 @@ def session_env1_aws_client(session_env1, session_id, session_env1_integration_r raise +@pytest.fixture(scope='session') +def session_env1_integration_role_arn(session_env1): + return f'arn:aws:iam::{session_env1.AwsAccountId}:role/dataall-integration-tests-role-{session_env1.region}' + + +@pytest.fixture(scope='session') +def session_env1_aws_client(session_env1, session_env1_integration_role_arn): + return get_environment_aws_session(session_env1_integration_role_arn, session_env1) + + @pytest.fixture(scope='session') def persistent_env1_integration_role_arn(persistent_env1): - yield f'arn:aws:iam::{persistent_env1.AwsAccountId}:role/dataall-integration-tests-role-{persistent_env1.region}' + return f'arn:aws:iam::{persistent_env1.AwsAccountId}:role/dataall-integration-tests-role-{persistent_env1.region}' @pytest.fixture(scope='session') -def persistent_env1_aws_client(persistent_env1, session_id): - try: - base_session = boto3.Session() - role_arn = ( - f'arn:aws:iam::{persistent_env1.AwsAccountId}:role/dataall-integration-tests-role-{persistent_env1.region}' - ) - response = base_session.client('sts', region_name=persistent_env1.region).assume_role( - RoleArn=role_arn, RoleSessionName=role_arn.split('/')[1] - ) - yield boto3.Session( - aws_access_key_id=response['Credentials']['AccessKeyId'], - aws_secret_access_key=response['Credentials']['SecretAccessKey'], - aws_session_token=response['Credentials']['SessionToken'], - ) - except: - log.exception('Failed to assume environment integration test role') - raise +def persistent_env1_aws_client(persistent_env1, persistent_env1_integration_role_arn): + return get_environment_aws_session(persistent_env1_integration_role_arn, persistent_env1) @pytest.fixture(scope='session') diff --git a/tests_new/integration_tests/core/environment/queries.py b/tests_new/integration_tests/core/environment/queries.py index 7682a6e78..a965d1588 100644 --- a/tests_new/integration_tests/core/environment/queries.py +++ b/tests_new/integration_tests/core/environment/queries.py @@ -61,6 +61,9 @@ def create_environment(client, name, group, organizationUri, awsAccountId, regio 'region': region, 'description': 'Created for integration testing', 'tags': tags, + 'parameters': [ + {'key': 'notebooksEnabled', 'value': 'true'}, + ], 'type': 'IntegrationTesting', } }, diff --git a/tests_new/integration_tests/core/stack/utils.py b/tests_new/integration_tests/core/stack/utils.py index 3608244ae..f04e9aecc 100644 --- a/tests_new/integration_tests/core/stack/utils.py +++ b/tests_new/integration_tests/core/stack/utils.py @@ -16,3 +16,15 @@ def check_stack_in_progress(client, env_uri, stack_uri, target_uri=None, target_ @poller(check_success=lambda stack: not is_stack_in_progress(stack), timeout=600) def check_stack_ready(client, env_uri, stack_uri, target_uri=None, target_type='environment'): return get_stack(client, env_uri, stack_uri, target_uri or env_uri, target_type) + + +def wait_stack_delete_complete(cf_client, stack_name): + # Wait for the stack to be deleted + waiter = cf_client.get_waiter('stack_delete_complete') + waiter.wait( + StackName=stack_name, + WaiterConfig={ + 'Delay': 20, # Delay between each poll request (in seconds) + 'MaxAttempts': 60, # Maximum number of attempts before giving up + }, + ) diff --git a/tests_new/integration_tests/modules/notebooks/__init__.py b/tests_new/integration_tests/modules/notebooks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_new/integration_tests/modules/notebooks/aws_clients.py b/tests_new/integration_tests/modules/notebooks/aws_clients.py new file mode 100644 index 000000000..c1c429758 --- /dev/null +++ b/tests_new/integration_tests/modules/notebooks/aws_clients.py @@ -0,0 +1,97 @@ +import logging +import json +from typing import Any, Dict, Optional +from botocore.exceptions import ClientError + +log = logging.getLogger(__name__) + + +class VpcClient: + def __init__(self, session, region): + self._client = session.client('ec2', region_name=region) + self._region = region + + def create_vpc(self, vpc_name: str, cidr: str = '10.0.0.0/28') -> str: + log.info('Creating VPC..') + response = self._client.create_vpc( + CidrBlock=cidr, + TagSpecifications=[ + { + 'ResourceType': 'vpc', + 'Tags': [ + {'Key': 'Name', 'Value': vpc_name}, + ], + }, + ], + ) + vpc_id = response['Vpc']['VpcId'] + log.info(f'VPC created with ID: {vpc_id=}') + return vpc_id + + def delete_vpc(self, vpc_id: str) -> Dict[str, Any]: + log.info('Deleting VPC..') + response = self._client.delete_vpc(VpcId=vpc_id) + log.info(f'VPC deleted with ID: {vpc_id=}') + return response + + def get_vpc_id_by_name(self, vpc_name: str) -> Optional[str]: + log.info('Getting VPC ID by name..') + response = self._client.describe_vpcs(Filters=[{'Name': 'tag:Name', 'Values': [vpc_name]}]) + if len(response['Vpcs']) == 0: + log.info(f'VPC with name {vpc_name} not found') + return None + vpc_id = response['Vpcs'][0]['VpcId'] + log.info(f'VPC ID found: {vpc_id=}') + return vpc_id + + def delete_vpc_by_name(self, vpc_name: str): + try: + vpc_id = self.get_vpc_id_by_name(vpc_name) + if vpc_id: + self.delete_vpc(vpc_id) + return True + except Exception as e: + log.error(f'Error deleting vpc {vpc_name=}. Error Message: {e}') + + def create_subnet(self, vpc_id: str, subnet_name: str, cidr: str) -> str: + log.info('Creating subnet..') + response = self._client.create_subnet( + VpcId=vpc_id, + CidrBlock=cidr, + TagSpecifications=[ + { + 'ResourceType': 'subnet', + 'Tags': [ + {'Key': 'Name', 'Value': subnet_name}, + ], + }, + ], + ) + subnet_id = response['Subnet']['SubnetId'] + log.info(f'Subnet created with ID: {subnet_id=}') + return subnet_id + + def get_subnet_id_by_name(self, subnet_name: str) -> Optional[str]: + log.info('Getting subnet ID by name..') + response = self._client.describe_subnets(Filters=[{'Name': 'tag:Name', 'Values': [subnet_name]}]) + if len(response['Subnets']) == 0: + log.info(f'Subnet with name {subnet_name} not found') + return None + subnet_id = response['Subnets'][0]['SubnetId'] + log.info(f'Subnet ID found: {subnet_id=}') + return subnet_id + + def delete_subnet(self, subnet_id: str) -> Dict[str, Any]: + log.info('Deleting subnet..') + response = self._client.delete_subnet(SubnetId=subnet_id) + log.info(f'Subnet deleted with ID: {subnet_id=}') + return response + + def delete_subnet_by_name(self, subnet_name: str): + try: + subnet_id = self.get_subnet_id_by_name(subnet_name) + if subnet_id: + self.delete_subnet(subnet_id) + return True + except Exception as e: + log.error(f'Error deleting subnet {subnet_name=}. Error Message: {e}') diff --git a/tests_new/integration_tests/modules/notebooks/conftest.py b/tests_new/integration_tests/modules/notebooks/conftest.py new file mode 100644 index 000000000..975f621f0 --- /dev/null +++ b/tests_new/integration_tests/modules/notebooks/conftest.py @@ -0,0 +1,167 @@ +import logging + +import pytest + +from integration_tests.client import GqlError +from integration_tests.modules.notebooks.queries import ( + create_sagemaker_notebook, + get_sagemaker_notebook, + delete_sagemaker_notebook, + list_sagemaker_notebooks, +) +from integration_tests.core.stack.utils import check_stack_ready, wait_stack_delete_complete + +from integration_tests.modules.notebooks.aws_clients import VpcClient + +log = logging.getLogger(__name__) + + +def create_notebook(client, group, env_uri, vpc_id, subnet_id, tags=[], name='TestNotebook'): + notebook = create_sagemaker_notebook( + client=client, + name=name, + group=group, + environmentUri=env_uri, + VpcId=vpc_id, + SubnetId=subnet_id, + tags=tags, + ) + check_stack_ready( + client=client, + env_uri=env_uri, + stack_uri=notebook.stack.stackUri, + target_uri=notebook.notebookUri, + target_type='notebook', + ) + return get_sagemaker_notebook(client, notebook.notebookUri) + + +def delete_notebook(client, env_uri, notebook): + check_stack_ready( + client=client, + env_uri=env_uri, + stack_uri=notebook.stack.stackUri, + target_uri=notebook.notebookUri, + target_type='notebook', + ) + try: + return delete_sagemaker_notebook(client, notebook.notebookUri) + except GqlError: + log.exception('unexpected error when deleting environment') + return False + + +""" +Session envs persist accross the duration of the whole integ test suite and are meant to make the test suite run faster (env creation takes ~2 mins). +For this reason they must stay immutable as changes to them will affect the rest of the tests. +""" + + +@pytest.fixture(scope='session') +def session_notebook1(client1, group1, session_env1, session_id, session_env1_aws_client): + resource_name = 'sessionnotebook1' + notebook = None + try: + vpc_client = VpcClient(session=session_env1_aws_client, region=session_env1['region']) + vpc_id = vpc_client.create_vpc(vpc_name=resource_name, cidr='172.31.0.0/26') + subnet_id = vpc_client.create_subnet(vpc_id=vpc_id, subnet_name=resource_name, cidr='172.31.0.0/28') + notebook = create_notebook( + client1, + group=group1, + env_uri=session_env1['environmentUri'], + tags=[session_id], + vpc_id=vpc_id, + subnet_id=subnet_id, + ) + yield notebook + finally: + if notebook: + delete_notebook(client1, session_env1['environmentUri'], notebook) + wait_stack_delete_complete( + session_env1_aws_client.client('cloudformation', region_name=session_env1['region']), + notebook.stack.name, + ) + + vpc_client = VpcClient(session=session_env1_aws_client, region=session_env1['region']) + vpc_client.delete_subnet_by_name(resource_name) + vpc_client.delete_vpc_by_name(resource_name) + + +""" +Temp envs will be created and deleted per test, use with caution as they might increase the runtime of the test suite. +They are suitable to test env mutations. +""" + + +@pytest.fixture(scope='function') +def temp_notebook1(client1, group1, session_env1, session_id, session_env1_aws_client): + resource_name = 'tempnotebook1' + notebook = None + try: + vpc_client = VpcClient(session=session_env1_aws_client, region=session_env1['region']) + vpc_id = vpc_client.create_vpc(vpc_name=resource_name, cidr='172.31.0.0/26') + subnet_id = vpc_client.create_subnet(vpc_id=vpc_id, subnet_name=resource_name, cidr='172.31.0.0/28') + + notebook = create_notebook( + client1, + group=group1, + env_uri=session_env1['environmentUri'], + tags=[session_id], + vpc_id=vpc_id, + subnet_id=subnet_id, + ) + yield notebook + finally: + if notebook: + delete_notebook(client1, session_env1['environmentUri'], notebook) + vpc_client = VpcClient(session=session_env1_aws_client, region=session_env1['region']) + vpc_client.delete_subnet_by_name(resource_name) + vpc_client.delete_vpc_by_name(resource_name) + + +""" +Persistent environments must always be present (if not i.e first run they will be created but won't be removed). +They are suitable for testing backwards compatibility. +""" + + +def get_or_create_persistent_notebook(resource_name, client, group, env, session): + notebooks = list_sagemaker_notebooks(client, term=resource_name).nodes + if notebooks: + return notebooks[0] + else: + vpc_client = VpcClient(session=session, region=env['region']) + + vpc_id = ( + vpc_client.get_vpc_id_by_name(resource_name) + if vpc_client.get_vpc_id_by_name(resource_name) + else vpc_client.create_vpc(vpc_name=resource_name, cidr='172.31.1.0/26') + ) + + subnet_id = ( + vpc_client.get_subnet_id_by_name(resource_name) + if vpc_client.get_subnet_id_by_name(resource_name) + else vpc_client.create_subnet(vpc_id=vpc_id, subnet_name=resource_name, cidr='172.31.1.0/28') + ) + + notebook = create_notebook( + client, + group=group, + env_uri=env['environmentUri'], + tags=[resource_name], + vpc_id=vpc_id, + subnet_id=subnet_id, + name=resource_name, + ) + if notebook.stack.status in ['CREATE_COMPLETE', 'UPDATE_COMPLETE']: + return notebook + else: + delete_notebook(client, env['environmentUri'], notebook) + raise RuntimeError(f'failed to create {resource_name=} {notebook=}') + + +@pytest.fixture(scope='session') +def persistent_notebook1(client1, group1, persistent_env1, persistent_env1_aws_client): + return get_or_create_persistent_notebook( + 'persistent_notebook1', client1, group1, persistent_env1, persistent_env1_aws_client + ) diff --git a/tests_new/integration_tests/modules/notebooks/queries.py b/tests_new/integration_tests/modules/notebooks/queries.py new file mode 100644 index 000000000..e8ee95a68 --- /dev/null +++ b/tests_new/integration_tests/modules/notebooks/queries.py @@ -0,0 +1,179 @@ +NOTEBOOK_TYPE = """ +notebookUri +name +owner +description +label +created +tags +NotebookInstanceStatus +SamlAdminGroupName +RoleArn +VpcId +SubnetId +VolumeSizeInGB +InstanceType +environment { + label + name + environmentUri + AwsAccountId + region +} +organization { + label + name + organizationUri +} +stack { + stack + name + status + stackUri + targetUri + accountid + region + stackid + link + outputs + resources +} +""" + + +def create_sagemaker_notebook( + client, + name, + group, + environmentUri, + tags, + VpcId, + SubnetId, + VolumeSizeInGB=32, + InstanceType='ml.t3.medium', +): + query = { + 'operationName': 'CreateSagemakerNotebook', + 'variables': { + 'input': { + 'label': name, + 'SamlAdminGroupName': group, + 'environmentUri': environmentUri, + 'VpcId': VpcId, + 'SubnetId': SubnetId, + 'description': 'Created for integration testing', + 'tags': tags, + 'VolumeSizeInGB': VolumeSizeInGB, + 'InstanceType': InstanceType, + } + }, + 'query': f""" + mutation CreateSagemakerNotebook($input: NewSagemakerNotebookInput) {{ + createSagemakerNotebook(input: $input) {{ + {NOTEBOOK_TYPE} + }} + }} + """, + } + + response = client.query(query=query) + return response.data.createSagemakerNotebook + + +def get_sagemaker_notebook(client, notebookUri): + query = { + 'operationName': 'getSagemakerNotebook', + 'variables': {'notebookUri': notebookUri}, + 'query': f""" + query getSagemakerNotebook($notebookUri: String!) {{ + getSagemakerNotebook(notebookUri: $notebookUri) {{ + {NOTEBOOK_TYPE} + }} + }} + """, + } + response = client.query(query=query) + return response.data.getSagemakerNotebook + + +def delete_sagemaker_notebook(client, notebookUri, deleteFromAWS=True): + query = { + 'operationName': 'deleteSagemakerNotebook', + 'variables': { + 'notebookUri': notebookUri, + 'deleteFromAWS': deleteFromAWS, + }, + 'query': """ + mutation deleteSagemakerNotebook( + $notebookUri: String! + $deleteFromAWS: Boolean + ) { + deleteSagemakerNotebook( + notebookUri: $notebookUri + deleteFromAWS: $deleteFromAWS + ) + } + """, + } + response = client.query(query=query) + return response + + +def list_sagemaker_notebooks(client, term=''): + query = { + 'operationName': 'ListSagemakerNotebooks', + 'variables': { + 'filter': {'term': term}, + }, + 'query': f""" + query ListSagemakerNotebooks($filter: SagemakerNotebookFilter) {{ + listSagemakerNotebooks(filter: $filter) {{ + count + page + pages + hasNext + hasPrevious + nodes {{ + {NOTEBOOK_TYPE} + }} + }} + }} + """, + } + + response = client.query(query=query) + return response.data.listSagemakerNotebooks + + +def stop_sagemaker_notebook(client, notebookUri): + query = { + 'operationName': 'StopSagemakerNotebook', + 'variables': { + 'notebookUri': notebookUri, + }, + 'query': f""" + mutation StopSagemakerNotebook($notebookUri: String!) {{ + stopSagemakerNotebook(notebookUri: $notebookUri) + }} + """, + } + + response = client.query(query=query) + return response.data.stopSagemakerNotebook + + +def start_sagemaker_notebook(client, notebookUri): + query = { + 'operationName': 'StartSagemakerNotebook', + 'variables': { + 'notebookUri': notebookUri, + }, + 'query': f""" + mutation StartSagemakerNotebook($notebookUri: String!) {{ + startSagemakerNotebook(notebookUri: $notebookUri) + }} + """, + } + + response = client.query(query=query) + return response.data.startSagemakerNotebook diff --git a/tests_new/integration_tests/modules/notebooks/test_notebooks.py b/tests_new/integration_tests/modules/notebooks/test_notebooks.py new file mode 100644 index 000000000..14a9b3c4c --- /dev/null +++ b/tests_new/integration_tests/modules/notebooks/test_notebooks.py @@ -0,0 +1,91 @@ +import logging +import time +from assertpy import assert_that + +import re +from integration_tests.utils import poller + +from integration_tests.core.stack.queries import update_stack +from integration_tests.core.stack.utils import check_stack_in_progress, check_stack_ready +from integration_tests.errors import GqlError +from tests_new.integration_tests.modules.notebooks.queries import ( + get_sagemaker_notebook, + list_sagemaker_notebooks, + start_sagemaker_notebook, + stop_sagemaker_notebook, +) + +log = logging.getLogger(__name__) + + +def is_notebook_ready(notebook): + return re.match(r'Stopping|Pending|Deleting|Updating', notebook.NotebookInstanceStatus, re.IGNORECASE) + + +@poller(check_success=lambda notebook: not is_notebook_ready(notebook), timeout=600) +def check_notebook_ready(client, notebook_uri): + return get_sagemaker_notebook(client, notebook_uri) + + +def test_create_notebook(session_notebook1): + assert_that(session_notebook1.stack.status).is_in('CREATE_COMPLETE', 'UPDATE_COMPLETE') + + +def test_list_notebooks_authorized(client1, session_notebook1, session_id): + assert_that(list_sagemaker_notebooks(client1, term=session_id).nodes).is_length(1) + + +def test_list_notebooks_unauthorized(client2, session_notebook1, session_id): + assert_that(list_sagemaker_notebooks(client2, term=session_id).nodes).is_length(0) + + +def test_stop_notebook_unauthorized(client1, client2, session_notebook1): + assert_that(stop_sagemaker_notebook).raises(GqlError).when_called_with( + client2, session_notebook1.notebookUri + ).contains('UnauthorizedOperation', session_notebook1.notebookUri) + notebook = get_sagemaker_notebook(client1, session_notebook1.notebookUri) + assert_that(notebook.NotebookInstanceStatus).is_equal_to('InService') + + +def test_stop_notebook_authorized(client1, session_notebook1): + assert_that(stop_sagemaker_notebook(client1, notebookUri=session_notebook1.notebookUri)).is_equal_to('Stopping') + notebook = get_sagemaker_notebook(client1, session_notebook1.notebookUri) + assert_that(notebook.NotebookInstanceStatus).matches(r'Stopping|Stopped') + check_notebook_ready(client1, session_notebook1.notebookUri) + + +def test_start_notebook_unauthorized(client1, client2, session_notebook1): + assert_that(start_sagemaker_notebook).raises(GqlError).when_called_with( + client2, session_notebook1.notebookUri + ).contains('UnauthorizedOperation', session_notebook1.notebookUri) + notebook = get_sagemaker_notebook(client1, session_notebook1.notebookUri) + assert_that(notebook.NotebookInstanceStatus).is_equal_to('Stopped') + + +def test_start_notebook_authorized(client1, session_notebook1): + assert_that(start_sagemaker_notebook(client1, notebookUri=session_notebook1.notebookUri)).is_equal_to('Starting') + notebook = get_sagemaker_notebook(client1, session_notebook1.notebookUri) + assert_that(notebook.NotebookInstanceStatus).matches(r'Pending|InService') + check_notebook_ready(client1, session_notebook1.notebookUri) + + +def test_persistent_notebook_update(client1, persistent_notebook1): + # wait for stack to get to a final state before triggering an update + stack_uri = persistent_notebook1.stack.stackUri + env_uri = persistent_notebook1.environment.environmentUri + notebook_uri = persistent_notebook1.notebookUri + target_type = 'notebook' + check_stack_ready( + client=client1, env_uri=env_uri, stack_uri=stack_uri, target_uri=notebook_uri, target_type=target_type + ) + update_stack(client1, notebook_uri, target_type) + # wait for stack to move to "in_progress" state + # TODO: Come up with better way to handle wait in progress if applicable + # Use time.sleep() instead of poller b/c of case where no changes founds (i.e. no update required) + # check_stack_in_progress(client1, env_uri, stack_uri) + time.sleep(120) + + stack = check_stack_ready( + client1, env_uri=env_uri, stack_uri=stack_uri, target_uri=notebook_uri, target_type=target_type + ) + assert_that(stack.status).is_in('CREATE_COMPLETE', 'UPDATE_COMPLETE') diff --git a/tests_new/integration_tests/requirements.txt b/tests_new/integration_tests/requirements.txt index 69d909121..35da0738e 100644 --- a/tests_new/integration_tests/requirements.txt +++ b/tests_new/integration_tests/requirements.txt @@ -8,3 +8,4 @@ pytest-dependency==0.5.1 requests==2.32.2 dataclasses-json==0.6.6 werkzeug==3.0.3 +retrying==1.3.4