diff --git a/backend/dataall/modules/s3_datasets/api/dataset/enums.py b/backend/dataall/modules/s3_datasets/api/dataset/enums.py new file mode 100644 index 000000000..16aa95907 --- /dev/null +++ b/backend/dataall/modules/s3_datasets/api/dataset/enums.py @@ -0,0 +1,9 @@ +from dataall.base.api.constants import GraphQLEnumMapper + + +class MetadataGenerationTargets(GraphQLEnumMapper): + """Describes the s3_datasets metadata generation targets""" + + Table = 'Table' + Folder = 'Folder' + S3_Dataset = 'S3_Dataset' diff --git a/backend/dataall/modules/s3_datasets/api/dataset/input_types.py b/backend/dataall/modules/s3_datasets/api/dataset/input_types.py index ced7ddf6a..ada6315b5 100644 --- a/backend/dataall/modules/s3_datasets/api/dataset/input_types.py +++ b/backend/dataall/modules/s3_datasets/api/dataset/input_types.py @@ -46,6 +46,15 @@ gql.Argument(name='expiryMaxDuration', type=gql.Integer), ], ) +DatasetMetadataInput = gql.InputType( + name='DatasetMetadataInput', + arguments=[ + gql.Argument('label', gql.String), + gql.Argument('description', gql.String), + gql.Argument('tags', gql.ArrayType(gql.String)), + gql.Argument('topics', gql.ArrayType(gql.Ref('Topic'))), + ], +) DatasetPresignedUrlInput = gql.InputType( name='DatasetPresignedUrlInput', @@ -58,6 +67,14 @@ CrawlerInput = gql.InputType(name='CrawlerInput', arguments=[gql.Argument(name='prefix', type=gql.String)]) +SampleDataInput = gql.InputType( + name='SampleDataInput', + arguments=[ + gql.Field(name='fields', type=gql.ArrayType(gql.String)), + gql.Field(name='rows', type=gql.ArrayType(gql.String)), + ], +) + ImportDatasetInput = gql.InputType( name='ImportDatasetInput', arguments=[ diff --git a/backend/dataall/modules/s3_datasets/api/dataset/mutations.py b/backend/dataall/modules/s3_datasets/api/dataset/mutations.py index d82f98194..5df48bf4f 100644 --- a/backend/dataall/modules/s3_datasets/api/dataset/mutations.py +++ b/backend/dataall/modules/s3_datasets/api/dataset/mutations.py @@ -1,9 +1,5 @@ from dataall.base.api import gql -from dataall.modules.s3_datasets.api.dataset.input_types import ( - ModifyDatasetInput, - NewDatasetInput, - ImportDatasetInput, -) +from dataall.modules.s3_datasets.api.dataset.input_types import ModifyDatasetInput, NewDatasetInput, ImportDatasetInput from dataall.modules.s3_datasets.api.dataset.resolvers import ( create_dataset, update_dataset, @@ -11,7 +7,9 @@ delete_dataset, import_dataset, start_crawler, + generate_metadata, ) +from dataall.modules.s3_datasets.api.dataset.enums import MetadataGenerationTargets createDataset = gql.MutationField( name='createDataset', @@ -68,3 +66,15 @@ resolver=start_crawler, type=gql.Ref('GlueCrawler'), ) +generateMetadata = gql.MutationField( + name='generateMetadata', + args=[ + gql.Argument(name='resourceUri', type=gql.NonNullableType(gql.String)), + gql.Argument(name='targetType', type=gql.NonNullableType(MetadataGenerationTargets.toGraphQLEnum())), + gql.Argument(name='version', type=gql.Integer), + gql.Argument(name='metadataTypes', type=gql.NonNullableType(gql.ArrayType(gql.String))), + gql.Argument(name='sampleData', type=gql.Ref('SampleDataInput')), + ], + type=gql.Ref('GeneratedMetadata'), + resolver=generate_metadata, +) diff --git a/backend/dataall/modules/s3_datasets/api/dataset/queries.py b/backend/dataall/modules/s3_datasets/api/dataset/queries.py index 5043b868d..fdd9ed5aa 100644 --- a/backend/dataall/modules/s3_datasets/api/dataset/queries.py +++ b/backend/dataall/modules/s3_datasets/api/dataset/queries.py @@ -4,6 +4,8 @@ get_dataset_assume_role_url, get_file_upload_presigned_url, list_datasets_owned_by_env_group, + list_dataset_tables_folders, + read_sample_data, ) getDataset = gql.QueryField( @@ -45,3 +47,18 @@ resolver=list_datasets_owned_by_env_group, test_scope='Dataset', ) +listDatasetTablesFolders = gql.QueryField( + name='listDatasetTablesFolders', + args=[ + gql.Argument(name='datasetUri', type=gql.NonNullableType(gql.String)), + gql.Argument(name='filter', type=gql.Ref('DatasetFilter')), + ], + type=gql.Ref('DatasetItemsSearchResult'), + resolver=list_dataset_tables_folders, +) +listSampleData = gql.QueryField( + name='listSampleData', + args=[gql.Argument(name='tableUri', type=gql.NonNullableType(gql.String))], + type=gql.Ref('QueryPreviewResult'), # basically returns nothing...? + resolver=read_sample_data, +) # return the data -> user invokes generateMetadata again + sample data ; similar api exists diff --git a/backend/dataall/modules/s3_datasets/api/dataset/resolvers.py b/backend/dataall/modules/s3_datasets/api/dataset/resolvers.py index 90f6fd3d9..d97eb8429 100644 --- a/backend/dataall/modules/s3_datasets/api/dataset/resolvers.py +++ b/backend/dataall/modules/s3_datasets/api/dataset/resolvers.py @@ -1,5 +1,5 @@ import logging - +import re from dataall.base.api.context import Context from dataall.base.feature_toggle_checker import is_feature_enabled from dataall.base.utils.expiration_util import Expiration @@ -11,6 +11,9 @@ from dataall.modules.s3_datasets.db.dataset_models import S3Dataset from dataall.modules.datasets_base.services.datasets_enums import DatasetRole, ConfidentialityClassification from dataall.modules.s3_datasets.services.dataset_service import DatasetService +from dataall.modules.s3_datasets.services.dataset_table_service import DatasetTableService +from dataall.modules.s3_datasets.services.dataset_location_service import DatasetLocationService +from dataall.modules.s3_datasets.services.dataset_enums import MetadataGenerationTargets, MetadataGenerationTypes log = logging.getLogger(__name__) @@ -156,6 +159,59 @@ def list_datasets_owned_by_env_group( return DatasetService.list_datasets_owned_by_env_group(environmentUri, groupUri, filter) +# @ResourceThresholdRepository.invocation_handler('generate_metadata_ai') +# To make this treshold work treshold limits should be added on resource_treshold_repository into the resource paths dictionary. +# as an example; 'nlq' : 'modules.worksheets.features.max_count_per_day' here max_count_per_day shall be defined for metadata generation +# or it could be used as it is by using different key or even the same key after merge. +@is_feature_enabled('modules.s3_datasets.features.generate_metadata_ai.active') +def generate_metadata( + context: Context, + source: S3Dataset, + resourceUri: str, + targetType: str, + version: int, + metadataTypes: list, + sampleData: dict = {}, +): + RequestValidator.validate_uri(param_name='resourceUri', param_value=resourceUri) + if metadataTypes not in [item.value for item in MetadataGenerationTypes]: + raise InvalidInput( + 'metadataType', + metadataTypes, + f'a list of allowed values {[item.value for item in MetadataGenerationTypes]}', + ) + # TODO validate sampleData and make it generic for S3 + if targetType == MetadataGenerationTargets.S3_Dataset.value: + return DatasetService.generate_metadata_for_dataset( + resourceUri=resourceUri, version=version, metadataTypes=metadataTypes + ) + elif targetType == MetadataGenerationTargets.Table.value: + return DatasetTableService.generate_metadata_for_table( + resourceUri=resourceUri, version=version, metadataTypes=metadataTypes, sampleData=sampleData + ) + elif targetType == MetadataGenerationTargets.Folder.value: + return DatasetLocationService.generate_metadata_for_folder( + resourceUri=resourceUri, version=version, metadataTypes=metadataTypes + ) + else: + raise Exception('Unsupported target type for metadata generation') + + +def read_sample_data(context: Context, source: S3Dataset, tableUri: str): + RequestValidator.validate_uri(param_name='tableUri', param_value=tableUri) + return DatasetTableService.preview(uri=tableUri) + + +def update_dataset_metadata(context: Context, source: S3Dataset, resourceUri: str): + return DatasetService.update_dataset(uri=resourceUri, data=input) + + +def list_dataset_tables_folders(context: Context, source: S3Dataset, datasetUri: str, filter: dict = None): + if not filter: + filter = {} + return DatasetService.list_dataset_tables_folders(dataset_uri=datasetUri, filter=filter) + + class RequestValidator: @staticmethod def validate_creation_request(data): @@ -200,6 +256,18 @@ def validate_share_expiration_request(data): 'is of invalid type', ) + @staticmethod + def validate_uri(param_name: str, param_value: str): + if not param_value: + raise RequiredParameter(param_name) + pattern = r'^[a-z0-9]{8}$' + if not re.match(pattern, param_value): + raise InvalidInput( + param_name=param_name, + param_value=param_value, + constraint='8 characters long and contain only lowercase letters and numbers', + ) + @staticmethod def validate_import_request(data): RequestValidator.validate_creation_request(data) diff --git a/backend/dataall/modules/s3_datasets/api/dataset/types.py b/backend/dataall/modules/s3_datasets/api/dataset/types.py index a8e11a14e..9a5c4f3f7 100644 --- a/backend/dataall/modules/s3_datasets/api/dataset/types.py +++ b/backend/dataall/modules/s3_datasets/api/dataset/types.py @@ -140,3 +140,47 @@ gql.Field(name='status', type=gql.String), ], ) +SubitemDescription = gql.ObjectType( + name='SubitemDescription', + fields=[ + gql.Field(name='label', type=gql.String), + gql.Field(name='description', type=gql.String), + gql.Field(name='subitem_id', type=gql.String), + ], +) +GeneratedMetadata = gql.ObjectType( + name='GeneratedMetadata', + fields=[ + gql.Field(name='type', type=gql.String), # Table, Column, Folder, Dataset + gql.Field(name='label', type=gql.String), + gql.Field(name='topics', type=gql.ArrayType(gql.String)), + gql.Field(name='tags', type=gql.ArrayType(gql.String)), + gql.Field(name='description', type=gql.String), + gql.Field(name='name', type=gql.String), + gql.Field(name='subitem_descriptions', type=gql.ArrayType(gql.Ref('SubitemDescription'))), + ], +) + +DatasetItem = gql.ObjectType( + name='DatasetItem', + fields=[ + gql.Field(name='name', type=gql.String), + gql.Field(name='targetType', type=gql.String), + gql.Field(name='targetUri', type=gql.String), + ], +) + +DatasetItemsSearchResult = gql.ObjectType( + name='DatasetItemsSearchResult', + fields=[ + gql.Field(name='count', type=gql.Integer), + gql.Field(name='nodes', type=gql.ArrayType(DatasetItem)), + gql.Field(name='pageSize', type=gql.Integer), + gql.Field(name='nextPage', type=gql.Integer), + gql.Field(name='pages', type=gql.Integer), + gql.Field(name='page', type=gql.Integer), + gql.Field(name='previousPage', type=gql.Integer), + gql.Field(name='hasNext', type=gql.Boolean), + gql.Field(name='hasPrevious', type=gql.Boolean), + ], +) diff --git a/backend/dataall/modules/s3_datasets/api/table_column/input_types.py b/backend/dataall/modules/s3_datasets/api/table_column/input_types.py index ca32c83f9..2d8f90c77 100644 --- a/backend/dataall/modules/s3_datasets/api/table_column/input_types.py +++ b/backend/dataall/modules/s3_datasets/api/table_column/input_types.py @@ -18,3 +18,11 @@ gql.Argument('topics', gql.Integer), ], ) +SubitemDescription = gql.InputType( + name='SubitemDescriptionInput', + arguments=[ + gql.Argument(name='label', type=gql.String), + gql.Argument(name='description', type=gql.String), + gql.Argument(name='subitem_id', type=gql.String), + ], +) diff --git a/backend/dataall/modules/s3_datasets/api/table_column/mutations.py b/backend/dataall/modules/s3_datasets/api/table_column/mutations.py index d9ae99b6d..3ee266ff6 100644 --- a/backend/dataall/modules/s3_datasets/api/table_column/mutations.py +++ b/backend/dataall/modules/s3_datasets/api/table_column/mutations.py @@ -1,5 +1,9 @@ from dataall.base.api import gql -from dataall.modules.s3_datasets.api.table_column.resolvers import sync_table_columns, update_table_column +from dataall.modules.s3_datasets.api.table_column.resolvers import ( + sync_table_columns, + update_table_column, + batch_update_table_columns_description, +) syncDatasetTableColumns = gql.MutationField( name='syncDatasetTableColumns', @@ -18,3 +22,9 @@ type=gql.Ref('DatasetTableColumn'), resolver=update_table_column, ) +batchUpdateDatasetTableColumn = gql.MutationField( + name='batchUpdateDatasetTableColumn', + args=[gql.Argument(name='columns', type=gql.ArrayType(gql.Ref('SubitemDescriptionInput')))], + type=gql.String, + resolver=batch_update_table_columns_description, +) diff --git a/backend/dataall/modules/s3_datasets/api/table_column/resolvers.py b/backend/dataall/modules/s3_datasets/api/table_column/resolvers.py index 07cb82d5a..3acbe2408 100644 --- a/backend/dataall/modules/s3_datasets/api/table_column/resolvers.py +++ b/backend/dataall/modules/s3_datasets/api/table_column/resolvers.py @@ -41,3 +41,9 @@ def update_table_column(context: Context, source, columnUri: str = None, input: description = input.get('description', 'No description provided') return DatasetColumnService.update_table_column_description(column_uri=columnUri, description=description) + + +def batch_update_table_columns_description(context: Context, source, columns): + if columns is None: + return None + return DatasetColumnService.batch_update_table_columns_description(columns=columns) diff --git a/backend/dataall/modules/s3_datasets/aws/bedrock_metadata_client.py b/backend/dataall/modules/s3_datasets/aws/bedrock_metadata_client.py new file mode 100644 index 000000000..c5940e5b4 --- /dev/null +++ b/backend/dataall/modules/s3_datasets/aws/bedrock_metadata_client.py @@ -0,0 +1,145 @@ +import logging +import json +from dataall.base.config import config +from dataall.base.aws.sts import SessionHelper + + +log = logging.getLogger(__name__) + +# TODO: refactoring with prompt templates + +class BedrockClient: + def __init__(self, account_id: str, region: str): + session = SessionHelper.remote_session(accountid=account_id, region=region) + self._client = session.client('bedrock-runtime', region_name=region) + + def _generate_prompt(self, **kwargs): + prompt_type = kwargs.get('prompt_type', 'Table') + common_data = { + 'label': kwargs.get('label', ''), + 'description': kwargs.get('description', ''), + 'tags': kwargs.get('tags', ''), + 'columns': kwargs.get('columns', []), + 'subitem_descriptions': kwargs.get('subitem_descriptions', []), + 'file_names': kwargs.get('file_names', []), + 'folder_name': kwargs.get('folder_name', ''), + 'folder_description': kwargs.get('folder_description', ''), + 'folder_tags': kwargs.get('folder_tags', ''), + 'tables': kwargs.get('tables', []), + 'table_description': kwargs.get('table_descriptions', ''), + 'metadata_types': kwargs.get('metadata_type', []), + 'folders': kwargs.get('folders', []), + 'sample_data': kwargs.get('sample_data', []), + } + if prompt_type == 'Table': + return f""" + Generate or improve metadata for a common_data['label'] table using the following provided data: + - Table name: {common_data['label'] if common_data['label'] else 'No description provided'} + - Column names: {common_data['columns'] if common_data['columns'] else 'No description provided'} + - Table description: {common_data['description'] if common_data['description'] else 'No description provided'} + - Tags: {common_data['tags'] if common_data['tags'] else 'No description provided'} + - Subitem Descriptions: {common_data['subitem_descriptions'] if common_data['subitem_descriptions'] else 'No description provided'} + - (Only Input) Sample data: {common_data['sample_data'] if common_data['sample_data'] else 'No sample data'} + **Important**: + - If the data indicates "No description provided," do not use that particular input for generating metadata, these data is optional you should still generate in that case. + - Only focus on generating the following metadata types as specified by the user: {common_data['metadata_types']}. Do not include any other metadata types. + - Sample data is only input for you to understand the table better, do not generate sample data. + Your response must strictly contain all the requested metadata types, do not include any of the metadata types if it is not specified by the user. Don't use ' ' in your response, use " ". + Subitem Descriptions corresponds to column descriptions. If the user specifically didn't ask for subitem descriptions, do not include it in the response. + subitem_descriptions is another dictionary within the existing dictionary, rest of them are strings, never change order of columns when you generate description for them. + For example, if the requested metadata types are "Tags" and "Subitem Descriptions", the response should be: + tags: + subitem_descriptions: + : + : + ,..., + : + Evaluate if the given parameters are sufficient for generating the requested metadata, if not, respond with "NotEnoughData" for all values of dictionary keys. + Return the result as a Python dictionary where the keys are the requested metadata types, all the keys must be lowercase and the values are the corresponding generated metadata. + For tags and topics, ensure the output is a string list. Label is singular so you should return only one label as string. + + """ + + elif prompt_type == 'S3_Dataset': + return f""" + Generate or improve metadata for a dataset using the following provided data: + - Dataset name: {common_data['label'] if common_data['label'] else 'No description provided'} + - Table names in the dataset: {common_data['tables'] if common_data['tables'] else 'No description provided'} + - Folder names in the dataset: {common_data['folders'] if common_data['folders'] else 'No description provided'} + - Current tags for dataset: {common_data['tags'] if common_data['tags'] else 'No description provided'} + - Current dataset description: {common_data['description'] if common_data['description'] else 'No description provided'} + **Important**: + - If the data indicates "No description provided," do not use that particular input for generating metadata. + - Only focus on generating the following metadata types as specified by the user: {common_data['metadata_types']}. Do not include any other metadata types. + - Return the result as a Python dictionary. + Your response should strictly contain the requested metadata types. Don't use ' ' in your response, use " ". + For example, if the requested metadata types are "tags" and "description", the response should be: + "tags": + "description": + Evaluate if the given parameters are sufficient for generating the requested metadata, if not, respond with listing table names and folder names for description and for label keep the current name + For tags and topics, ensure the output is a string list. Label is singular so you should return only one label as string. + Return the result as a Python dictionary where the keys are the requested metadata types, all the keys must be lowercase and the values are the corresponding generated metadata. + + """ + elif prompt_type == 'Folder': + return f""" + Generate a detailed metadata description for a database table using following provided data: + folder name: {common_data['label']}, + file names: {common_data['file_names'] if common_data['file_names'] else 'No description provided'} + folder_description: {common_data['description'] if common_data['description'] else 'No description provided'} + folder_tags: {common_data['tags'] if common_data['tags'] else 'No description provided'} + **Important**: + - If the data indicates "No description provided," do not use that particular input for generating metadata. + - Only focus on generating the following metadata types as specified by the user: {common_data['metadata_types']}. Do not include any other metadata types. + - Return the result as a Python dictionary. + Your response should strictly contain the requested metadata types. + For example, if the requested metadata types are "tags" and "description", the response should be: + "tags": + "description": + For tags and topics, ensure the output is a string list. Label is singular so you should return only one label as string. + Return a python dictionary, all the keys must be lowercase. Don't use ' ' in your response, use " ". Include file types as pdf, and write file names in description. + Evaluate if the given parameters are sufficient for generating the requested metadata, if not, respond with "NotEnoughData" for all values of dictionary keys. + """ + + def _invoke_model(self, prompt): + messages = [{'role': 'user', 'content': [{'type': 'text', 'text': prompt}]}] + body = json.dumps( + { + 'anthropic_version': 'bedrock-2023-05-31', + 'max_tokens': 4096, + 'messages': messages, + 'temperature': 0.5, + 'top_p': 0.5, + 'stop_sequences': ['\n\nHuman:'], + 'top_k': 250, + } + ) + #TODO: adjust input depending on model + model_id = config.get_property('modules.s3_datasets.features.generate_metadata_ai.model_id') + response = self._client.invoke_model(body=body, modelId=model_id) + response_body = json.loads(response.get('body').read()) + return response_body.get('content', []) + + def _parse_response(self, response_content, targetName, subitem_ids): + output_str = response_content[0]['text'] + + output_dict = json.loads(output_str) + if not output_dict.get('name'): + output_dict['name'] = targetName + + if output_dict.get('subitem_descriptions'): + subitem_ids = subitem_ids.pop() + subitem_ids = subitem_ids.split(',') + subitem_ids = subitem_ids[: len(output_dict['subitem_descriptions'])] + subitem_descriptions = [] + for index, (key, value) in enumerate(output_dict['subitem_descriptions'].items()): + subitem_descriptions.append({'label': key, 'description': value, 'subitem_id': subitem_ids[index]}) + output_dict['subitem_descriptions'] = subitem_descriptions + return output_dict + + def generate_metadata(self, **kwargs): + # TODO: refactor to use explicit params instead of kwargs + prompt = self._generate_prompt(**kwargs) + response_content = self._invoke_model(prompt) + # TODO: add templated output so that we can avoid parsing the response too much + return self._parse_response(response_content, kwargs.get('label', ' '), kwargs.get('subitem_ids', ' ')) diff --git a/backend/dataall/modules/s3_datasets/aws/s3_dataset_client.py b/backend/dataall/modules/s3_datasets/aws/s3_dataset_client.py index 94db4d056..778edb508 100644 --- a/backend/dataall/modules/s3_datasets/aws/s3_dataset_client.py +++ b/backend/dataall/modules/s3_datasets/aws/s3_dataset_client.py @@ -73,3 +73,16 @@ def get_bucket_encryption(self) -> (str, str, str): f'Data.all Environment Pivot Role does not have s3:GetEncryptionConfiguration Permission for {dataset.S3BucketName} bucket: {e}' ) raise Exception(f'Cannot fetch the bucket encryption configuration for {dataset.S3BucketName}: {e}') + + def list_bucket_files(self, bucket_name, prefix): + dataset = self._dataset + try: + response = self._client.list_objects_v2( + Bucket=bucket_name, + Prefix=prefix, + ExpectedBucketOwner=dataset.AwsAccountId, + MaxKeys=1000, + ) + return response.get('Contents', []) + except ClientError as e: + raise Exception(f'Cannot list the bucket files for {dataset.S3BucketName}: {e}') diff --git a/backend/dataall/modules/s3_datasets/db/dataset_column_repositories.py b/backend/dataall/modules/s3_datasets/db/dataset_column_repositories.py index c2038084b..17e7eb733 100644 --- a/backend/dataall/modules/s3_datasets/db/dataset_column_repositories.py +++ b/backend/dataall/modules/s3_datasets/db/dataset_column_repositories.py @@ -1,5 +1,5 @@ from operator import or_ - +from sqlalchemy import func, and_ from dataall.base.db import paginate from dataall.base.db.exceptions import ObjectNotFound from dataall.modules.s3_datasets.db.dataset_models import DatasetTableColumn @@ -42,3 +42,30 @@ def paginate_active_columns_for_table(session, table_uri: str, filter: dict): ).order_by(DatasetTableColumn.columnType.asc()) return paginate(query=q, page=filter.get('page', 1), page_size=filter.get('pageSize', 10)).to_dict() + + @staticmethod + def get_table_info_metadata_generation(session, table_uri: str): + result = ( + session.query( + DatasetTableColumn.GlueTableName, + DatasetTableColumn.AWSAccountId, + func.array_agg(DatasetTableColumn.description).label('description'), + func.array_agg(DatasetTableColumn.label).label('label'), + func.array_agg(DatasetTableColumn.columnUri).label('columnUri'), + ) + .filter(and_(DatasetTableColumn.tableUri == table_uri)) + .group_by(DatasetTableColumn.GlueTableName, DatasetTableColumn.AWSAccountId) + .first() + ) + return result + + @staticmethod + def query_active_columns_for_table(session, table_uri: str): + return ( + session.query(DatasetTableColumn) + .filter( + DatasetTableColumn.tableUri == table_uri, + DatasetTableColumn.deleted.is_(None), + ) + .order_by(DatasetTableColumn.columnType.asc()) + ) diff --git a/backend/dataall/modules/s3_datasets/db/dataset_models.py b/backend/dataall/modules/s3_datasets/db/dataset_models.py index 3e9291485..819d9065d 100644 --- a/backend/dataall/modules/s3_datasets/db/dataset_models.py +++ b/backend/dataall/modules/s3_datasets/db/dataset_models.py @@ -46,7 +46,10 @@ class DatasetStorageLocation(Resource, Base): S3BucketName = Column(String, nullable=False) S3Prefix = Column(String, nullable=False) S3AccessPoint = Column(String, nullable=True) + label = Column(String, nullable=False) region = Column(String, default='eu-west-1') + description = Column(String, nullable=True) + tags = Column(ARRAY(String)) locationCreated = Column(Boolean, default=False) userRoleForStorageLocation = query_expression() projectPermission = query_expression() diff --git a/backend/dataall/modules/s3_datasets/db/dataset_repositories.py b/backend/dataall/modules/s3_datasets/db/dataset_repositories.py index 075575a2a..feedf7fb5 100644 --- a/backend/dataall/modules/s3_datasets/db/dataset_repositories.py +++ b/backend/dataall/modules/s3_datasets/db/dataset_repositories.py @@ -1,6 +1,7 @@ import logging -from sqlalchemy import and_, or_ +import sqlalchemy +from sqlalchemy import and_, or_, literal from sqlalchemy.orm import Query from dataall.core.activity.db.activity_models import Activity from dataall.core.environment.db.environment_models import Environment @@ -9,7 +10,7 @@ from dataall.base.db.exceptions import ObjectNotFound from dataall.modules.datasets_base.services.datasets_enums import ConfidentialityClassification, Language from dataall.core.environment.services.environment_resource_manager import EnvironmentResource -from dataall.modules.s3_datasets.db.dataset_models import DatasetTable, S3Dataset +from dataall.modules.s3_datasets.db.dataset_models import DatasetTable, S3Dataset, DatasetStorageLocation from dataall.base.utils.naming_convention import ( NamingConventionService, NamingConventionPattern, @@ -278,3 +279,41 @@ def _set_import_data(dataset, data): dataset.importedAdminRole = True if data.get('adminRoleName') else False if data.get('imported'): dataset.KmsAlias = data.get('KmsKeyAlias') if data.get('KmsKeyAlias') else 'SSE-S3' + + @staticmethod + def query_dataset_tables_folders(session, dataset_uri): + q1 = ( + session.query( + S3Dataset.datasetUri, + DatasetTable.tableUri.label('targetUri'), + DatasetTable.name.label('name'), + literal('Table', type_=sqlalchemy.types.String).label('targetType'), + ) + .join( + DatasetTable, + DatasetTable.datasetUri == S3Dataset.datasetUri, + ) + .filter(S3Dataset.datasetUri == dataset_uri) + ) + q2 = ( + session.query( + S3Dataset.datasetUri, + DatasetStorageLocation.locationUri.label('targetUri'), + DatasetStorageLocation.name.label('name'), + literal('Folder', type_=sqlalchemy.types.String).label('targetType'), + ) + .join( + DatasetStorageLocation, + DatasetStorageLocation.datasetUri == S3Dataset.datasetUri, + ) + .filter(S3Dataset.datasetUri == dataset_uri) + ) + return q1.union(q2) + + @staticmethod + def paginated_dataset_tables_folders(session, dataset_uri, data): + return paginate( + query=DatasetRepository.query_dataset_tables_folders(session, dataset_uri), + page=data.get('page', 1), + page_size=data.get('pageSize', 10), + ).to_dict() diff --git a/backend/dataall/modules/s3_datasets/services/dataset_column_service.py b/backend/dataall/modules/s3_datasets/services/dataset_column_service.py index 987b855a4..40987c4c7 100644 --- a/backend/dataall/modules/s3_datasets/services/dataset_column_service.py +++ b/backend/dataall/modules/s3_datasets/services/dataset_column_service.py @@ -70,3 +70,11 @@ def update_table_column_description(column_uri: str, description) -> DatasetTabl Worker.queue(engine=get_context().db_engine, task_ids=[task.taskUri]) return column + + @staticmethod + def batch_update_table_columns_description(columns): + for column_ in columns: + DatasetColumnService.update_table_column_description( + column_uri=column_['subitem_id'], description=column_['description'] + ) + return 'Success' diff --git a/backend/dataall/modules/s3_datasets/services/dataset_enums.py b/backend/dataall/modules/s3_datasets/services/dataset_enums.py new file mode 100644 index 000000000..a62d0fd5b --- /dev/null +++ b/backend/dataall/modules/s3_datasets/services/dataset_enums.py @@ -0,0 +1,18 @@ +from enum import Enum + + +class MetadataGenerationTargets(Enum): + """Describes the s3_datasets metadata generation types""" + + Table = 'Table' + Folder = 'Folder' + S3_Dataset = 'S3_Dataset' + + +class MetadataGenerationTypes(Enum): + """Describes the s3_datasets metadata generation types""" + + Description = 'Description' + Label = 'Label' + Tag = 'Tag' + Topic = 'Topic' diff --git a/backend/dataall/modules/s3_datasets/services/dataset_location_service.py b/backend/dataall/modules/s3_datasets/services/dataset_location_service.py index a4ac2b33f..23827c9c1 100644 --- a/backend/dataall/modules/s3_datasets/services/dataset_location_service.py +++ b/backend/dataall/modules/s3_datasets/services/dataset_location_service.py @@ -1,3 +1,5 @@ +import logging + from dataall.modules.s3_datasets.indexers.dataset_indexer import DatasetIndexer from dataall.base.context import get_context from dataall.core.permissions.services.resource_policy_service import ResourcePolicyService @@ -18,6 +20,11 @@ from dataall.modules.s3_datasets.services.dataset_permissions import DATASET_FOLDER_READ, GET_DATASET_FOLDER from dataall.modules.s3_datasets.db.dataset_repositories import DatasetRepository from dataall.modules.s3_datasets.db.dataset_models import DatasetStorageLocation, S3Dataset +from dataall.modules.s3_datasets.aws.bedrock_metadata_client import BedrockClient +from dataall.modules.s3_datasets.aws.s3_dataset_client import S3DatasetClient +from dataall.modules.s3_datasets.services.dataset_enums import MetadataGenerationTargets + +log = logging.getLogger(__name__) class DatasetLocationService: @@ -137,3 +144,22 @@ def _delete_dataset_folder_read_permission(session, dataset: S3Dataset, location } for group in permission_group: ResourcePolicyService.delete_resource_policy(session=session, group=group, resource_uri=location_uri) + + @staticmethod + def generate_metadata_for_folder(resourceUri, version, metadataTypes): + context = get_context() + # TODO decide what to do with version + with context.db_engine.scoped_session() as session: + folder = DatasetLocationRepository.get_location_by_uri(session, resourceUri) + dataset = DatasetRepository.get_dataset_by_uri(session, folder.datasetUri) + files = S3DatasetClient(dataset).list_bucket_files(folder.S3BucketName, folder.S3Prefix) + file_names = [f['Key'] for f in files] + log.info('file names', file_names) + return BedrockClient(folder.AWSAccountId, 'us-east-1').generate_metadata( + prompt_type=MetadataGenerationTargets.Folder.value, + label=folder.label, + file_names=file_names, + description=folder.description, + tags=folder.tags, + metadata_type=metadataTypes, + ) diff --git a/backend/dataall/modules/s3_datasets/services/dataset_service.py b/backend/dataall/modules/s3_datasets/services/dataset_service.py index 14cfdc2fd..e80b8e4ba 100644 --- a/backend/dataall/modules/s3_datasets/services/dataset_service.py +++ b/backend/dataall/modules/s3_datasets/services/dataset_service.py @@ -47,6 +47,8 @@ from dataall.modules.datasets_base.db.dataset_models import DatasetBase from dataall.modules.s3_datasets.services.dataset_permissions import DATASET_TABLE_ALL from dataall.modules.datasets_base.services.dataset_service_interface import DatasetServiceInterface +from dataall.modules.s3_datasets.aws.bedrock_metadata_client import BedrockClient +from dataall.modules.s3_datasets.services.dataset_enums import MetadataGenerationTargets log = logging.getLogger(__name__) @@ -557,3 +559,29 @@ def delete_dataset_term_links(session, dataset_uri): for table_uri in tables: GlossaryRepository.delete_glossary_terms_links(session, table_uri, 'DatasetTable') GlossaryRepository.delete_glossary_terms_links(session, dataset_uri, 'Dataset') + + @staticmethod + def list_dataset_tables_folders(dataset_uri, filter): + context = get_context() + with context.db_engine.scoped_session() as session: + return DatasetRepository.paginated_dataset_tables_folders(session, dataset_uri, filter) + + @staticmethod + def generate_metadata_for_dataset(resourceUri, version, metadataTypes): + # TODO decide what to do with version + context = get_context() + with context.db_engine.scoped_session() as session: + dataset = DatasetBaseRepository.get_dataset_by_uri(session, resourceUri) + table_labels = [t.label for t in DatasetRepository.get_dataset_tables(session, resourceUri)] + table_descriptions = [t.description for t in DatasetRepository.get_dataset_tables(session, resourceUri)] + folders = [f.label for f in DatasetLocationRepository.get_dataset_folders(session, resourceUri)] + return BedrockClient(dataset.AwsAccountId, 'us-east-1').generate_metadata( + prompt_type=MetadataGenerationTargets.S3_Dataset.value, + label=dataset.label, + tables=table_labels, + description=dataset.description, + table_description=table_descriptions, + tags=dataset.tags, + metadata_type=metadataTypes, + folders=folders, + ) diff --git a/backend/dataall/modules/s3_datasets/services/dataset_table_service.py b/backend/dataall/modules/s3_datasets/services/dataset_table_service.py index 5c3c3228f..d4a7f7878 100644 --- a/backend/dataall/modules/s3_datasets/services/dataset_table_service.py +++ b/backend/dataall/modules/s3_datasets/services/dataset_table_service.py @@ -29,6 +29,10 @@ from dataall.modules.s3_datasets.services.dataset_service import DatasetService from dataall.base.utils import json_utils from dataall.base.db import exceptions +from dataall.modules.s3_datasets.aws.bedrock_metadata_client import BedrockClient +from dataall.modules.s3_datasets.db.dataset_column_repositories import DatasetColumnRepository +from dataall.modules.s3_datasets.services.dataset_enums import MetadataGenerationTargets + log = logging.getLogger(__name__) @@ -183,3 +187,23 @@ def _delete_dataset_table_read_permission(session, table_uri): ResourcePolicyService.delete_resource_policy( session=session, group=None, resource_uri=table_uri, resource_type=DatasetTable.__name__ ) + + # TODO ADD PERMISSIONS! + @staticmethod + def generate_metadata_for_table(resourceUri, version, metadataTypes, sampleData): + # TODO decide what to do with version + context = get_context() + with context.db_engine.scoped_session() as session: + table = DatasetTableRepository.get_dataset_table_by_uri(session, resourceUri) + table_column = DatasetColumnRepository.get_table_info_metadata_generation(session, resourceUri) + return BedrockClient(table_column.AWSAccountId, 'us-east-1').generate_metadata( + prompt_type=MetadataGenerationTargets.Table.value, + label=table.label, + columns={','.join(table_column.label)}, + subitem_descriptions={','.join(table_column.description)}, + subitem_ids={','.join(table_column.columnUri)}, + description=table.description, + tags=table.tags, + metadata_type=metadataTypes, + sample_data=sampleData, + ) diff --git a/config.json b/config.json index 26d0891b4..9c4b80f8f 100644 --- a/config.json +++ b/config.json @@ -58,7 +58,11 @@ "preview_data": true, "glue_crawler": true, "metrics_data": true, - "show_stack_logs": "enabled" + "show_stack_logs": "enabled", + "generate_metadata_ai": { + "active": true, + "model_id": "anthropic.claude-3-sonnet-20240229-v1:0" + } } }, "shares_base": { diff --git a/frontend/src/modules/S3_Datasets/components/GenerateMetadataComponent.js b/frontend/src/modules/S3_Datasets/components/GenerateMetadataComponent.js new file mode 100644 index 000000000..f50933bfb --- /dev/null +++ b/frontend/src/modules/S3_Datasets/components/GenerateMetadataComponent.js @@ -0,0 +1,424 @@ +// import { LoadingButton } from '@mui/lab'; +import { + // Autocomplete, + Avatar, + Box, + Button, + // CardContent, + // CardHeader, + Checkbox, + Chip, + Divider, + FormControl, + FormGroup, + FormControlLabel, + FormLabel, + Grid, + InputLabel, + MenuItem, + Select, + Switch, + // TextField, + Typography +} from '@mui/material'; +import { DataGrid } from '@mui/x-data-grid'; +// import { Formik } from 'formik'; +import { useSnackbar } from 'notistack'; +import PropTypes from 'prop-types'; +// import { useCallback, useEffect, useState } from 'react'; +import { useState } from 'react'; +import AutoModeIcon from '@mui/icons-material/AutoMode'; +// import * as Yup from 'yup'; +// import { ChipInput, Defaults } from 'design'; +import { Defaults, Scrollbar } from 'design'; +import { SET_ERROR, useDispatch } from 'globalErrors'; +import { useClient } from 'services'; +import { listDatasetTablesFolders, generateMetadataBedrock } from '../services'; +import { useCallback } from 'react'; + +/* eslint-disable no-console */ +export const GenerateMetadataComponent = (props) => { + const { + dataset, + targetType, + setTargetType, + targets, + setTargets, + targetOptions, + setTargetOptions, + selectedMetadataTypes, + setSelectedMetadataTypes, + currentView, + setCurrentView, + loadingMetadata, + setLoadingMetadata, + version, + setVersion, + ...other + } = props; + const { enqueueSnackbar } = useSnackbar(); + const dispatch = useDispatch(); + + const client = useClient(); + const [loadingTableFolder, setLoadingTableFolder] = useState(false); + const [tableFolderFilter, setTableFolderFilter] = useState(Defaults.filter); + const handleChange = useCallback( + async (event) => { + setTargetType(event.target.value); + if (event.target.value === 'Dataset') { + setTargets([ + { + targetUri: dataset.datasetUri, + targetType: 'S3_Dataset' + } + ]); + } else { + setTargets([]); + setLoadingTableFolder(true); + const response = await client.query( + listDatasetTablesFolders({ + datasetUri: dataset.datasetUri, + filter: tableFolderFilter + }) + ); + if (!response.errors) { + setTargetOptions(response.data.listDatasetTablesFolders); + } else { + dispatch({ + type: SET_ERROR, + error: response.errors[0].message + dataset.datasetUri + }); + } + setLoadingTableFolder(false); + } + }, + [client, dispatch] + ); + + const handleMetadataChange = (event) => { + setSelectedMetadataTypes({ + ...selectedMetadataTypes, + [event.target.name]: event.target.checked + }); + }; + + const handlePageChange = async (page) => { + page += 1; //expecting 1-indexing + if (page <= targetOptions.pages && page !== targetOptions.page) { + await setTableFolderFilter({ ...tableFolderFilter, page: page }); + } + }; + + const generateMetadata = async () => { + setCurrentView('REVIEW_METADATA'); + for (let target of targets) { + let response = await client.mutate( + generateMetadataBedrock({ + resourceUri: target.targetUri, + targetType: target.targetType, + metadataTypes: Object.entries(selectedMetadataTypes) + .filter(([key, value]) => value === true) + .map(([key]) => key), + version: version, + sampleData: {} + }) + ); + if (!response.errors) { + target.description = response.data.generateMetadata.description; + target.label = response.data.generateMetadata.label; + target.name = response.data.generateMetadata.name; + target.tags = response.data.generateMetadata.tags; + target.topics = response.data.generateMetadata.topics; + target.subitem_descriptions = ( + response.data.generateMetadata.subitem_descriptions || [] + ).map((item) => ({ + description: item.description, + label: item.label, + subitem_id: item.subitem_id + })); + const hasNotEnoughData = [ + target.description, + target.label, + target.name, + target.tags, + target.topics, + target.subitem_descriptions + ].some((value) => value === 'NotEnoughData'); + + if (hasNotEnoughData) { + enqueueSnackbar( + `Not enough data to generate metadata for ${target.name}`, + { + anchorOrigin: { + horizontal: 'right', + vertical: 'top' + }, + variant: 'warning' + } + ); + } else { + enqueueSnackbar( + `Metadata generation is successful for ${target.name}`, + { + anchorOrigin: { + horizontal: 'right', + vertical: 'top' + }, + variant: 'success' + } + ); + } + setVersion(version + 1); + } + } + }; + return ( + <> + + + 1} + label="Select target type" + color="primary" + variant="outlined" + /> + + + {targetType && ( + 2} + label="Select target resources" + color="primary" + variant="outlined" + /> + )} + + + {targetType && !!targets.length && ( + 3} + label="Select type of metadata" + color="primary" + variant="outlined" + /> + )} + + + + + + + Target Type * + + + {targetType === 'Dataset' && ( + + Data.all will use the table and folder metadata to generate + Dataset label, description, tags and/or topics using Amazon + Bedrock. + + )} + {targetType === 'TablesAndFolders' && ( + + Data.all will use table column names and table descriptions and + folder S3 prefix names to generate Tables and Folders label, + description, tags and/or topics using Amazon Bedrock. + + )} + + + {targetType === 'Dataset' && ( + } + label={dataset.name} + /> + )} + {targetType === 'TablesAndFolders' && ( + + + node.targetUri} + rows={targetOptions.nodes} + columns={[ + { field: 'targetUri', hide: true }, + { + field: 'name', + headerName: 'Name', + flex: 1.5, + editable: false + }, + { + field: 'targetType', + headerName: 'Type', + flex: 1, + editable: false + } + ]} + rowCount={targetOptions.count} + page={targetOptions.page - 1} + pageSize={targetOptions.pageSize} + paginationMode="server" + onPageChange={handlePageChange} + loading={loadingTableFolder} + onPageSizeChange={(pageSize) => { + setTableFolderFilter({ + ...tableFolderFilter, + pageSize: pageSize + }); + }} + getRowHeight={() => 'auto'} + disableSelectionOnClick + onSelectionModelChange={(newSelectionModel) => { + const selectedTargets = newSelectionModel.map((id) => + targetOptions.nodes.find( + (option) => option.targetUri === id + ) + ); + setTargets(selectedTargets); + if (newSelectionModel.length === 0) { + setSelectedMetadataTypes({}); + } + }} + sx={{ + wordWrap: 'break-word', + '& .MuiDataGrid-row': { + borderBottom: '1px solid rgba(145, 158, 171, 0.24)' + }, + '& .MuiDataGrid-columnHeaders': { + borderBottom: 0.5 + } + }} + /> + + + )} + + + {targetType && !!targets.length && ( + + Metadata + + } + label="Label" + /> + + } + label="Description" + /> + + } + label="Tags" + /> + + } + label="Subitem Descriptions" + /> + + } + label="Topics" + /> + + )} + + + {!loadingMetadata && ( + + )} + + ); +}; + +GenerateMetadataComponent.propTypes = { + dataset: PropTypes.object.isRequired, + targetType: PropTypes.string.isRequired, + setTargetType: PropTypes.func.isRequired, + targets: PropTypes.array.isRequired, + setTargets: PropTypes.func.isRequired, + targetOptions: PropTypes.array.isRequired, + setTargetOptions: PropTypes.func.isRequired, + selectedMetadataTypes: PropTypes.object.isRequired, + setSelectedMetadataTypes: PropTypes.func.isRequired, + currentView: PropTypes.string.isRequired, + setCurrentView: PropTypes.func.isRequired, + version: PropTypes.number.isRequired, + setVersion: PropTypes.func.isRequired +}; diff --git a/frontend/src/modules/S3_Datasets/components/MetadataMainModal.js b/frontend/src/modules/S3_Datasets/components/MetadataMainModal.js new file mode 100644 index 000000000..efde24871 --- /dev/null +++ b/frontend/src/modules/S3_Datasets/components/MetadataMainModal.js @@ -0,0 +1,77 @@ +import { Dialog } from '@mui/material'; +import PropTypes from 'prop-types'; +import { useEffect, useState } from 'react'; +import { Defaults } from 'design'; +import { GenerateMetadataComponent } from './GenerateMetadataComponent'; +import { ReviewMetadataComponent } from './ReviewMetadataComponent'; + +export const MetadataMainModal = (props) => { + const { dataset, onApply, onClose, open, ...other } = props; + const [currentView, setCurrentView] = useState('GENERATE_FORM'); + const [targetType, setTargetType] = useState(''); + const [targets, setTargets] = useState([]); + const [targetOptions, setTargetOptions] = useState([]); + const [version, setVersion] = useState(0); + const [selectedMetadataTypes, setSelectedMetadataTypes] = useState({ + label: false, + description: false, + tags: false, + topics: false, + subitem_descriptions: false + }); + + useEffect(() => { + if (!open) { + setCurrentView('GENERATE_FORM'); + setTargetType(''); + setTargets([]); + setTargetOptions(Defaults.pagedResponse); + setSelectedMetadataTypes({}); + setVersion(0); + } + }, [open]); + + if (!dataset) { + return null; + } + + return ( + + {currentView === 'GENERATE_FORM' && ( + + )} + {currentView === 'REVIEW_METADATA' && ( + + )} + + ); +}; + +MetadataMainModal.propTypes = { + dataset: PropTypes.object.isRequired, + onApply: PropTypes.func, + onClose: PropTypes.func, + open: PropTypes.bool.isRequired +}; diff --git a/frontend/src/modules/S3_Datasets/components/ReviewMetadataComponent.js b/frontend/src/modules/S3_Datasets/components/ReviewMetadataComponent.js new file mode 100644 index 000000000..46488d4b2 --- /dev/null +++ b/frontend/src/modules/S3_Datasets/components/ReviewMetadataComponent.js @@ -0,0 +1,494 @@ +import { + Button, + Box, + Chip, + Typography, + CircularProgress +} from '@mui/material'; +import { DataGrid } from '@mui/x-data-grid'; +import { useSnackbar } from 'notistack'; +import PropTypes from 'prop-types'; +import AutoModeIcon from '@mui/icons-material/AutoMode'; +import { Scrollbar } from 'design'; +import { SET_ERROR, useDispatch } from 'globalErrors'; +import { useClient } from 'services'; +import { updateDataset, generateMetadataBedrock } from '../services'; +import { updateDatasetTable } from 'modules/Tables/services'; +import { BatchUpdateDatasetTableColumn } from '../services/batchUpdateTableColumnDescriptions'; +import { listSampleData } from '../services/listSampleData'; +import { updateDatasetStorageLocation } from 'modules/Folders/services'; +import SampleDataPopup from './SampleDataPopup'; +import React, { useState } from 'react'; +import SubitemDescriptionsGrid from './SubitemDescriptionsGrid'; + +export const ReviewMetadataComponent = (props) => { + const { + dataset, + targets, + setTargets, + selectedMetadataTypes, + version + } = props; + const { enqueueSnackbar } = useSnackbar(); + const dispatch = useDispatch(); + const client = useClient(); + const [popupOpen, setPopupOpen] = useState(false); + const [sampleData, setSampleData] = useState(null); + const [targetUri, setTargetUri] = useState(null); + const [showPopup, setShowPopup] = React.useState(false); + const [subitemDescriptions, setSubitemDescriptions] = React.useState([]); + + const showSubItemsPopup = (subitemDescriptions) => { + setSubitemDescriptions(subitemDescriptions); + setShowPopup(true); + }; + + const closeSubItemsPopup = () => { + setShowPopup(false); + }; + const openSampleDataPopup = (data) => { + setSampleData(data); + setPopupOpen(true); + }; + + const closeSampleDataPopup = () => { + setPopupOpen(false); + setSampleData(null); + }; + async function handleSaveSubitemDescriptions() { + try { + const columns_ = subitemDescriptions.map((item) => ({ + description: item.description, + label: item.label, + subitem_id: item.subitem_id + })); + const response = await client.mutate( + BatchUpdateDatasetTableColumn(columns_) + ); + if (!response.errors) { + enqueueSnackbar('Successfully updated subitem descriptions', { + variant: 'success' + }); + closeSubItemsPopup(); + } else { + dispatch({ type: SET_ERROR, error: response.errors[0].message }); + } + } catch (err) { + dispatch({ type: SET_ERROR, error: err.message }); + } + } + + async function handleRegenerate(table) { + try { + const response = await client.query( + listSampleData({ + tableUri: table.targetUri + }) + ); + openSampleDataPopup(response.data.listSampleData); + setTargetUri(table.targetUri); + if (!response.errors) { + enqueueSnackbar('Successfully read sample data', { + variant: 'success' + }); + } else { + dispatch({ type: SET_ERROR, error: response.errors[0].message }); + } + } catch (err) { + dispatch({ type: SET_ERROR, error: err.message }); + } + } + const handleAcceptAndRegenerate = async () => { + // Perform any necessary actions for accepting and regenerating the data + try { + const targetIndex = targets.findIndex((t) => t.targetUri === targetUri); + if (targetIndex !== -1) { + const { __typename, ...sampleDataWithoutTypename } = sampleData; + const response = await client.mutate( + generateMetadataBedrock({ + resourceUri: targets[targetIndex].targetUri, + targetType: targets[targetIndex].targetType, + metadataTypes: Object.entries(selectedMetadataTypes) + .filter(([key, value]) => value === true) + .map(([key]) => key), + version: version, + sampleData: sampleDataWithoutTypename + }) + ); + + if (!response.errors) { + const updatedTarget = { + ...targets[targetIndex], + description: response.data.generateMetadata.description, + label: response.data.generateMetadata.label, + name: response.data.generateMetadata.name, + tags: response.data.generateMetadata.tags, + topics: response.data.generateMetadata.topics, + subitem_descriptions: + response.data.generateMetadata.subitem_descriptions + }; + + const updatedTargets = [...targets]; + updatedTargets[targetIndex] = updatedTarget; + + setTargets(updatedTargets); + + enqueueSnackbar( + `Metadata generation is successful for ${updatedTarget.name}`, + { + anchorOrigin: { + horizontal: 'right', + vertical: 'top' + }, + variant: 'success' + } + ); + } + } else { + console.error(`Target with targetUri not found`); + enqueueSnackbar(`Metadata generation is unsuccessful`, { + anchorOrigin: { + horizontal: 'right', + vertical: 'top' + }, + variant: 'error' + }); + } + + closeSampleDataPopup(); + } catch (err) { + dispatch({ type: SET_ERROR, error: err.message }); + } + }; + + async function saveMetadata(targets) { + try { + const updatedTargets = targets.map(async (target) => { + const updatedMetadata = {}; + + // Loop through selectedMetadataTypes and add the corresponding key-value pairs to updatedMetadata + Object.entries(selectedMetadataTypes).forEach( + ([metadataType, checked]) => { + if (checked) { + updatedMetadata[metadataType] = target[metadataType]; + } + } + ); + if (target.targetType === 'S3_Dataset') { + updatedMetadata.KmsAlias = dataset.KmsAlias; + const response = await client.mutate( + updateDataset({ + datasetUri: target.targetUri, + input: updatedMetadata + }) + ); + + if (!response.errors) { + return { ...target, success: true }; + } else { + dispatch({ type: SET_ERROR, error: response.errors[0].message }); + return { ...target, success: false }; + } + } else if (target.targetType === 'Table') { + const response = await client.mutate( + updateDatasetTable({ + tableUri: target.targetUri, + input: updatedMetadata + }) + ); + + if (!response.errors) { + return { ...target, success: true }; + } else { + dispatch({ type: SET_ERROR, error: response.errors[0].message }); + return { ...target, success: false }; + } + } else if (target.targetType === 'Folder') { + const response = await client.mutate( + updateDatasetStorageLocation({ + locationUri: target.targetUri, + input: updatedMetadata + }) + ); + + if (!response.errors) { + return { ...target, success: true }; // Return the updated target with success flag + } else { + dispatch({ type: SET_ERROR, error: response.errors[0].message }); + return { ...target, success: false }; // Return the target with success flag set to false + } + } + }); + + const updatedTargetsResolved = await Promise.all(updatedTargets); + + const successfulTargets = updatedTargetsResolved.filter( + (target) => target.success + ); + const failedTargets = updatedTargetsResolved.filter( + (target) => !target.success + ); + + if (successfulTargets.length > 0) { + enqueueSnackbar( + `${successfulTargets.length} target(s) updated successfully`, + { + anchorOrigin: { + horizontal: 'right', + vertical: 'top' + }, + variant: 'success' + } + ); + } + + if (failedTargets.length > 0) { + enqueueSnackbar(`${failedTargets.length} target(s) failed to update`, { + anchorOrigin: { + horizontal: 'right', + vertical: 'top' + }, + variant: 'error' + }); + } + } catch (err) { + console.error(err); + dispatch({ type: SET_ERROR, error: err.message }); + } + } + return ( + <> + {Array.isArray(targets) && targets.length > 0 ? ( + + + + node.targetUri} + rowHeight={80} + columns={[ + { field: 'targetUri', hide: true }, + { + field: 'name', + headerName: 'Name', + flex: 1.5, + editable: false + }, + { + field: 'targetType', + headerName: 'Target Type', + flex: 1.5, + editable: false + }, + { + field: 'label', + headerName: 'Label', + flex: 2, + editable: true, + renderCell: (params) => + params.value === undefined ? ( + + ) : params.value === 'NotEnoughData' ? ( + + ) : ( +
+ {params.value} +
+ ) + }, + { + field: 'description', + headerName: 'Description', + flex: 3, + editable: true, + renderCell: (params) => + params.value === undefined ? ( + + ) : params.value === 'NotEnoughData' ? ( + + ) : ( +
+ {params.value} +
+ ) + }, + { + field: 'tags', + headerName: 'Tags', + flex: 2, + editable: true, + valueSetter: (params) => { + const { id, row, newValue } = params; + const tags = + typeof newValue === 'string' + ? newValue.split(',') + : newValue; + return { ...row, targetUri: id, tags }; + }, + renderCell: (params) => + params.value === undefined ? ( + + ) : params.value[0] === 'NotEnoughData' ? ( + + ) : ( +
+ {Array.isArray(params.value) + ? params.value.join(', ') + : params.value} +
+ ) + }, + { + field: 'topics', + headerName: 'Topics', + flex: 2, + editable: true, + renderCell: (params) => + params.value === undefined ? ( + + ) : params.value[0] === 'NotEnoughData' ? ( + + ) : ( +
+ {params.value} +
+ ) + }, + { + field: 'subitem_descriptions', + headerName: 'Subitem Descriptions', + flex: 3, + editable: false, + renderCell: (params) => + params.value === undefined ? ( + + ) : params.value[0] === 'NotEnoughData' ? ( + + ) : ( + + ) + }, + { + field: 'regenerate', + headerName: 'Regenerate', + flex: 3, + type: 'boolean', + renderCell: (params) => + params.row.targetType === 'Table' ? ( + + ) : ( + '-' + ) + } + ]} + columnVisibilityModel={{ + targetUri: false, + label: selectedMetadataTypes['label'] + ? selectedMetadataTypes['label'] + : false, + description: selectedMetadataTypes['description'] + ? selectedMetadataTypes['description'] + : false, + tags: selectedMetadataTypes['tags'] + ? selectedMetadataTypes['tags'] + : false, + topics: selectedMetadataTypes['topics'] + ? selectedMetadataTypes['topics'] + : false, + subitem_descriptions: selectedMetadataTypes[ + 'subitem_descriptions' + ] + ? selectedMetadataTypes['subitem_descriptions'] + : false + }} + pageSize={10} + rowsPerPageOptions={[5, 10, 20]} + pagination + disableSelectionOnClick + onCellEditCommit={(params) => { + const { value, id, field } = params; + const updatedTargets = targets.map((target) => { + const newTarget = { ...target }; + if (newTarget.targetUri === id) { + newTarget[field] = value; + } + return newTarget; + }); + setTargets(updatedTargets); + }} + onProcessRowUpdateError={(error) => { + console.error('Error updating row:', error); + }} + sx={{ + wordWrap: 'break-word', + '& .MuiDataGrid-row': { + borderBottom: '1px solid rgba(145, 158, 171, 0.24)' + }, + '& .MuiDataGrid-columnHeaders': { + borderBottom: 0.5 + } + }} + /> +
+
+
+ ) : ( + No metadata available + )} + {showPopup && ( + + )} + + + + ); +}; + +ReviewMetadataComponent.propTypes = { + dataset: PropTypes.object.isRequired, + targetType: PropTypes.string.isRequired, + targets: PropTypes.array.isRequired, + setTargets: PropTypes.func.isRequired, + selectedMetadataTypes: PropTypes.object.isRequired, + version: PropTypes.number.isRequired, + setVersion: PropTypes.func.isRequired +}; diff --git a/frontend/src/modules/S3_Datasets/components/SampleDataPopup.js b/frontend/src/modules/S3_Datasets/components/SampleDataPopup.js new file mode 100644 index 000000000..a271db9c7 --- /dev/null +++ b/frontend/src/modules/S3_Datasets/components/SampleDataPopup.js @@ -0,0 +1,43 @@ +import React from 'react'; +import { Modal, Box, Button, Typography } from '@mui/material'; +import SampleDataTableComponent from './SampleDataTableComponent'; + +const SampleDataPopup = ({ + open, + sampleData, + handleClose, + handleRegenerate +}) => { + return ( + + + + + By clicking the button below, you agree to share this sample data with + a third-party language model. + + + {' '} + + + + ); +}; + +export default SampleDataPopup; diff --git a/frontend/src/modules/S3_Datasets/components/SampleDataTableComponent.js b/frontend/src/modules/S3_Datasets/components/SampleDataTableComponent.js new file mode 100644 index 000000000..a15d3b32b --- /dev/null +++ b/frontend/src/modules/S3_Datasets/components/SampleDataTableComponent.js @@ -0,0 +1,52 @@ +import React from 'react'; +import { DataGrid } from '@mui/x-data-grid'; +import { styled } from '@mui/styles'; +import { Card } from '@mui/material'; + +const StyledDataGrid = styled(DataGrid)(({ theme }) => ({ + '& .MuiDataGrid-columnsContainer': { + backgroundColor: + theme.palette.mode === 'dark' + ? 'rgba(29,29,29,0.33)' + : 'rgba(255,255,255,0.38)' + } +})); +const buildHeader = (fields) => + fields.map((field) => ({ + field: JSON.parse(field).name, + headerName: JSON.parse(field).name, + editable: false + })); +const buildRows = (rows, fields) => { + const header = fields.map((field) => JSON.parse(field).name); + const newRows = rows.map((row) => JSON.parse(row)); + const builtRows = newRows.map((row) => + header.map((h, index) => ({ [h]: row[index] })) + ); + const objects = []; + builtRows.forEach((row) => { + const obj = {}; + row.forEach((r) => { + Object.entries(r).forEach(([key, value]) => { + obj[key] = value; + }); + obj.id = Math.random(); + }); + objects.push(obj); + }); + return objects; +}; +const SampleDataTableComponent = ({ data }) => { + return ( + + + + ); +}; + +export default SampleDataTableComponent; diff --git a/frontend/src/modules/S3_Datasets/components/SubitemDescriptionsGrid.js b/frontend/src/modules/S3_Datasets/components/SubitemDescriptionsGrid.js new file mode 100644 index 000000000..4d5fe3613 --- /dev/null +++ b/frontend/src/modules/S3_Datasets/components/SubitemDescriptionsGrid.js @@ -0,0 +1,78 @@ +import React from 'react'; +import { Grid, Typography, Paper, Button } from '@mui/material'; + +const SubitemDescriptionsGrid = ({ subitemDescriptions, onClose, onSave }) => { + return ( +
+ + + + + + + Label + + + + + Description + + + + + {subitemDescriptions.map((item) => ( + + + + + {item.label} + + + + + {item.description} + + + + + ))} + + + + + + + + +
+ ); +}; + +export default SubitemDescriptionsGrid; diff --git a/frontend/src/modules/S3_Datasets/components/index.js b/frontend/src/modules/S3_Datasets/components/index.js index 7e899832c..2fda5268c 100644 --- a/frontend/src/modules/S3_Datasets/components/index.js +++ b/frontend/src/modules/S3_Datasets/components/index.js @@ -6,3 +6,4 @@ export * from './DatasetOverview'; export * from './DatasetStartCrawlerModal'; export * from './DatasetTables'; export * from './DatasetUpload'; +export * from './MetadataMainModal'; diff --git a/frontend/src/modules/S3_Datasets/services/batchUpdateTableColumnDescriptions.js b/frontend/src/modules/S3_Datasets/services/batchUpdateTableColumnDescriptions.js new file mode 100644 index 000000000..e09d883c9 --- /dev/null +++ b/frontend/src/modules/S3_Datasets/services/batchUpdateTableColumnDescriptions.js @@ -0,0 +1,14 @@ +import { gql } from 'apollo-boost'; + +export const BatchUpdateDatasetTableColumn = (columns) => ({ + variables: { + columns + }, + mutation: gql` + mutation BatchUpdateDatasetTableColumn( + $columns: [SubitemDescriptionInput] + ) { + batchUpdateDatasetTableColumn(columns: $columns) + } + ` +}); diff --git a/frontend/src/modules/S3_Datasets/services/generateMetadataBedrock.js b/frontend/src/modules/S3_Datasets/services/generateMetadataBedrock.js new file mode 100644 index 000000000..15176ce08 --- /dev/null +++ b/frontend/src/modules/S3_Datasets/services/generateMetadataBedrock.js @@ -0,0 +1,46 @@ +import { gql } from 'apollo-boost'; + +export const generateMetadataBedrock = ({ + resourceUri, + targetType, + metadataTypes, + version, + sampleData +}) => ({ + variables: { + resourceUri, + targetType, + metadataTypes, + version, + sampleData + }, + mutation: gql` + mutation generateMetadata( + $resourceUri: String! + $targetType: MetadataGenerationTargets + $metadataTypes: [String] + $version: Int + $sampleData: SampleDataInput + ) { + generateMetadata( + resourceUri: $resourceUri + targetType: $targetType + metadataTypes: $metadataTypes + version: $version + sampleData: $sampleData + ) { + type + label + description + tags + topics + name + subitem_descriptions { + description + label + subitem_id + } + } + } + ` +}); diff --git a/frontend/src/modules/S3_Datasets/services/index.js b/frontend/src/modules/S3_Datasets/services/index.js index 53a75912e..a32c84061 100644 --- a/frontend/src/modules/S3_Datasets/services/index.js +++ b/frontend/src/modules/S3_Datasets/services/index.js @@ -1,9 +1,11 @@ export * from './createDataset'; export * from './deleteDataset'; export * from './generateDatasetAccessToken'; +export * from './generateMetadataBedrock'; export * from './getDatasetPresignedUrl'; export * from './importDataset'; export * from './listDatasetStorageLocations'; +export * from './listDatasetTablesFolders'; export * from './startGlueCrawler'; export * from './syncTables'; export * from './updateDataset'; diff --git a/frontend/src/modules/S3_Datasets/services/listDatasetTablesFolders.js b/frontend/src/modules/S3_Datasets/services/listDatasetTablesFolders.js new file mode 100644 index 000000000..4acc23017 --- /dev/null +++ b/frontend/src/modules/S3_Datasets/services/listDatasetTablesFolders.js @@ -0,0 +1,27 @@ +import { gql } from 'apollo-boost'; + +export const listDatasetTablesFolders = ({ datasetUri, filter }) => ({ + variables: { + datasetUri, + filter + }, + query: gql` + query listDatasetTablesFolders( + $datasetUri: String! + $filter: DatasetFilter + ) { + listDatasetTablesFolders(datasetUri: $datasetUri, filter: $filter) { + count + page + pages + hasNext + hasPrevious + nodes { + name + targetType + targetUri + } + } + } + ` +}); diff --git a/frontend/src/modules/S3_Datasets/services/listSampleData.js b/frontend/src/modules/S3_Datasets/services/listSampleData.js new file mode 100644 index 000000000..b9863fa2a --- /dev/null +++ b/frontend/src/modules/S3_Datasets/services/listSampleData.js @@ -0,0 +1,15 @@ +import { gql } from 'apollo-boost'; + +export const listSampleData = ({ tableUri }) => ({ + variables: { + tableUri + }, + query: gql` + query listSampleData($tableUri: String!) { + listSampleData(tableUri: $tableUri) { + fields + rows + } + } + ` +}); diff --git a/frontend/src/modules/S3_Datasets/views/DatasetView.js b/frontend/src/modules/S3_Datasets/views/DatasetView.js index 6f1800a64..c9ea04eda 100644 --- a/frontend/src/modules/S3_Datasets/views/DatasetView.js +++ b/frontend/src/modules/S3_Datasets/views/DatasetView.js @@ -7,6 +7,7 @@ import { Upload, ViewArrayOutlined } from '@mui/icons-material'; +import AutoModeIcon from '@mui/icons-material/AutoMode'; import { Box, Breadcrumbs, @@ -42,7 +43,8 @@ import { DatasetAWSActions, DatasetData, DatasetOverview, - DatasetUpload + DatasetUpload, + MetadataMainModal } from '../components'; import { isFeatureEnabled, isModuleEnabled, ModuleNames } from 'utils'; import { RequestAccessModal } from 'modules/Catalog/components'; @@ -127,6 +129,7 @@ const DatasetView = () => { const [isRequestAccessOpen, setIsRequestAccessOpen] = useState(false); const [isOpeningModal, setIsOpeningModal] = useState(false); + const [isMetadataModalOpen, setIsMetadataModalOpen] = useState(false); const handleRequestAccessModalOpen = () => { setIsOpeningModal(true); setIsRequestAccessOpen(true); @@ -135,6 +138,15 @@ const DatasetView = () => { const handleRequestAccessModalClose = () => { setIsRequestAccessOpen(false); }; + + const handleMetadataModalOpen = () => { + setIsMetadataModalOpen(true); + }; + + const handleMetadataModalClose = () => { + setIsMetadataModalOpen(false); + }; + const reloadVotes = async () => { const response = await client.query(countUpVotes(params.uri, 'dataset')); if (!response.errors && response.data.countUpVotes !== null) { @@ -266,6 +278,28 @@ const DatasetView = () => { + {isFeatureEnabled('s3_datasets', 'generate_metadata_ai') && ( + + )} + + {isFeatureEnabled('s3_datasets', 'generate_metadata_ai') && ( + + )} + {isAdmin && (