diff --git a/backend/dataall/api/Objects/DatasetProfiling/mutations.py b/backend/dataall/api/Objects/DatasetProfiling/mutations.py index 5876c81a7..559129dc8 100644 --- a/backend/dataall/api/Objects/DatasetProfiling/mutations.py +++ b/backend/dataall/api/Objects/DatasetProfiling/mutations.py @@ -7,13 +7,3 @@ type=gql.Ref('DatasetProfilingRun'), resolver=start_profiling_run, ) - -updateDatasetProfilingRunResults = gql.MutationField( - name='updateDatasetProfilingRunResults', - args=[ - gql.Argument(name='profilingRunUri', type=gql.NonNullableType(gql.String)), - gql.Argument(name='results', type=gql.NonNullableType(gql.String)), - ], - type=gql.Ref('DatasetProfilingRun'), - resolver=update_profiling_run_results, -) diff --git a/backend/dataall/api/Objects/DatasetProfiling/queries.py b/backend/dataall/api/Objects/DatasetProfiling/queries.py index 9ab3eb2bb..1cbe06764 100644 --- a/backend/dataall/api/Objects/DatasetProfiling/queries.py +++ b/backend/dataall/api/Objects/DatasetProfiling/queries.py @@ -2,24 +2,6 @@ from .resolvers import * -getDatasetProfilingRun = gql.QueryField( - name='getDatasetProfilingRun', - args=[gql.Argument(name='profilingRunUri', type=gql.NonNullableType(gql.String))], - type=gql.Ref('DatasetProfilingRun'), - resolver=get_profiling_run, -) - - -listDatasetProfilingRuns = gql.QueryField( - name='listDatasetProfilingRuns', - args=[ - gql.Argument(name='datasetUri', type=gql.NonNullableType(gql.String)), - gql.Argument(name='filter', type=gql.Ref('DatasetProfilingRunFilter')), - ], - type=gql.Ref('DatasetProfilingRunSearchResults'), - resolver=list_profiling_runs, -) - listDatasetTableProfilingRuns = gql.QueryField( name='listDatasetTableProfilingRuns', args=[gql.Argument(name='tableUri', type=gql.NonNullableType(gql.String))], @@ -31,5 +13,5 @@ name='getDatasetTableProfilingRun', args=[gql.Argument(name='tableUri', type=gql.NonNullableType(gql.String))], type=gql.Ref('DatasetProfilingRun'), - resolver=get_last_table_profiling_run, + resolver=get_dataset_table_profiling_run, ) diff --git a/backend/dataall/api/Objects/DatasetProfiling/resolvers.py b/backend/dataall/api/Objects/DatasetProfiling/resolvers.py index 678a8cba6..11c19b888 100644 --- a/backend/dataall/api/Objects/DatasetProfiling/resolvers.py +++ b/backend/dataall/api/Objects/DatasetProfiling/resolvers.py @@ -1,6 +1,7 @@ import json import logging +from .... import db from ....api.context import Context from ....aws.handlers.service_handlers import Worker from ....aws.handlers.sts import SessionHelper @@ -19,7 +20,30 @@ def resolve_dataset(context, source: models.DatasetProfilingRun): ) +def resolve_profiling_run_status(context: Context, source: models.DatasetProfilingRun): + if not source: + return None + with context.engine.scoped_session() as session: + task = models.Task( + targetUri=source.profilingRunUri, action='glue.job.profiling_run_status' + ) + session.add(task) + Worker.queue(engine=context.engine, task_ids=[task.taskUri]) + return source.status + + +def resolve_profiling_results(context: Context, source: models.DatasetProfilingRun): + if not source or source.results == {}: + return None + else: + return json.dumps(source.results) + + def start_profiling_run(context: Context, source, input: dict = None): + """ + Triggers profiling jobs on a Table. + Only Dataset owners with PROFILE_DATASET_TABLE can perform this action + """ with context.engine.scoped_session() as session: ResourcePolicy.check_user_resource_permission( @@ -48,47 +72,14 @@ def start_profiling_run(context: Context, source, input: dict = None): return run -def get_profiling_run_status(context: Context, source: models.DatasetProfilingRun): - if not source: - return None - with context.engine.scoped_session() as session: - task = models.Task( - targetUri=source.profilingRunUri, action='glue.job.profiling_run_status' - ) - session.add(task) - Worker.queue(engine=context.engine, task_ids=[task.taskUri]) - return source.status - - -def get_profiling_results(context: Context, source: models.DatasetProfilingRun): - if not source or source.results == {}: - return None - else: - return json.dumps(source.results) - - -def update_profiling_run_results(context: Context, source, profilingRunUri, results): - with context.engine.scoped_session() as session: - run = api.DatasetProfilingRun.update_run( - session=session, profilingRunUri=profilingRunUri, results=results - ) - return run - - -def list_profiling_runs(context: Context, source, datasetUri=None): - with context.engine.scoped_session() as session: - return api.DatasetProfilingRun.list_profiling_runs(session, datasetUri) - - -def get_profiling_run(context: Context, source, profilingRunUri=None): - with context.engine.scoped_session() as session: - return api.DatasetProfilingRun.get_profiling_run( - session=session, profilingRunUri=profilingRunUri - ) - - -def get_last_table_profiling_run(context: Context, source, tableUri=None): +def get_dataset_table_profiling_run(context: Context, source, tableUri=None): + """ + Shows the results of the last profiling job on a Table. + For datasets "Unclassified" all users can perform this action. + For datasets "Secret" or "Official", only users with PREVIEW_DATASET_TABLE permissions can perform this action. + """ with context.engine.scoped_session() as session: + _check_preview_permissions_if_needed(context=context, session=session, tableUri=tableUri) run: models.DatasetProfilingRun = ( api.DatasetProfilingRun.get_table_last_profiling_run( session=session, tableUri=tableUri @@ -102,7 +93,7 @@ def get_last_table_profiling_run(context: Context, source, tableUri=None): environment = api.Environment.get_environment_by_uri( session, dataset.environmentUri ) - content = get_profiling_results_from_s3( + content = _get_profiling_results_from_s3( environment, dataset, table, run ) if content: @@ -121,7 +112,7 @@ def get_last_table_profiling_run(context: Context, source, tableUri=None): return run -def get_profiling_results_from_s3(environment, dataset, table, run): +def _get_profiling_results_from_s3(environment, dataset, table, run): s3 = SessionHelper.remote_session(environment.AwsAccountId).client( 's3', region_name=environment.region ) @@ -141,7 +132,32 @@ def get_profiling_results_from_s3(environment, dataset, table, run): def list_table_profiling_runs(context: Context, source, tableUri=None): + """ + Lists the runs of a profiling job on a Table. + For datasets "Unclassified" all users can perform this action. + For datasets "Secret" or "Official", only users with PREVIEW_DATASET_TABLE permissions can perform this action. + """ with context.engine.scoped_session() as session: + _check_preview_permissions_if_needed(context=context, session=session, tableUri=tableUri) return api.DatasetProfilingRun.list_table_profiling_runs( session=session, tableUri=tableUri, filter={} ) + + +def _check_preview_permissions_if_needed(context, session, tableUri): + table: models.DatasetTable = db.api.DatasetTable.get_dataset_table_by_uri( + session, tableUri + ) + dataset = db.api.Dataset.get_dataset_by_uri(session, table.datasetUri) + if ( + dataset.confidentiality + != models.ConfidentialityClassification.Unclassified.value + ): + ResourcePolicy.check_user_resource_permission( + session=session, + username=context.username, + groups=context.groups, + resource_uri=table.tableUri, + permission_name=permissions.PREVIEW_DATASET_TABLE, + ) + return True diff --git a/backend/dataall/api/Objects/DatasetProfiling/schema.py b/backend/dataall/api/Objects/DatasetProfiling/schema.py index f6fe9c575..88edbc403 100644 --- a/backend/dataall/api/Objects/DatasetProfiling/schema.py +++ b/backend/dataall/api/Objects/DatasetProfiling/schema.py @@ -1,8 +1,8 @@ from ... import gql from .resolvers import ( resolve_dataset, - get_profiling_run_status, - get_profiling_results, + resolve_profiling_run_status, + resolve_profiling_results, ) DatasetProfilingRun = gql.ObjectType( @@ -16,11 +16,11 @@ gql.Field(name='GlueTriggerName', type=gql.String), gql.Field(name='GlueTableName', type=gql.String), gql.Field(name='AwsAccountId', type=gql.String), - gql.Field(name='results', type=gql.String, resolver=get_profiling_results), + gql.Field(name='results', type=gql.String, resolver=resolve_profiling_results), gql.Field(name='created', type=gql.String), gql.Field(name='updated', type=gql.String), gql.Field(name='owner', type=gql.String), - gql.Field('status', type=gql.String, resolver=get_profiling_run_status), + gql.Field('status', type=gql.String, resolver=resolve_profiling_run_status), gql.Field(name='dataset', type=gql.Ref('Dataset'), resolver=resolve_dataset), ], ) diff --git a/tests/api/conftest.py b/tests/api/conftest.py index f3666d850..ef18094e9 100644 --- a/tests/api/conftest.py +++ b/tests/api/conftest.py @@ -187,6 +187,7 @@ def factory( name: str, owner: str, group: str, + confidentiality: str = None ) -> models.Dataset: key = f'{org.organizationUri}-{env.environmentUri}-{name}-{group}' if cache.get(key): @@ -290,6 +291,7 @@ def factory( 'environmentUri': env.environmentUri, 'SamlAdminGroupName': group or random_group(), 'organizationUri': org.organizationUri, + 'confidentiality': confidentiality or dataall.api.constants.ConfidentialityClassification.Unclassified.value }, ) print('==>', response) @@ -566,6 +568,49 @@ def factory(dataset: models.Dataset, name, username) -> models.DatasetTable: yield factory +@pytest.fixture(scope='module', autouse=True) +def table_with_permission(client, patch_es): + cache = {} + + def factory( + dataset: models.Dataset, + name: str, + owner: str, + group: str, + ) -> models.DatasetTable: + key = f'{dataset.datasetUri}-{name}' + if cache.get(key): + print('found in cache ', cache[key]) + return cache.get(key) + response = client.query( + """ + mutation CreateDatasetTable( + $datasetUri: String + $input: NewDatasetTableInput + ) { + createDatasetTable(datasetUri: $datasetUri, input: $input) { + tableUri + name + } + } + """, + username=owner, + groups=[group], + datasetUri=dataset.datasetUri, + input={ + 'label': f'{name}', + 'name': name, + 'description': f'test table {name}', + 'tags': random_tags(), + 'region': dataset.region + }, + ) + print('==>', response) + return response.data.createDatasetTable + + yield factory + + @pytest.fixture(scope='module', autouse=True) def org(client): cache = {} diff --git a/tests/api/test_dataset_profiling.py b/tests/api/test_dataset_profiling.py index ece463008..bcab1deb5 100644 --- a/tests/api/test_dataset_profiling.py +++ b/tests/api/test_dataset_profiling.py @@ -15,34 +15,32 @@ def env1(env, org1, user, group, tenant): env1 = env(org1, 'dev', user.userName, group.name, '111111111111', 'eu-west-1') yield env1 +@pytest.fixture(scope='module', autouse=True) +def org2(org, user2, group2, tenant): + org2 = org('testorg2', user2.userName, group2.name) + yield org2 -@pytest.fixture(scope='module') -def dataset1(env1, org1, dataset, group, user) -> dataall.db.models.Dataset: - yield dataset( - org=org1, env=env1, name='dataset1', owner=user.userName, group=group.name - ) +@pytest.fixture(scope='module', autouse=True) +def env2(env, org2, user2, group2, tenant): + env2 = env(org2, 'dev2', user2.userName, group2.name, '2222222222', 'eu-west-1') + yield env2 -def test_add_tables(table, dataset1, db): - for i in range(0, 10): - table(dataset=dataset1, name=f'table{i+1}', username=dataset1.owner) - - with db.scoped_session() as session: - nb = session.query(dataall.db.models.DatasetTable).count() - assert nb == 10 +@pytest.fixture(scope='module') +def dataset1(env1, org1, dataset, group, user) -> dataall.db.models.Dataset: + dataset1 = dataset( + org=org1, env=env1, name='dataset1', owner=user.userName, group=group.name, + confidentiality=dataall.api.constants.ConfidentialityClassification.Secret.value + ) + yield dataset1 -def update_runs(db, runs): - with db.scoped_session() as session: - for run in runs: - run = session.query(dataall.db.models.DatasetProfilingRun).get( - run['profilingRunUri'] - ) - run.status = 'SUCCEEDED' - session.commit() +@pytest.fixture(scope='module') +def table1(dataset1, table_with_permission, group, user): + yield table_with_permission(dataset=dataset1, name="table1", owner=user.userName, group=group.name) -def test_start_profiling(org1, env1, dataset1, client, module_mocker, db, user, group): +def test_start_profiling_run_authorized(org1, env1, dataset1, table1, client, module_mocker, db, user, group): module_mocker.patch('requests.post', return_value=True) module_mocker.patch( 'dataall.aws.handlers.service_handlers.Worker.process', return_value=True @@ -60,7 +58,7 @@ def test_start_profiling(org1, env1, dataset1, client, module_mocker, db, user, } """, username=user.userName, - input={'datasetUri': dataset1.datasetUri, 'GlueTableName': 'table1'}, + input={'datasetUri': dataset1.datasetUri, 'GlueTableName': table1.name}, groups=[group.name], ) profiling = response.data.startDatasetProfilingRun @@ -73,73 +71,64 @@ def test_start_profiling(org1, env1, dataset1, client, module_mocker, db, user, session.commit() -def test_list_runs(client, dataset1, env1, group): - runs = list_profiling_runs(client, dataset1, group) - assert len(runs) == 1 - - -def list_profiling_runs(client, dataset1, group): +def test_start_profiling_run_unauthorized(org2, env2, dataset1, table1, client, module_mocker, db, user2, group2): + module_mocker.patch('requests.post', return_value=True) + module_mocker.patch( + 'dataall.aws.handlers.service_handlers.Worker.process', return_value=True + ) + dataset1.GlueProfilingJobName = ('profile-job',) + dataset1.GlueProfilingTriggerSchedule = ('cron(* 2 * * ? *)',) + dataset1.GlueProfilingTriggerName = ('profile-job',) response = client.query( """ - query listDatasetProfilingRuns($datasetUri:String!){ - listDatasetProfilingRuns(datasetUri:$datasetUri){ - count - nodes{ + mutation startDatasetProfilingRun($input:StartDatasetProfilingRunInput){ + startDatasetProfilingRun(input:$input) + { profilingRunUri } } - } """, - datasetUri=dataset1.datasetUri, - groups=[group.name], + username=user2.userName, + input={'datasetUri': dataset1.datasetUri, 'GlueTableName': table1.name}, + groups=[group2.name], ) - return response.data.listDatasetProfilingRuns['nodes'] + assert 'UnauthorizedOperation' in response.errors[0].message -def test_get_profiling_run(client, dataset1, env1, module_mocker, db, group): - runs = list_profiling_runs(client, dataset1, group) +def test_get_table_profiling_run_authorized( + client, dataset1, table1, module_mocker, db, user, group +): module_mocker.patch( - 'dataall.aws.handlers.service_handlers.Worker.queue', - return_value=update_runs(db, runs), + 'dataall.api.Objects.DatasetProfiling.resolvers._get_profiling_results_from_s3', + return_value='{"results": "yes"}', ) + response = client.query( """ - query getDatasetProfilingRun($profilingRunUri:String!){ - getDatasetProfilingRun(profilingRunUri:$profilingRunUri){ + query getDatasetTableProfilingRun($tableUri:String!){ + getDatasetTableProfilingRun(tableUri:$tableUri){ profilingRunUri status + GlueTableName } } """, - profilingRunUri=runs[0]['profilingRunUri'], + tableUri=table1.tableUri, groups=[group.name], + username=user.userName, ) - assert ( - response.data.getDatasetProfilingRun['profilingRunUri'] - == runs[0]['profilingRunUri'] - ) - assert response.data.getDatasetProfilingRun['status'] == 'SUCCEEDED' - + assert response.data.getDatasetTableProfilingRun['profilingRunUri'] + assert response.data.getDatasetTableProfilingRun['status'] == 'RUNNING' + assert response.data.getDatasetTableProfilingRun['GlueTableName'] == 'table1' -def test_get_table_profiling_run( - client, dataset1, env1, module_mocker, table, db, group +def test_get_table_profiling_run_unauthorized( + client, dataset1, module_mocker, table1, db, user2, group2 ): module_mocker.patch( - 'dataall.api.Objects.DatasetProfiling.resolvers.get_profiling_results_from_s3', + 'dataall.api.Objects.DatasetProfiling.resolvers._get_profiling_results_from_s3', return_value='{"results": "yes"}', ) - runs = list_profiling_runs(client, dataset1, group) - module_mocker.patch( - 'dataall.aws.handlers.service_handlers.Worker.queue', - return_value=update_runs(db, runs), - ) - table = table(dataset=dataset1, name='table1', username=dataset1.owner) - with db.scoped_session() as session: - table = ( - session.query(dataall.db.models.DatasetTable) - .filter(dataall.db.models.DatasetTable.GlueTableName == 'table1') - .first() - ) + response = client.query( """ query getDatasetTableProfilingRun($tableUri:String!){ @@ -150,37 +139,22 @@ def test_get_table_profiling_run( } } """, - tableUri=table.tableUri, - groups=[group.name], - ) - assert ( - response.data.getDatasetTableProfilingRun['profilingRunUri'] - == runs[0]['profilingRunUri'] + tableUri=table1.tableUri, + groups=[group2.name], + username=user2.userName, ) - assert response.data.getDatasetTableProfilingRun['status'] == 'SUCCEEDED' - assert response.data.getDatasetTableProfilingRun['GlueTableName'] == 'table1' + assert 'UnauthorizedOperation' in response.errors[0].message -def test_list_table_profiling_runs( - client, dataset1, env1, module_mocker, table, db, group +def test_list_table_profiling_runs_authorized( + client, dataset1, module_mocker, table1, db, user, group ): module_mocker.patch( - 'dataall.api.Objects.DatasetProfiling.resolvers.get_profiling_results_from_s3', + 'dataall.api.Objects.DatasetProfiling.resolvers._get_profiling_results_from_s3', return_value='{"results": "yes"}', ) module_mocker.patch('requests.post', return_value=True) - runs = list_profiling_runs(client, dataset1, group) - table1000 = table(dataset=dataset1, name='table1000', username=dataset1.owner) - with db.scoped_session() as session: - table = ( - session.query(dataall.db.models.DatasetTable) - .filter(dataall.db.models.DatasetTable.GlueTableName == 'table1') - .first() - ) - module_mocker.patch( - 'dataall.aws.handlers.service_handlers.Worker.queue', - return_value=update_runs(db, runs), - ) + response = client.query( """ query listDatasetTableProfilingRuns($tableUri:String!){ @@ -195,25 +169,29 @@ def test_list_table_profiling_runs( } } """, - tableUri=table.tableUri, + tableUri=table1.tableUri, groups=[group.name], + username=user.userName, ) + assert response.data.listDatasetTableProfilingRuns['count'] == 1 + assert response.data.listDatasetTableProfilingRuns['nodes'][0]['profilingRunUri'] assert ( - response.data.listDatasetTableProfilingRuns['nodes'][0]['profilingRunUri'] - == runs[0]['profilingRunUri'] - ) - assert ( - response.data.listDatasetTableProfilingRuns['nodes'][0]['status'] == 'SUCCEEDED' + response.data.listDatasetTableProfilingRuns['nodes'][0]['status'] == 'RUNNING' ) assert ( response.data.listDatasetTableProfilingRuns['nodes'][0]['GlueTableName'] == 'table1' ) +def test_list_table_profiling_runs_unauthorized( + client, dataset1, module_mocker, table1, db, user2, group2 +): module_mocker.patch( - 'dataall.aws.handlers.service_handlers.Worker.queue', - return_value=update_runs(db, runs), + 'dataall.api.Objects.DatasetProfiling.resolvers._get_profiling_results_from_s3', + return_value='{"results": "yes"}', ) + module_mocker.patch('requests.post', return_value=True) + response = client.query( """ query listDatasetTableProfilingRuns($tableUri:String!){ @@ -228,39 +206,8 @@ def test_list_table_profiling_runs( } } """, - tableUri=table1000.tableUri, - groups=[group.name], - ) - assert response.data.listDatasetTableProfilingRuns['count'] == 0 - - response = client.query( - """ - query getDatasetTableProfilingRun($tableUri:String!){ - getDatasetTableProfilingRun(tableUri:$tableUri){ - profilingRunUri - status - GlueTableName - } - } - """, - tableUri=table.tableUri, - groups=[group.name], - ) - assert ( - response.data.getDatasetTableProfilingRun['profilingRunUri'] - == runs[0]['profilingRunUri'] - ) - - response = client.query( - """ - query getDatasetTableProfilingRun($tableUri:String!){ - getDatasetTableProfilingRun(tableUri:$tableUri){ - profilingRunUri - status - GlueTableName - } - } - """, - tableUri=table1000.tableUri, + tableUri=table1.tableUri, + groups=[group2.name], + username=user2.userName, ) - assert not response.data.getDatasetTableProfilingRun + assert 'UnauthorizedOperation' in response.errors[0].message