Skip to content

Commit

Permalink
feat: Disabling profiling results from "secret" and "official" datase…
Browse files Browse the repository at this point in the history
…ts (#482)

### Feature or Bugfix
- Feature

### Detail
- For datasets that are classified as "Secret", data preview is
disabled. In the same way, data profiling results should alse be denied.
- Added tests for profiling
- removed unused methods

![image](https://github.com/awslabs/aws-dataall/assets/71252798/cc860476-b8ad-429f-ab2b-76c6f09f0010)

### Relates
- #478 
By submitting this pull request, I confirm that my contribution is made
under the terms of the Apache 2.0 license.
  • Loading branch information
dlpzx authored Jun 2, 2023
1 parent 25bab60 commit 856a721
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 205 deletions.
10 changes: 0 additions & 10 deletions backend/dataall/api/Objects/DatasetProfiling/mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,3 @@
type=gql.Ref('DatasetProfilingRun'),
resolver=start_profiling_run,
)

updateDatasetProfilingRunResults = gql.MutationField(
name='updateDatasetProfilingRunResults',
args=[
gql.Argument(name='profilingRunUri', type=gql.NonNullableType(gql.String)),
gql.Argument(name='results', type=gql.NonNullableType(gql.String)),
],
type=gql.Ref('DatasetProfilingRun'),
resolver=update_profiling_run_results,
)
20 changes: 1 addition & 19 deletions backend/dataall/api/Objects/DatasetProfiling/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,6 @@
from .resolvers import *


getDatasetProfilingRun = gql.QueryField(
name='getDatasetProfilingRun',
args=[gql.Argument(name='profilingRunUri', type=gql.NonNullableType(gql.String))],
type=gql.Ref('DatasetProfilingRun'),
resolver=get_profiling_run,
)


listDatasetProfilingRuns = gql.QueryField(
name='listDatasetProfilingRuns',
args=[
gql.Argument(name='datasetUri', type=gql.NonNullableType(gql.String)),
gql.Argument(name='filter', type=gql.Ref('DatasetProfilingRunFilter')),
],
type=gql.Ref('DatasetProfilingRunSearchResults'),
resolver=list_profiling_runs,
)

listDatasetTableProfilingRuns = gql.QueryField(
name='listDatasetTableProfilingRuns',
args=[gql.Argument(name='tableUri', type=gql.NonNullableType(gql.String))],
Expand All @@ -31,5 +13,5 @@
name='getDatasetTableProfilingRun',
args=[gql.Argument(name='tableUri', type=gql.NonNullableType(gql.String))],
type=gql.Ref('DatasetProfilingRun'),
resolver=get_last_table_profiling_run,
resolver=get_dataset_table_profiling_run,
)
100 changes: 58 additions & 42 deletions backend/dataall/api/Objects/DatasetProfiling/resolvers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging

from .... import db
from ....api.context import Context
from ....aws.handlers.service_handlers import Worker
from ....aws.handlers.sts import SessionHelper
Expand All @@ -19,7 +20,30 @@ def resolve_dataset(context, source: models.DatasetProfilingRun):
)


def resolve_profiling_run_status(context: Context, source: models.DatasetProfilingRun):
if not source:
return None
with context.engine.scoped_session() as session:
task = models.Task(
targetUri=source.profilingRunUri, action='glue.job.profiling_run_status'
)
session.add(task)
Worker.queue(engine=context.engine, task_ids=[task.taskUri])
return source.status


def resolve_profiling_results(context: Context, source: models.DatasetProfilingRun):
if not source or source.results == {}:
return None
else:
return json.dumps(source.results)


def start_profiling_run(context: Context, source, input: dict = None):
"""
Triggers profiling jobs on a Table.
Only Dataset owners with PROFILE_DATASET_TABLE can perform this action
"""
with context.engine.scoped_session() as session:

ResourcePolicy.check_user_resource_permission(
Expand Down Expand Up @@ -48,47 +72,14 @@ def start_profiling_run(context: Context, source, input: dict = None):
return run


def get_profiling_run_status(context: Context, source: models.DatasetProfilingRun):
if not source:
return None
with context.engine.scoped_session() as session:
task = models.Task(
targetUri=source.profilingRunUri, action='glue.job.profiling_run_status'
)
session.add(task)
Worker.queue(engine=context.engine, task_ids=[task.taskUri])
return source.status


def get_profiling_results(context: Context, source: models.DatasetProfilingRun):
if not source or source.results == {}:
return None
else:
return json.dumps(source.results)


def update_profiling_run_results(context: Context, source, profilingRunUri, results):
with context.engine.scoped_session() as session:
run = api.DatasetProfilingRun.update_run(
session=session, profilingRunUri=profilingRunUri, results=results
)
return run


def list_profiling_runs(context: Context, source, datasetUri=None):
with context.engine.scoped_session() as session:
return api.DatasetProfilingRun.list_profiling_runs(session, datasetUri)


def get_profiling_run(context: Context, source, profilingRunUri=None):
with context.engine.scoped_session() as session:
return api.DatasetProfilingRun.get_profiling_run(
session=session, profilingRunUri=profilingRunUri
)


def get_last_table_profiling_run(context: Context, source, tableUri=None):
def get_dataset_table_profiling_run(context: Context, source, tableUri=None):
"""
Shows the results of the last profiling job on a Table.
For datasets "Unclassified" all users can perform this action.
For datasets "Secret" or "Official", only users with PREVIEW_DATASET_TABLE permissions can perform this action.
"""
with context.engine.scoped_session() as session:
_check_preview_permissions_if_needed(context=context, session=session, tableUri=tableUri)
run: models.DatasetProfilingRun = (
api.DatasetProfilingRun.get_table_last_profiling_run(
session=session, tableUri=tableUri
Expand All @@ -102,7 +93,7 @@ def get_last_table_profiling_run(context: Context, source, tableUri=None):
environment = api.Environment.get_environment_by_uri(
session, dataset.environmentUri
)
content = get_profiling_results_from_s3(
content = _get_profiling_results_from_s3(
environment, dataset, table, run
)
if content:
Expand All @@ -121,7 +112,7 @@ def get_last_table_profiling_run(context: Context, source, tableUri=None):
return run


def get_profiling_results_from_s3(environment, dataset, table, run):
def _get_profiling_results_from_s3(environment, dataset, table, run):
s3 = SessionHelper.remote_session(environment.AwsAccountId).client(
's3', region_name=environment.region
)
Expand All @@ -141,7 +132,32 @@ def get_profiling_results_from_s3(environment, dataset, table, run):


def list_table_profiling_runs(context: Context, source, tableUri=None):
"""
Lists the runs of a profiling job on a Table.
For datasets "Unclassified" all users can perform this action.
For datasets "Secret" or "Official", only users with PREVIEW_DATASET_TABLE permissions can perform this action.
"""
with context.engine.scoped_session() as session:
_check_preview_permissions_if_needed(context=context, session=session, tableUri=tableUri)
return api.DatasetProfilingRun.list_table_profiling_runs(
session=session, tableUri=tableUri, filter={}
)


def _check_preview_permissions_if_needed(context, session, tableUri):
table: models.DatasetTable = db.api.DatasetTable.get_dataset_table_by_uri(
session, tableUri
)
dataset = db.api.Dataset.get_dataset_by_uri(session, table.datasetUri)
if (
dataset.confidentiality
!= models.ConfidentialityClassification.Unclassified.value
):
ResourcePolicy.check_user_resource_permission(
session=session,
username=context.username,
groups=context.groups,
resource_uri=table.tableUri,
permission_name=permissions.PREVIEW_DATASET_TABLE,
)
return True
8 changes: 4 additions & 4 deletions backend/dataall/api/Objects/DatasetProfiling/schema.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from ... import gql
from .resolvers import (
resolve_dataset,
get_profiling_run_status,
get_profiling_results,
resolve_profiling_run_status,
resolve_profiling_results,
)

DatasetProfilingRun = gql.ObjectType(
Expand All @@ -16,11 +16,11 @@
gql.Field(name='GlueTriggerName', type=gql.String),
gql.Field(name='GlueTableName', type=gql.String),
gql.Field(name='AwsAccountId', type=gql.String),
gql.Field(name='results', type=gql.String, resolver=get_profiling_results),
gql.Field(name='results', type=gql.String, resolver=resolve_profiling_results),
gql.Field(name='created', type=gql.String),
gql.Field(name='updated', type=gql.String),
gql.Field(name='owner', type=gql.String),
gql.Field('status', type=gql.String, resolver=get_profiling_run_status),
gql.Field('status', type=gql.String, resolver=resolve_profiling_run_status),
gql.Field(name='dataset', type=gql.Ref('Dataset'), resolver=resolve_dataset),
],
)
Expand Down
45 changes: 45 additions & 0 deletions tests/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def factory(
name: str,
owner: str,
group: str,
confidentiality: str = None
) -> models.Dataset:
key = f'{org.organizationUri}-{env.environmentUri}-{name}-{group}'
if cache.get(key):
Expand Down Expand Up @@ -290,6 +291,7 @@ def factory(
'environmentUri': env.environmentUri,
'SamlAdminGroupName': group or random_group(),
'organizationUri': org.organizationUri,
'confidentiality': confidentiality or dataall.api.constants.ConfidentialityClassification.Unclassified.value
},
)
print('==>', response)
Expand Down Expand Up @@ -566,6 +568,49 @@ def factory(dataset: models.Dataset, name, username) -> models.DatasetTable:
yield factory


@pytest.fixture(scope='module', autouse=True)
def table_with_permission(client, patch_es):
cache = {}

def factory(
dataset: models.Dataset,
name: str,
owner: str,
group: str,
) -> models.DatasetTable:
key = f'{dataset.datasetUri}-{name}'
if cache.get(key):
print('found in cache ', cache[key])
return cache.get(key)
response = client.query(
"""
mutation CreateDatasetTable(
$datasetUri: String
$input: NewDatasetTableInput
) {
createDatasetTable(datasetUri: $datasetUri, input: $input) {
tableUri
name
}
}
""",
username=owner,
groups=[group],
datasetUri=dataset.datasetUri,
input={
'label': f'{name}',
'name': name,
'description': f'test table {name}',
'tags': random_tags(),
'region': dataset.region
},
)
print('==>', response)
return response.data.createDatasetTable

yield factory


@pytest.fixture(scope='module', autouse=True)
def org(client):
cache = {}
Expand Down
Loading

0 comments on commit 856a721

Please sign in to comment.