diff --git a/stix_shifter_modules/aws_athena/stix_transmission/post_query_connector_error_handling.py b/stix_shifter_modules/aws_athena/stix_transmission/post_query_connector_error_handling.py new file mode 100644 index 000000000..d8d4b846b --- /dev/null +++ b/stix_shifter_modules/aws_athena/stix_transmission/post_query_connector_error_handling.py @@ -0,0 +1,112 @@ + +from venv import logger +import regex as re +import time +from stix_shifter_modules.aws_athena.stix_transmission import status_connector + +class PostQueryConnectorErrorHandler(): + async def check_status_for_missing_column(client, search_id, query) -> None: + """Creates a status check loop to see if the query fails with a column doesn't exist exception. If it does, return the query with the offending column removed. + If it does not, return with True + + Args: + client (RestApiClientAsync) + search_id (String): For each query sent to Athena, a job is created that runs in till it finishes. This ID is used to find the job and access it's results. + query (String): The query that will be modified if a missing column error occurs. + + Returns: + String: Returns either a modified query, or a special key that can be used to know that the query was successful and to stop. + """ + status = status_connector.StatusConnector(client) + column_to_delete = "" + + #Wait ten seconds if the status is RUNNING. Exits early if the status is not RUNNING. + for i in range(10): + time.sleep(1) + status_response = await status.create_status_connection(search_id) + #Checks if there is a message that can be read and if that message matches the column not found message. + if(status_response != None and "message" in status_response): + match = re.search(f"Column '(.*)' cannot be resolved", status_response["message"]) + if(match): + #If there is a match, return the column name. + column_to_delete = match.group(1) + break + elif(status_response != None and "status" in status_response and status_response["status"] != "RUNNING"): + #If there is no match and the status is not running, than stop trying and exit with the success message. + break + + if(column_to_delete != ""): + return PostQueryConnectorErrorHandler._remove_invalid_column_table(column_to_delete, query) + else: + #May not always be successful, just that no column error is occurring. + return True + + + def _remove_invalid_column_table(column_to_remove, query): + """ Uses regex to iterate over the query and replace all comparison operations using the invalid column + + Args: + column_to_remove (string): This is the name of the column that should be removed from the query. + query (String): This is the current query that is failing. It should contain a column that needs to be removed. + + Returns: + String : Returns the modified query with the column comparisons replaced with either TRUE or FALSE. + """ + #These are the possible forms for a left expression in a comparison + COLUMN_NAME_PATTERN="([\\w\\d]*(?:\\.[\\w\\d]+)*)" + VARCHAR_CAST_LEFT_EXPRESSION = f"CAST\\({COLUMN_NAME_PATTERN} as varchar\\)" + REAL_CAST_LEFT_EXPRESSION = f"CAST\\({COLUMN_NAME_PATTERN} as real\\)" + LOWER_LEFT_EXPRESSION = f"lower\\({COLUMN_NAME_PATTERN}\\)" + #These are the possible forms for a right expression in a comparison + LOWER_RIGHT_EXPRESSION="lower\\(.*?\\)" + BRACKET_RIGHT_EXPRESSION="\\(.*?\\)" + QUOTE_RIGHT_EXPRESSION="\\\'.*?\\\'" + + #These are the possible forms for an operator + OPERATORS=">|>=|<|<=|!=|LIKE|IN|=" + + #The general format for the pattern is {left expression} {operator} {right expression}. In order to get the column name, it needs to match on the left expression. + standard_pattern = f"((?:{VARCHAR_CAST_LEFT_EXPRESSION}|{REAL_CAST_LEFT_EXPRESSION}|{LOWER_LEFT_EXPRESSION}|{COLUMN_NAME_PATTERN}) ({OPERATORS}) (?:{LOWER_RIGHT_EXPRESSION}|{BRACKET_RIGHT_EXPRESSION}|{QUOTE_RIGHT_EXPRESSION}))" + #Match Expressions are unique. They act as a function, for example regexp(string, pattern). Standard pattern is like ID = 5 format. + match_pattern = f"((REGEXP_LIKE)\\((?:{VARCHAR_CAST_LEFT_EXPRESSION}|{REAL_CAST_LEFT_EXPRESSION}|{LOWER_LEFT_EXPRESSION}|{COLUMN_NAME_PATTERN}|{QUOTE_RIGHT_EXPRESSION}), '.*?'\\))" + + logger.debug(f"The failing column name : {column_to_remove}") + logger.debug(f"Current query : {query}") + logger.debug(f"Attempt to match the standard comparison pattern : {standard_pattern}") + + #Matches against the standard pattern, this gets all of the comparison expressions in the query except for TRUE/FALSE. + #It checks each comparison against the offending column and replaces any that match with TRUE or FALSE + all_comparison_strings = re.findall(standard_pattern, query, flags=re.IGNORECASE) + if (len(all_comparison_strings) > 0): + for comparison in all_comparison_strings: + filtered_comparison_list = [item for item in comparison if item != ""] + if(column_to_remove in filtered_comparison_list[1]): + logger.debug(f"The following comparison expression will be replaced (standard): {filtered_comparison_list[0]}" ) + + #If the column doesn't exist and the comparison is a != it will always be true. + #In the case of <,>,>=,<=, the number doesn't exist, thus it is false. + #In the case of =, it will never = the value, thus it must be false. + #In the case of IN or LIKE, the value will always resolve to FALSE because something, can't be in nothing. + #Match will always resolve to false. This one may be true if the match is impossible (that however is a weird edge case), thus false. + if("!=" in filtered_comparison_list[2]): + query = query.replace(comparison[0], f"TRUE") + else: + query = query.replace(comparison[0], f"FALSE") + + #Matches against the match pattern, this gets all of the comparison expressions in the query except for TRUE/FALSE. + #It checks each comparison against the offending column and replaces any that match with TRUE or FALSE + all_match_strings = re.findall(match_pattern, query, flags=re.IGNORECASE) + if (len(all_match_strings) > 0): + for comparison in all_match_strings: + filtered_comparison_list = [item for item in comparison if item != ""] + if(column_to_remove in filtered_comparison_list[2]): + logger.debug(f"The following comparison expression will be replaced (match) : {filtered_comparison_list[0]}" ) + + #If the column doesn't exist and the comparison is a != it will always be true. + #In the case of <,>,>=,<=, the number doesn't exist, thus it is false. + #In the case of =, it will never = the value, thus it must be false. + #In the case of IN or LIKE, the value will always resolve to FALSE because something, can't be in nothing. + #Match will always resolve to false. This one may be true if the match is impossible (that however is a weird edge case), thus false. + query = query.replace(comparison[0], f"FALSE") + + return query \ No newline at end of file diff --git a/stix_shifter_modules/aws_athena/stix_transmission/query_connector.py b/stix_shifter_modules/aws_athena/stix_transmission/query_connector.py index e40dbb886..18364b501 100644 --- a/stix_shifter_modules/aws_athena/stix_transmission/query_connector.py +++ b/stix_shifter_modules/aws_athena/stix_transmission/query_connector.py @@ -1,3 +1,6 @@ +from venv import logger +from stix_shifter_modules.aws_athena.stix_transmission import status_connector +from stix_shifter_modules.aws_athena.stix_transmission.post_query_connector_error_handling import PostQueryConnectorErrorHandler from stix_shifter_utils.modules.base.stix_transmission.base_connector import BaseQueryConnector from stix_shifter_utils.utils.error_response import ErrorResponder import json @@ -25,6 +28,7 @@ def __init__(self, client, connection): self.client = client self.connection = connection self.connector = __name__.split('.')[1] + self.total_try_count = 0 async def create_query_connection(self, query): """ @@ -47,6 +51,7 @@ async def create_query_connection(self, query): raise InvalidParameterException("{} is required for {} query operation".format(config, query_service_type)) table_config = self.connection[config_details[0]] + '."' + self.connection[config_details[1]] + '"' + other_tables = '' findall = re.finditer("##UNNEST.*?##", query[query_service_type]) if findall: @@ -57,7 +62,7 @@ async def create_query_connection(self, query): other_tables += ' %s%s%s ' % ('LEFT JOIN ', match_str.replace('##', ''), ' ON TRUE ') if query_service_type == 'ocsf': - columns = await self.column_list(self.connection[config_details[1]]) + columns = await self.column_list(self.connection[config_details[0]], self.connection[config_details[1]]) column_cast = [] for column in columns: column_cast.append("CAST(%s as JSON) AS %s" % (column, column)) @@ -65,28 +70,64 @@ async def create_query_connection(self, query): select_statement = "SELECT %s FROM %s%s WHERE " % (", ".join(column_cast), table_config, other_tables) else: select_statement = "SELECT %s.* FROM %s%s WHERE " % (table_config, table_config, other_tables) + + #self.get_list_of_columns_and_rows(query[query_service_type]) + #await self.row_list(self.connection[config_details[0]], self.connection[config_details[1]]) # for multiple observation operators union and intersect, select statement will be added if 'UNION' in query[query_service_type] or 'INTERSECT' in query[query_service_type]: query_string = re.sub(r'\(\(', '(({}'.format(select_statement), query[query_service_type], 1) - query = query_string.replace('UNION (', 'UNION ({}'.format(select_statement)).\ + query_with_select = query_string.replace('UNION (', 'UNION ({}'.format(select_statement)).\ replace('INTERSECT (', 'INTERSECT ({}'.format(select_statement)) else: - query = select_statement + query[query_service_type] + query_with_select = select_statement + query[query_service_type] result_config = self.get_result_config() - query_args = {"QueryString": query, "ResultConfiguration": result_config} - response_dict = await self.client.makeRequest('athena', 'start_query_execution', **query_args) - return_obj['success'] = True - return_obj['search_id'] = response_dict['QueryExecutionId'] + ":" + query_service_type - + + return await self.query_api(query_with_select, result_config, query_service_type, return_obj) + except Exception as ex: response_dict['__type'] = ex.__class__.__name__ response_dict['message'] = ex ErrorResponder.fill_error(return_obj, response_dict, ['message'], connector=self.connector) return return_obj - async def column_list(self, table): + async def query_api(self, query, result_config, query_service_type, return_obj): + """Creates a query job and ensures that none of the columns requested are missing from the query. + + Args: + query (String): The original query without modification. + result_config (Dict): Query configuration. + query_service_type (String): The type of query ("OCSF", "VPCFlow, etc) + return_obj (Dict): Contains the metadata about the request. + + Returns: + Dict: The metadata about the request such as the query/search ID and the status. + """ + logger.debug(f"The current query is : {query}") + + #Creates the initial query job. + query_args = {"QueryString": query, "ResultConfiguration": result_config} + response_dict = await self.client.makeRequest('athena', 'start_query_execution', **query_args) + return_obj['search_id'] = response_dict['QueryExecutionId'] + ":" + query_service_type + + modified_query = dict() + modified_query = await PostQueryConnectorErrorHandler.check_status_for_missing_column(self.client, return_obj['search_id'], query) + #If the query is successful (or the exception isn't column related) than it's considered a success and exits. + #If 10 columns are not found or it fails to replace a column 10 times, than it exits (to prevent endless loops). + #If the query is not successful, than it will retry the query with the modified query. + if(modified_query == True or self.total_try_count > 10): + logger.debug(f"The number of attempts to remove missing columns was {self.total_try_count}") + if(self.total_try_count >= 10): + logger.warn("There were 10 failed exceptions related to columns. This could be because there were more invalid columns than 10, \ + or alternatively that the replacement failed to remove the offending column.") + return_obj['success'] = True + return return_obj + else: + self.total_try_count = self.total_try_count + 1 + return await self.query_api(modified_query, result_config, query_service_type, return_obj) + + async def column_list(self, database, table): columns = [] - query = "SELECT column_name FROM information_schema.columns WHERE table_name = '%s'" % table + query = f"SELECT column_name,data_type FROM information_schema.columns WHERE table_name = '{table}' AND table_schema = '{database}'" result_config = self.get_result_config() query_args = {"QueryString": query, "ResultConfiguration": result_config} response_dict = await self.client.makeRequest('athena', 'start_query_execution', **query_args) @@ -133,6 +174,13 @@ async def column_list(self, table): return columns + def get_rows_from_response(self, data, parent, row_list): + data = data.casefold() + ##Root of the tree. + if(data.startswith("row(".casefold()) or data.startwith("array(".casefold())): + remainder = data[data.find("("):data.rfind(")")] + self.get_rows_from_response(remainder, parent, row_list) + def get_result_config(self): """ Output location and encryption configuration are added @@ -152,3 +200,4 @@ def get_result_config(self): output_location = 's3://' + path result_config = {'OutputLocation': output_location} return result_config + \ No newline at end of file diff --git a/stix_shifter_modules/aws_athena/stix_transmission/status_connector.py b/stix_shifter_modules/aws_athena/stix_transmission/status_connector.py index b6f8e14ad..ef8e08070 100644 --- a/stix_shifter_modules/aws_athena/stix_transmission/status_connector.py +++ b/stix_shifter_modules/aws_athena/stix_transmission/status_connector.py @@ -56,6 +56,8 @@ async def create_status_connection(self, search_id): return_obj['status'] = self._getstatus(response_dict.get('QueryExecution', 'FAILED'). get('Status', 'FAILED'). get('State', 'FAILED')) + if (response_dict != None and "QueryExecution" in response_dict and "Status" in response_dict["QueryExecution"] and "StateChangeReason" in response_dict["QueryExecution"]["Status"]): + return_obj['message'] = response_dict["QueryExecution"]["Status"]["StateChangeReason"] if return_obj['status'] == 'COMPLETED': return_obj['progress'] = 100 elif return_obj['status'] == 'RUNNING': diff --git a/stix_shifter_modules/aws_athena/tests/stix_transmission/test_aws_athena.py b/stix_shifter_modules/aws_athena/tests/stix_transmission/test_aws_athena.py index cf88f0963..e9067d3fe 100644 --- a/stix_shifter_modules/aws_athena/tests/stix_transmission/test_aws_athena.py +++ b/stix_shifter_modules/aws_athena/tests/stix_transmission/test_aws_athena.py @@ -223,6 +223,47 @@ def get_query_execution(**kwargs): } } return json_response + + @staticmethod + def get_query_execution_running_status(**kwargs): + json_response = { + 'QueryExecution': { + 'QueryExecutionId': '3fdb8f84-6ad6-4f7c-8e9e-7bf3db87c274', + 'Query': 'SELECT * FROM logs_db.vpc_flow_logs limit 1', + 'StatementType': 'DML', + 'ResultConfiguration': { + 'OutputLocation': 's3://queryresults-athena-s3/3fdb8f84-6ad6-4f7c-8e9e-7bf3db87c274.csv' + }, + 'QueryExecutionContext': {}, + 'Status': { + 'State': 'RUNNING', + 'SubmissionDateTime': datetime.datetime(2020, 9, 30, 13, 31, 28, 856000, tzinfo=tzlocal()), + 'CompletionDateTime': datetime.datetime(2020, 9, 30, 13, 31, 31, 313000, tzinfo=tzlocal()) + }, + 'Statistics': { + 'EngineExecutionTimeInMillis': 2280, + 'DataScannedInBytes': 1493800, + 'TotalExecutionTimeInMillis': 2457, + 'QueryQueueTimeInMillis': 117, + 'QueryPlanningTimeInMillis': 1819, + 'ServiceProcessingTimeInMillis': 60 + }, + 'WorkGroup': 'primary' + }, + 'ResponseMetadata': { + 'RequestId': '870cfb1e-734f-40b2-bab0-cc44affa21d4', + 'HTTPStatusCode': 200, + 'HTTPHeaders': { + 'content-type': 'application/x-amz-json-1.1', + 'date': 'Wed, 30 Sep 2020 08:10:16 GMT', + 'x-amzn-requestid': '870cfb1e-734f-40b2-bab0-cc44affa21d4', + 'content-length': '1275', + 'connection': 'keep-alive' + }, + 'RetryAttempts': 0 + } + } + return json_response class MockStatusResponseRunning: @@ -393,7 +434,6 @@ def stop_query_execution(**kwargs): response = {'Error': {'Code': 'authentication_fail', 'Message': 'Unable to access the data'}} return ClientError(response, 'test4') - class TestAWSConnection(unittest.TestCase): @staticmethod def test_is_async(): @@ -402,10 +442,27 @@ def test_is_async(): assert check_async @staticmethod + @patch('stix_shifter_modules.aws_athena.stix_transmission.post_query_connector_error_handling.PostQueryConnectorErrorHandler.check_status_for_missing_column') @patch('stix_shifter_modules.aws_athena.stix_transmission.boto3_client.BOTO3Client.makeRequest') - def test_create_query_connection(mock_start_query): + def test_create_query_connection(mock_start_query, mock_status_query): mock_start_query.return_value = get_aws_mock_response(AWSMockJsonResponse.start_query_execution(**{})) + mock_status_query.return_value = "CONNECTOR_FACTORY_SUCCESS" + query = """{"vpcflow": "endtime >= 1588310653 AND starttime BETWEEN 1588322590 AND 1604054590"}""" + transmission = stix_transmission.StixTransmission('aws_athena', CONNECTION, CONFIGURATION) + query_response = transmission.query(query) + + assert query_response is not None + assert 'success' in query_response + assert query_response['success'] is True + assert 'search_id' in query_response + assert query_response['search_id'] == "4214e100-9990-4161-9038-b431ec45661a:vpcflow" + @staticmethod + @patch('stix_shifter_modules.aws_athena.stix_transmission.post_query_connector_error_handling.PostQueryConnectorErrorHandler.check_status_for_missing_column') + @patch('stix_shifter_modules.aws_athena.stix_transmission.boto3_client.BOTO3Client.makeRequest') + def test_create_query_connection_missing_column(mock_start_query, mock_status_query): + mock_start_query.return_value = get_aws_mock_response(AWSMockJsonResponse.start_query_execution(**{})) + mock_status_query.return_value = "KEEP_TRYING_QUERY" query = """{"vpcflow": "endtime >= 1588310653 AND starttime BETWEEN 1588322590 AND 1604054590"}""" transmission = stix_transmission.StixTransmission('aws_athena', CONNECTION, CONFIGURATION) query_response = transmission.query(query) @@ -416,6 +473,7 @@ def test_create_query_connection(mock_start_query): assert 'search_id' in query_response assert query_response['search_id'] == "4214e100-9990-4161-9038-b431ec45661a:vpcflow" + @staticmethod @patch('stix_shifter_modules.aws_athena.stix_transmission.boto3_client.BOTO3Client.makeRequest') def test_create_query_exception(mock_start_query): @@ -431,9 +489,11 @@ def test_create_query_exception(mock_start_query): assert query_response['code'] == "authentication_fail" @staticmethod + @patch('stix_shifter_modules.aws_athena.stix_transmission.post_query_connector_error_handling.PostQueryConnectorErrorHandler.check_status_for_missing_column') @patch('stix_shifter_modules.aws_athena.stix_transmission.boto3_client.BOTO3Client.makeRequest') - def test_iam_create_query_connection(mock_start_query): + def test_iam_create_query_connection(mock_start_query, mock_status_query): mock_start_query.return_value = get_aws_mock_response(AWSMockJsonResponse.start_query_execution(**{})) + mock_status_query.return_value = "CONNECTOR_FACTORY_SUCCESS" query = """{ "vpcflow": "(CAST(destinationport AS varchar) IN ('38422', '38420') AND starttime BETWEEN 1603975773 AND \ 1603976073 LIMIT 100) UNION (CAST(destinationport AS varchar) = '32791' AND starttime BETWEEN 1603975773 \ diff --git a/stix_shifter_modules/aws_athena/tests/stix_transmission/test_aws_athena_post_query_error_handler.py b/stix_shifter_modules/aws_athena/tests/stix_transmission/test_aws_athena_post_query_error_handler.py new file mode 100644 index 000000000..bd9e41dfb --- /dev/null +++ b/stix_shifter_modules/aws_athena/tests/stix_transmission/test_aws_athena_post_query_error_handler.py @@ -0,0 +1,128 @@ + +import unittest +from unittest.mock import patch +import stix_shifter_modules + +from stix_shifter_modules.aws_athena.stix_transmission.boto3_client import BOTO3Client +from stix_shifter_modules.aws_athena.stix_transmission.post_query_connector_error_handling import PostQueryConnectorErrorHandler +from stix_shifter_modules.aws_athena.tests.stix_transmission.test_aws_athena import AWSMockJsonResponse +from tests.utils.async_utils import get_aws_mock_response + +CONFIGURATION = { + "auth": { + "aws_access_key_id": "abc", + "aws_secret_access_key": "xyz" + } +} + +IAM_CONFIG = { + "auth": { + "aws_access_key_id": "abc", + "aws_secret_access_key": "xyz", + "aws_iam_role": "ABC" + } +} + +CONNECTION = { + "region": "us-east-1", + "s3_bucket_location": "s3://queryresults-athena-s3/", + "vpcflow_database_name": "all", + "vpcflow_table_name": "gd_logs", + "guardduty_database_name": "gd_logs", + "guardduty_table_name": "gd_logs" +} + + +ip_address_column_not_found = {'success': True, 'status': 'ERROR', 'message': "COLUMN_NOT_FOUND: line 1:877: Column 'dst_endpoint.ip' cannot be resolved or requester is not authorized to access requested resources", 'progress': 0} +equal_query = 'SELECT CAST(metadata as JSON) AS metadata, CAST(time as JSON) AS time, CAST(cloud as JSON) AS cloud, CAST(api as JSON) AS api, CAST(ref_event_uid as JSON) AS ref_event_uid, CAST(src_endpoint as JSON) AS src_endpoint, CAST(resources as JSON) AS resources, CAST(identity as JSON) AS identity, CAST(http_request as JSON) AS http_request, CAST(class_name as JSON) AS class_name, CAST(class_uid as JSON) AS class_uid, CAST(category_name as JSON) AS category_name, CAST(category_uid as JSON) AS category_uid, CAST(severity_id as JSON) AS severity_id, CAST(severity as JSON) AS severity, CAST(activity_name as JSON) AS activity_name, CAST(activity_id as JSON) AS activity_id, CAST(type_uid as JSON) AS type_uid, CAST(type_name as JSON) AS type_name, CAST(unmapped as JSON) AS unmapped FROM cf_moose_db."cloudtrail" WHERE ((lower(dst_endpoint.ip) = lower(\'192.168.0.52\') OR lower(src_endpoint.ip) = lower(\'192.168.0.52\')) AND time BETWEEN 1641026590000 AND 1698748990000)' +doesnt_equal_query = 'SELECT CAST(metadata as JSON) AS metadata, CAST(time as JSON) AS time, CAST(cloud as JSON) AS cloud, CAST(api as JSON) AS api, CAST(ref_event_uid as JSON) AS ref_event_uid, CAST(src_endpoint as JSON) AS src_endpoint, CAST(resources as JSON) AS resources, CAST(identity as JSON) AS identity, CAST(http_request as JSON) AS http_request, CAST(class_name as JSON) AS class_name, CAST(class_uid as JSON) AS class_uid, CAST(category_name as JSON) AS category_name, CAST(category_uid as JSON) AS category_uid, CAST(severity_id as JSON) AS severity_id, CAST(severity as JSON) AS severity, CAST(activity_name as JSON) AS activity_name, CAST(activity_id as JSON) AS activity_id, CAST(type_uid as JSON) AS type_uid, CAST(type_name as JSON) AS type_name, CAST(unmapped as JSON) AS unmapped FROM cf_moose_db."cloudtrail" WHERE ((lower(dst_endpoint.ip) != lower(\'192.168.0.52\') OR lower(src_endpoint.ip) != lower(\'192.168.0.52\')) AND time BETWEEN 1641026590000 AND 1698748990000)' +greater_than_query = 'SELECT CAST(metadata as JSON) AS metadata, CAST(time as JSON) AS time, CAST(cloud as JSON) AS cloud, CAST(api as JSON) AS api, CAST(ref_event_uid as JSON) AS ref_event_uid, CAST(src_endpoint as JSON) AS src_endpoint, CAST(resources as JSON) AS resources, CAST(identity as JSON) AS identity, CAST(http_request as JSON) AS http_request, CAST(class_name as JSON) AS class_name, CAST(class_uid as JSON) AS class_uid, CAST(category_name as JSON) AS category_name, CAST(category_uid as JSON) AS category_uid, CAST(severity_id as JSON) AS severity_id, CAST(severity as JSON) AS severity, CAST(activity_name as JSON) AS activity_name, CAST(activity_id as JSON) AS activity_id, CAST(type_uid as JSON) AS type_uid, CAST(type_name as JSON) AS type_name, CAST(unmapped as JSON) AS unmapped FROM cf_moose_db."cloudtrail" WHERE ((lower(dst_endpoint.ip) > lower(\'192.168.0.52\') OR lower(src_endpoint.ip) > lower(\'192.168.0.52\')) AND time BETWEEN 1641026590000 AND 1698748990000)' +in_query = 'SELECT CAST(metadata as JSON) AS metadata, CAST(time as JSON) AS time, CAST(cloud as JSON) AS cloud, CAST(api as JSON) AS api, CAST(ref_event_uid as JSON) AS ref_event_uid, CAST(src_endpoint as JSON) AS src_endpoint, CAST(resources as JSON) AS resources, CAST(identity as JSON) AS identity, CAST(http_request as JSON) AS http_request, CAST(class_name as JSON) AS class_name, CAST(class_uid as JSON) AS class_uid, CAST(category_name as JSON) AS category_name, CAST(category_uid as JSON) AS category_uid, CAST(severity_id as JSON) AS severity_id, CAST(severity as JSON) AS severity, CAST(activity_name as JSON) AS activity_name, CAST(activity_id as JSON) AS activity_id, CAST(type_uid as JSON) AS type_uid, CAST(type_name as JSON) AS type_name, CAST(unmapped as JSON) AS unmapped FROM cf_moose_db."cloudtrail" WHERE ((dst_endpoint.ip IN (\'172.31.92.201\', \'172.31.30.227\') OR src_endpoint.ip IN (\'172.31.92.201\', \'172.31.30.227\')) AND time BETWEEN 1641026590000 AND 1698748990000)' +like_query = 'SELECT CAST(metadata as JSON) AS metadata, CAST(time as JSON) AS time, CAST(cloud as JSON) AS cloud, CAST(api as JSON) AS api, CAST(ref_event_uid as JSON) AS ref_event_uid, CAST(src_endpoint as JSON) AS src_endpoint, CAST(resources as JSON) AS resources, CAST(identity as JSON) AS identity, CAST(http_request as JSON) AS http_request, CAST(class_name as JSON) AS class_name, CAST(class_uid as JSON) AS class_uid, CAST(category_name as JSON) AS category_name, CAST(category_uid as JSON) AS category_uid, CAST(severity_id as JSON) AS severity_id, CAST(severity as JSON) AS severity, CAST(activity_name as JSON) AS activity_name, CAST(activity_id as JSON) AS activity_id, CAST(type_uid as JSON) AS type_uid, CAST(type_name as JSON) AS type_name, CAST(unmapped as JSON) AS unmapped FROM cf_moose_db."cloudtrail" WHERE ((lower(dst_endpoint.ip) LIKE lower(\'192.168.0.52\') OR lower(src_endpoint.ip) LIKE lower(\'192.168.0.52\')) AND time BETWEEN 1641026590000 AND 1698748990000)' +match_query = 'SELECT CAST(metadata as JSON) AS metadata, CAST(time as JSON) AS time, CAST(cloud as JSON) AS cloud, CAST(api as JSON) AS api, CAST(ref_event_uid as JSON) AS ref_event_uid, CAST(src_endpoint as JSON) AS src_endpoint, CAST(resources as JSON) AS resources, CAST(identity as JSON) AS identity, CAST(http_request as JSON) AS http_request, CAST(class_name as JSON) AS class_name, CAST(class_uid as JSON) AS class_uid, CAST(category_name as JSON) AS category_name, CAST(category_uid as JSON) AS category_uid, CAST(severity_id as JSON) AS severity_id, CAST(severity as JSON) AS severity, CAST(activity_name as JSON) AS activity_name, CAST(activity_id as JSON) AS activity_id, CAST(type_uid as JSON) AS type_uid, CAST(type_name as JSON) AS type_name, CAST(unmapped as JSON) AS unmapped FROM cf_moose_db."cloudtrail" WHERE ((REGEXP_LIKE(CAST(dst_endpoint.ip as varchar), \'192.168.0.52\') OR REGEXP_LIKE(CAST(src_endpoint.ip as varchar), \'192.168.0.52\')) AND time BETWEEN 1641026590000 AND 1698748990000)' + + +equal_query_replaced_column = 'SELECT CAST(metadata as JSON) AS metadata, CAST(time as JSON) AS time, CAST(cloud as JSON) AS cloud, CAST(api as JSON) AS api, CAST(ref_event_uid as JSON) AS ref_event_uid, CAST(src_endpoint as JSON) AS src_endpoint, CAST(resources as JSON) AS resources, CAST(identity as JSON) AS identity, CAST(http_request as JSON) AS http_request, CAST(class_name as JSON) AS class_name, CAST(class_uid as JSON) AS class_uid, CAST(category_name as JSON) AS category_name, CAST(category_uid as JSON) AS category_uid, CAST(severity_id as JSON) AS severity_id, CAST(severity as JSON) AS severity, CAST(activity_name as JSON) AS activity_name, CAST(activity_id as JSON) AS activity_id, CAST(type_uid as JSON) AS type_uid, CAST(type_name as JSON) AS type_name, CAST(unmapped as JSON) AS unmapped FROM cf_moose_db."cloudtrail" WHERE ((FALSE OR lower(src_endpoint.ip) = lower(\'192.168.0.52\')) AND time BETWEEN 1641026590000 AND 1698748990000)' +doesnt_equal_query_replaced_column = 'SELECT CAST(metadata as JSON) AS metadata, CAST(time as JSON) AS time, CAST(cloud as JSON) AS cloud, CAST(api as JSON) AS api, CAST(ref_event_uid as JSON) AS ref_event_uid, CAST(src_endpoint as JSON) AS src_endpoint, CAST(resources as JSON) AS resources, CAST(identity as JSON) AS identity, CAST(http_request as JSON) AS http_request, CAST(class_name as JSON) AS class_name, CAST(class_uid as JSON) AS class_uid, CAST(category_name as JSON) AS category_name, CAST(category_uid as JSON) AS category_uid, CAST(severity_id as JSON) AS severity_id, CAST(severity as JSON) AS severity, CAST(activity_name as JSON) AS activity_name, CAST(activity_id as JSON) AS activity_id, CAST(type_uid as JSON) AS type_uid, CAST(type_name as JSON) AS type_name, CAST(unmapped as JSON) AS unmapped FROM cf_moose_db."cloudtrail" WHERE ((TRUE OR lower(src_endpoint.ip) != lower(\'192.168.0.52\')) AND time BETWEEN 1641026590000 AND 1698748990000)' +greater_than_query_replaced_column = 'SELECT CAST(metadata as JSON) AS metadata, CAST(time as JSON) AS time, CAST(cloud as JSON) AS cloud, CAST(api as JSON) AS api, CAST(ref_event_uid as JSON) AS ref_event_uid, CAST(src_endpoint as JSON) AS src_endpoint, CAST(resources as JSON) AS resources, CAST(identity as JSON) AS identity, CAST(http_request as JSON) AS http_request, CAST(class_name as JSON) AS class_name, CAST(class_uid as JSON) AS class_uid, CAST(category_name as JSON) AS category_name, CAST(category_uid as JSON) AS category_uid, CAST(severity_id as JSON) AS severity_id, CAST(severity as JSON) AS severity, CAST(activity_name as JSON) AS activity_name, CAST(activity_id as JSON) AS activity_id, CAST(type_uid as JSON) AS type_uid, CAST(type_name as JSON) AS type_name, CAST(unmapped as JSON) AS unmapped FROM cf_moose_db."cloudtrail" WHERE ((FALSE OR lower(src_endpoint.ip) > lower(\'192.168.0.52\')) AND time BETWEEN 1641026590000 AND 1698748990000)' +in_query_replaced_column = 'SELECT CAST(metadata as JSON) AS metadata, CAST(time as JSON) AS time, CAST(cloud as JSON) AS cloud, CAST(api as JSON) AS api, CAST(ref_event_uid as JSON) AS ref_event_uid, CAST(src_endpoint as JSON) AS src_endpoint, CAST(resources as JSON) AS resources, CAST(identity as JSON) AS identity, CAST(http_request as JSON) AS http_request, CAST(class_name as JSON) AS class_name, CAST(class_uid as JSON) AS class_uid, CAST(category_name as JSON) AS category_name, CAST(category_uid as JSON) AS category_uid, CAST(severity_id as JSON) AS severity_id, CAST(severity as JSON) AS severity, CAST(activity_name as JSON) AS activity_name, CAST(activity_id as JSON) AS activity_id, CAST(type_uid as JSON) AS type_uid, CAST(type_name as JSON) AS type_name, CAST(unmapped as JSON) AS unmapped FROM cf_moose_db."cloudtrail" WHERE ((FALSE OR src_endpoint.ip IN (\'172.31.92.201\', \'172.31.30.227\')) AND time BETWEEN 1641026590000 AND 1698748990000)' +like_query_replaced_column = 'SELECT CAST(metadata as JSON) AS metadata, CAST(time as JSON) AS time, CAST(cloud as JSON) AS cloud, CAST(api as JSON) AS api, CAST(ref_event_uid as JSON) AS ref_event_uid, CAST(src_endpoint as JSON) AS src_endpoint, CAST(resources as JSON) AS resources, CAST(identity as JSON) AS identity, CAST(http_request as JSON) AS http_request, CAST(class_name as JSON) AS class_name, CAST(class_uid as JSON) AS class_uid, CAST(category_name as JSON) AS category_name, CAST(category_uid as JSON) AS category_uid, CAST(severity_id as JSON) AS severity_id, CAST(severity as JSON) AS severity, CAST(activity_name as JSON) AS activity_name, CAST(activity_id as JSON) AS activity_id, CAST(type_uid as JSON) AS type_uid, CAST(type_name as JSON) AS type_name, CAST(unmapped as JSON) AS unmapped FROM cf_moose_db."cloudtrail" WHERE ((FALSE OR lower(src_endpoint.ip) LIKE lower(\'192.168.0.52\')) AND time BETWEEN 1641026590000 AND 1698748990000)' +match_query_replaced_column = 'SELECT CAST(metadata as JSON) AS metadata, CAST(time as JSON) AS time, CAST(cloud as JSON) AS cloud, CAST(api as JSON) AS api, CAST(ref_event_uid as JSON) AS ref_event_uid, CAST(src_endpoint as JSON) AS src_endpoint, CAST(resources as JSON) AS resources, CAST(identity as JSON) AS identity, CAST(http_request as JSON) AS http_request, CAST(class_name as JSON) AS class_name, CAST(class_uid as JSON) AS class_uid, CAST(category_name as JSON) AS category_name, CAST(category_uid as JSON) AS category_uid, CAST(severity_id as JSON) AS severity_id, CAST(severity as JSON) AS severity, CAST(activity_name as JSON) AS activity_name, CAST(activity_id as JSON) AS activity_id, CAST(type_uid as JSON) AS type_uid, CAST(type_name as JSON) AS type_name, CAST(unmapped as JSON) AS unmapped FROM cf_moose_db."cloudtrail" WHERE ((FALSE OR REGEXP_LIKE(CAST(src_endpoint.ip as varchar), \'192.168.0.52\')) AND time BETWEEN 1641026590000 AND 1698748990000)' + + +class TestAWSConnection(unittest.IsolatedAsyncioTestCase): + #----Test Against the check_status_for_missing_column method + @staticmethod + @patch('stix_shifter_modules.aws_athena.stix_transmission.boto3_client.BOTO3Client.makeRequest') + async def test_no_column_missing(mock_status_query): + mock_status_query.return_value = get_aws_mock_response(AWSMockJsonResponse.get_query_execution(**{})) + client = BOTO3Client(CONNECTION, CONFIGURATION) + test_query = "test" + query_response = await PostQueryConnectorErrorHandler.check_status_for_missing_column(client, "1", test_query) + + assert query_response == True + + @staticmethod + @patch('time.sleep') + @patch('stix_shifter_modules.aws_athena.stix_transmission.boto3_client.BOTO3Client.makeRequest') + async def test_stuck_waiting(mock_status_query, mock_time): + mock_status_query.return_value = get_aws_mock_response(AWSMockJsonResponse.get_query_execution_running_status(**{})) + mock_time.return_value = "do_nothing" + client = BOTO3Client(CONNECTION, CONFIGURATION) + test_query = "test" + query_response = await PostQueryConnectorErrorHandler.check_status_for_missing_column(client, "1", test_query) + + assert query_response == True + + #----Test Against the _remove_invalid_column_table method + @staticmethod + @patch('stix_shifter_modules.aws_athena.stix_transmission.status_connector.StatusConnector.create_status_connection') + async def test_column_missing_equals(mock_status_query): + mock_status_query.return_value = ip_address_column_not_found + client = BOTO3Client(CONNECTION, CONFIGURATION) + query_response = await PostQueryConnectorErrorHandler.check_status_for_missing_column(client, "1", equal_query) + assert query_response == equal_query_replaced_column + + @staticmethod + @patch('stix_shifter_modules.aws_athena.stix_transmission.status_connector.StatusConnector.create_status_connection') + async def test_column_missing_doesnt_equal(mock_status_query): + mock_status_query.return_value = ip_address_column_not_found + client = BOTO3Client(CONNECTION, CONFIGURATION) + query_response = await PostQueryConnectorErrorHandler.check_status_for_missing_column(client, "1", doesnt_equal_query) + assert query_response == doesnt_equal_query_replaced_column + + @staticmethod + @patch('stix_shifter_modules.aws_athena.stix_transmission.status_connector.StatusConnector.create_status_connection') + async def test_column_missing_greater_than(mock_status_query): + mock_status_query.return_value = ip_address_column_not_found + client = BOTO3Client(CONNECTION, CONFIGURATION) + query_response = await PostQueryConnectorErrorHandler.check_status_for_missing_column(client, "1", greater_than_query) + + assert query_response == greater_than_query_replaced_column + + @staticmethod + @patch('stix_shifter_modules.aws_athena.stix_transmission.status_connector.StatusConnector.create_status_connection') + async def test_column_missing_in_query(mock_status_query): + mock_status_query.return_value = ip_address_column_not_found + client = BOTO3Client(CONNECTION, CONFIGURATION) + query_response = await PostQueryConnectorErrorHandler.check_status_for_missing_column(client, "1", in_query) + + assert query_response == in_query_replaced_column + + @staticmethod + @patch('stix_shifter_modules.aws_athena.stix_transmission.status_connector.StatusConnector.create_status_connection') + async def test_column_missing_like_query(mock_status_query): + mock_status_query.return_value = ip_address_column_not_found + client = BOTO3Client(CONNECTION, CONFIGURATION) + query_response = await PostQueryConnectorErrorHandler.check_status_for_missing_column(client, "1", like_query) + + assert query_response == like_query_replaced_column + + @staticmethod + @patch('stix_shifter_modules.aws_athena.stix_transmission.status_connector.StatusConnector.create_status_connection') + async def test_column_missing_match_query(mock_status_query): + mock_status_query.return_value = ip_address_column_not_found + client = BOTO3Client(CONNECTION, CONFIGURATION) + query_response = await PostQueryConnectorErrorHandler.check_status_for_missing_column(client, "1", match_query) + + assert query_response == match_query_replaced_column \ No newline at end of file