Skip to content

Commit

Permalink
Automated metadata generation using genAI MVP (#1598)
Browse files Browse the repository at this point in the history
### Feature
- Feature

### Detail
- Automated metadata generation using gen AI. MVP phase

### Related
#1599 

By submitting this pull request, I confirm that my contribution is made
under the terms of the Apache 2.0 license.

---------

Co-authored-by: dlpzx <dlpzx@amazon.com>
  • Loading branch information
pelinKuran and dlpzx authored Oct 1, 2024
1 parent 4b67986 commit 11e082f
Show file tree
Hide file tree
Showing 34 changed files with 1,850 additions and 13 deletions.
9 changes: 9 additions & 0 deletions backend/dataall/modules/s3_datasets/api/dataset/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from dataall.base.api.constants import GraphQLEnumMapper


class MetadataGenerationTargets(GraphQLEnumMapper):
"""Describes the s3_datasets metadata generation targets"""

Table = 'Table'
Folder = 'Folder'
S3_Dataset = 'S3_Dataset'
17 changes: 17 additions & 0 deletions backend/dataall/modules/s3_datasets/api/dataset/input_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@
gql.Argument(name='expiryMaxDuration', type=gql.Integer),
],
)
DatasetMetadataInput = gql.InputType(
name='DatasetMetadataInput',
arguments=[
gql.Argument('label', gql.String),
gql.Argument('description', gql.String),
gql.Argument('tags', gql.ArrayType(gql.String)),
gql.Argument('topics', gql.ArrayType(gql.Ref('Topic'))),
],
)

DatasetPresignedUrlInput = gql.InputType(
name='DatasetPresignedUrlInput',
Expand All @@ -58,6 +67,14 @@

CrawlerInput = gql.InputType(name='CrawlerInput', arguments=[gql.Argument(name='prefix', type=gql.String)])

SampleDataInput = gql.InputType(
name='SampleDataInput',
arguments=[
gql.Field(name='fields', type=gql.ArrayType(gql.String)),
gql.Field(name='rows', type=gql.ArrayType(gql.String)),
],
)

ImportDatasetInput = gql.InputType(
name='ImportDatasetInput',
arguments=[
Expand Down
20 changes: 15 additions & 5 deletions backend/dataall/modules/s3_datasets/api/dataset/mutations.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from dataall.base.api import gql
from dataall.modules.s3_datasets.api.dataset.input_types import (
ModifyDatasetInput,
NewDatasetInput,
ImportDatasetInput,
)
from dataall.modules.s3_datasets.api.dataset.input_types import ModifyDatasetInput, NewDatasetInput, ImportDatasetInput
from dataall.modules.s3_datasets.api.dataset.resolvers import (
create_dataset,
update_dataset,
generate_dataset_access_token,
delete_dataset,
import_dataset,
start_crawler,
generate_metadata,
)
from dataall.modules.s3_datasets.api.dataset.enums import MetadataGenerationTargets

createDataset = gql.MutationField(
name='createDataset',
Expand Down Expand Up @@ -68,3 +66,15 @@
resolver=start_crawler,
type=gql.Ref('GlueCrawler'),
)
generateMetadata = gql.MutationField(
name='generateMetadata',
args=[
gql.Argument(name='resourceUri', type=gql.NonNullableType(gql.String)),
gql.Argument(name='targetType', type=gql.NonNullableType(MetadataGenerationTargets.toGraphQLEnum())),
gql.Argument(name='version', type=gql.Integer),
gql.Argument(name='metadataTypes', type=gql.NonNullableType(gql.ArrayType(gql.String))),
gql.Argument(name='sampleData', type=gql.Ref('SampleDataInput')),
],
type=gql.Ref('GeneratedMetadata'),
resolver=generate_metadata,
)
17 changes: 17 additions & 0 deletions backend/dataall/modules/s3_datasets/api/dataset/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
get_dataset_assume_role_url,
get_file_upload_presigned_url,
list_datasets_owned_by_env_group,
list_dataset_tables_folders,
read_sample_data,
)

getDataset = gql.QueryField(
Expand Down Expand Up @@ -45,3 +47,18 @@
resolver=list_datasets_owned_by_env_group,
test_scope='Dataset',
)
listDatasetTablesFolders = gql.QueryField(
name='listDatasetTablesFolders',
args=[
gql.Argument(name='datasetUri', type=gql.NonNullableType(gql.String)),
gql.Argument(name='filter', type=gql.Ref('DatasetFilter')),
],
type=gql.Ref('DatasetItemsSearchResult'),
resolver=list_dataset_tables_folders,
)
listSampleData = gql.QueryField(
name='listSampleData',
args=[gql.Argument(name='tableUri', type=gql.NonNullableType(gql.String))],
type=gql.Ref('QueryPreviewResult'), # basically returns nothing...?
resolver=read_sample_data,
) # return the data -> user invokes generateMetadata again + sample data ; similar api exists
70 changes: 69 additions & 1 deletion backend/dataall/modules/s3_datasets/api/dataset/resolvers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging

import re
from dataall.base.api.context import Context
from dataall.base.feature_toggle_checker import is_feature_enabled
from dataall.base.utils.expiration_util import Expiration
Expand All @@ -11,6 +11,9 @@
from dataall.modules.s3_datasets.db.dataset_models import S3Dataset
from dataall.modules.datasets_base.services.datasets_enums import DatasetRole, ConfidentialityClassification
from dataall.modules.s3_datasets.services.dataset_service import DatasetService
from dataall.modules.s3_datasets.services.dataset_table_service import DatasetTableService
from dataall.modules.s3_datasets.services.dataset_location_service import DatasetLocationService
from dataall.modules.s3_datasets.services.dataset_enums import MetadataGenerationTargets, MetadataGenerationTypes

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -156,6 +159,59 @@ def list_datasets_owned_by_env_group(
return DatasetService.list_datasets_owned_by_env_group(environmentUri, groupUri, filter)


# @ResourceThresholdRepository.invocation_handler('generate_metadata_ai')
# To make this treshold work treshold limits should be added on resource_treshold_repository into the resource paths dictionary.
# as an example; 'nlq' : 'modules.worksheets.features.max_count_per_day' here max_count_per_day shall be defined for metadata generation
# or it could be used as it is by using different key or even the same key after merge.
@is_feature_enabled('modules.s3_datasets.features.generate_metadata_ai.active')
def generate_metadata(
context: Context,
source: S3Dataset,
resourceUri: str,
targetType: str,
version: int,
metadataTypes: list,
sampleData: dict = {},
):
RequestValidator.validate_uri(param_name='resourceUri', param_value=resourceUri)
if metadataTypes not in [item.value for item in MetadataGenerationTypes]:
raise InvalidInput(
'metadataType',
metadataTypes,
f'a list of allowed values {[item.value for item in MetadataGenerationTypes]}',
)
# TODO validate sampleData and make it generic for S3
if targetType == MetadataGenerationTargets.S3_Dataset.value:
return DatasetService.generate_metadata_for_dataset(
resourceUri=resourceUri, version=version, metadataTypes=metadataTypes
)
elif targetType == MetadataGenerationTargets.Table.value:
return DatasetTableService.generate_metadata_for_table(
resourceUri=resourceUri, version=version, metadataTypes=metadataTypes, sampleData=sampleData
)
elif targetType == MetadataGenerationTargets.Folder.value:
return DatasetLocationService.generate_metadata_for_folder(
resourceUri=resourceUri, version=version, metadataTypes=metadataTypes
)
else:
raise Exception('Unsupported target type for metadata generation')


def read_sample_data(context: Context, source: S3Dataset, tableUri: str):
RequestValidator.validate_uri(param_name='tableUri', param_value=tableUri)
return DatasetTableService.preview(uri=tableUri)


def update_dataset_metadata(context: Context, source: S3Dataset, resourceUri: str):
return DatasetService.update_dataset(uri=resourceUri, data=input)


def list_dataset_tables_folders(context: Context, source: S3Dataset, datasetUri: str, filter: dict = None):
if not filter:
filter = {}
return DatasetService.list_dataset_tables_folders(dataset_uri=datasetUri, filter=filter)


class RequestValidator:
@staticmethod
def validate_creation_request(data):
Expand Down Expand Up @@ -200,6 +256,18 @@ def validate_share_expiration_request(data):
'is of invalid type',
)

@staticmethod
def validate_uri(param_name: str, param_value: str):
if not param_value:
raise RequiredParameter(param_name)
pattern = r'^[a-z0-9]{8}$'
if not re.match(pattern, param_value):
raise InvalidInput(
param_name=param_name,
param_value=param_value,
constraint='8 characters long and contain only lowercase letters and numbers',
)

@staticmethod
def validate_import_request(data):
RequestValidator.validate_creation_request(data)
Expand Down
44 changes: 44 additions & 0 deletions backend/dataall/modules/s3_datasets/api/dataset/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,47 @@
gql.Field(name='status', type=gql.String),
],
)
SubitemDescription = gql.ObjectType(
name='SubitemDescription',
fields=[
gql.Field(name='label', type=gql.String),
gql.Field(name='description', type=gql.String),
gql.Field(name='subitem_id', type=gql.String),
],
)
GeneratedMetadata = gql.ObjectType(
name='GeneratedMetadata',
fields=[
gql.Field(name='type', type=gql.String), # Table, Column, Folder, Dataset
gql.Field(name='label', type=gql.String),
gql.Field(name='topics', type=gql.ArrayType(gql.String)),
gql.Field(name='tags', type=gql.ArrayType(gql.String)),
gql.Field(name='description', type=gql.String),
gql.Field(name='name', type=gql.String),
gql.Field(name='subitem_descriptions', type=gql.ArrayType(gql.Ref('SubitemDescription'))),
],
)

DatasetItem = gql.ObjectType(
name='DatasetItem',
fields=[
gql.Field(name='name', type=gql.String),
gql.Field(name='targetType', type=gql.String),
gql.Field(name='targetUri', type=gql.String),
],
)

DatasetItemsSearchResult = gql.ObjectType(
name='DatasetItemsSearchResult',
fields=[
gql.Field(name='count', type=gql.Integer),
gql.Field(name='nodes', type=gql.ArrayType(DatasetItem)),
gql.Field(name='pageSize', type=gql.Integer),
gql.Field(name='nextPage', type=gql.Integer),
gql.Field(name='pages', type=gql.Integer),
gql.Field(name='page', type=gql.Integer),
gql.Field(name='previousPage', type=gql.Integer),
gql.Field(name='hasNext', type=gql.Boolean),
gql.Field(name='hasPrevious', type=gql.Boolean),
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,11 @@
gql.Argument('topics', gql.Integer),
],
)
SubitemDescription = gql.InputType(
name='SubitemDescriptionInput',
arguments=[
gql.Argument(name='label', type=gql.String),
gql.Argument(name='description', type=gql.String),
gql.Argument(name='subitem_id', type=gql.String),
],
)
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from dataall.base.api import gql
from dataall.modules.s3_datasets.api.table_column.resolvers import sync_table_columns, update_table_column
from dataall.modules.s3_datasets.api.table_column.resolvers import (
sync_table_columns,
update_table_column,
batch_update_table_columns_description,
)

syncDatasetTableColumns = gql.MutationField(
name='syncDatasetTableColumns',
Expand All @@ -18,3 +22,9 @@
type=gql.Ref('DatasetTableColumn'),
resolver=update_table_column,
)
batchUpdateDatasetTableColumn = gql.MutationField(
name='batchUpdateDatasetTableColumn',
args=[gql.Argument(name='columns', type=gql.ArrayType(gql.Ref('SubitemDescriptionInput')))],
type=gql.String,
resolver=batch_update_table_columns_description,
)
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,9 @@ def update_table_column(context: Context, source, columnUri: str = None, input:

description = input.get('description', 'No description provided')
return DatasetColumnService.update_table_column_description(column_uri=columnUri, description=description)


def batch_update_table_columns_description(context: Context, source, columns):
if columns is None:
return None
return DatasetColumnService.batch_update_table_columns_description(columns=columns)
Loading

0 comments on commit 11e082f

Please sign in to comment.