diff --git a/data/saas/config/saas_example_config.yml b/data/saas/config/saas_example_config.yml index 11fe5e83b3..0c1ca003b9 100644 --- a/data/saas/config/saas_example_config.yml +++ b/data/saas/config/saas_example_config.yml @@ -105,7 +105,7 @@ saas_config: - name: On-Behalf-Of value: - name: Token - value: Custom + value: Custom query_params: - name: limit value: @@ -134,12 +134,16 @@ saas_config: param_values: - name: placeholder identity: email - - name: user_feedback + - name: users + after: [ saas_connector_example.projects ] requests: read: method: GET path: /api/0/projects///user-feedback/ - grouped_inputs: [ organization_slug, project_slug ] + query_params: + - name: query + value: + grouped_inputs: [ organization_slug, project_slug, query ] param_values: - name: organization_slug references: @@ -151,3 +155,5 @@ saas_config: - dataset: saas_connector_example field: projects.slug direction: from + - name: query + identity: email \ No newline at end of file diff --git a/data/saas/dataset/saas_example_dataset.yml b/data/saas/dataset/saas_example_dataset.yml index e3f51ee6f3..6594beaf82 100644 --- a/data/saas/dataset/saas_example_dataset.yml +++ b/data/saas/dataset/saas_example_dataset.yml @@ -162,7 +162,7 @@ dataset: - name: slug fidesops_meta: data_type: string - - name: user_feedback + - name: users fields: - name: id data_categories: [ system.operations ] diff --git a/docs/fidesops/docs/guides/saas_config.md b/docs/fidesops/docs/guides/saas_config.md index cadfa6a653..786c121cbe 100644 --- a/docs/fidesops/docs/guides/saas_config.md +++ b/docs/fidesops/docs/guides/saas_config.md @@ -172,6 +172,7 @@ test_request: This is where we define how we are going to access and update each collection in the corresponding Dataset. The endpoint section contains the following members: - `name` This name corresponds to a Collection in the corresponding Dataset. +- `after` To configure if this endpoint should run after other endpoints or collections. This should be a list of collection addresses, for example: `after: [ mailchimp_connector_example.member ]` would cause the current endpoint to run after the member endpoint. - `requests` A map of `read`, `update`, and `delete` requests for this collection. Each collection can define a way to read and a way to update the data. Each request is made up of: - `method` The HTTP method used for the endpoint. - `path` A static or dynamic resource path. The dynamic portions of the path are enclosed within angle brackets `` and are replaced with values from `param_values`. diff --git a/src/fidesops/schemas/saas/saas_config.py b/src/fidesops/schemas/saas/saas_config.py index c16b9c1ae9..a448b75ca5 100644 --- a/src/fidesops/schemas/saas/saas_config.py +++ b/src/fidesops/schemas/saas/saas_config.py @@ -4,8 +4,15 @@ from fidesops.service.pagination.pagination_strategy_factory import get_strategy from pydantic import BaseModel, validator, root_validator, Extra from fidesops.schemas.base_class import BaseSchema -from fidesops.schemas.dataset import FidesopsDatasetReference -from fidesops.graph.config import Collection, Dataset, FieldAddress, ScalarField +from fidesops.schemas.dataset import FidesopsDatasetReference, FidesCollectionKey +from fidesops.graph.config import ( + Collection, + Dataset, + FieldAddress, + ScalarField, + CollectionAddress, + Field, +) from fidesops.schemas.saas.strategy_configuration import ConnectorParamRef from fidesops.schemas.shared_schemas import FidesOpsKey @@ -143,6 +150,7 @@ class Endpoint(BaseModel): name: str requests: Dict[Literal["read", "update", "delete"], SaaSRequest] + after: List[FidesCollectionKey] = [] class ConnectorParam(BaseModel): @@ -214,6 +222,9 @@ def get_graph(self) -> Dataset: name=endpoint.name, fields=fields, grouped_inputs=grouped_inputs, + after={ + CollectionAddress(*s.split(".")) for s in endpoint.after + }, ) ) diff --git a/src/fidesops/task/graph_task.py b/src/fidesops/task/graph_task.py index 0ce356ee4c..0113a972d6 100644 --- a/src/fidesops/task/graph_task.py +++ b/src/fidesops/task/graph_task.py @@ -128,6 +128,15 @@ def grouped_fields(self) -> Set[Optional[str]]: """ return self.traversal_node.node.collection.grouped_inputs or set() + @property + def dependent_identity_fields(self) -> bool: + """If the current collection needs inputs from other collections, in addition to its seed data.""" + collection = self.traversal_node.node.collection + for field in self.grouped_fields: + if collection.field(FieldPath(field)).identity: + return True + return False + def build_incoming_field_path_maps( self, group_dependent_fields: bool = False ) -> Tuple[COLLECTION_FIELD_PATH_MAP, COLLECTION_FIELD_PATH_MAP]: @@ -168,6 +177,27 @@ def can_write_data(self) -> bool: connection_config: ConnectionConfig = self.connector.configuration return connection_config.access == AccessLevel.write + def _combine_seed_data( + self, + *data: List[Row], + grouped_data: Dict[str, Any], + dependent_field_mappings: COLLECTION_FIELD_PATH_MAP, + ) -> Dict[str, Any]: + """Combine the seed data with the other dependent inputs. This is used when the seed data in a collection requires + inputs from another collection to generate subsequent queries.""" + # Get the identity values from the seeds that were passed into this collection. + seed_index = self.input_keys.index(ROOT_COLLECTION_ADDRESS) + seed_data = data[seed_index] + + for (foreign_field_path, local_field_path) in dependent_field_mappings[ + ROOT_COLLECTION_ADDRESS + ]: + dependent_values: List = consolidate_query_matches( + row=seed_data, target_path=foreign_field_path + ) + grouped_data[local_field_path.string_path] = dependent_values + return grouped_data + def pre_process_input_data( self, *data: List[Row], group_dependent_fields: bool = False ) -> NodeInput: @@ -209,6 +239,14 @@ def pre_process_input_data( for i, rowset in enumerate(data): collection_address = self.input_keys[i] + if ( + group_dependent_fields + and self.dependent_identity_fields + and collection_address == ROOT_COLLECTION_ADDRESS + ): + # Skip building data for the root collection if the seed data needs to be combined with other inputs + continue + logger.info( f"Consolidating incoming data into {self.traversal_node.node.address} from {collection_address}." ) @@ -234,6 +272,14 @@ def pre_process_input_data( row=row, target_path=foreign_field_path ) grouped_data[local_field_path.string_path] = dependent_values + + if self.dependent_identity_fields: + grouped_data = self._combine_seed_data( + *data, + grouped_data=grouped_data, + dependent_field_mappings=dependent_field_mappings, + ) + output[FIDESOPS_GROUPED_INPUTS].append(grouped_data) return output diff --git a/src/fidesops/util/saas_util.py b/src/fidesops/util/saas_util.py index 7e7e8e1963..5289e93e67 100644 --- a/src/fidesops/util/saas_util.py +++ b/src/fidesops/util/saas_util.py @@ -2,8 +2,7 @@ from functools import reduce from typing import Any, Dict, List, Optional, Set from fidesops.common_exceptions import FidesopsException -from fidesops.graph.config import Collection, Dataset, Field - +from fidesops.graph.config import Collection, Dataset, Field, CollectionAddress FIDESOPS_GROUPED_INPUTS = "fidesops_grouped_inputs" @@ -43,6 +42,18 @@ def get_collection_grouped_inputs( return collection.grouped_inputs +def get_collection_after( + collections: List[Collection], name: str +) -> Set[CollectionAddress]: + """If specified, return the collections that need to run before the current collection for saas configs""" + collection: Collection = next( + (collect for collect in collections if collect.name == name), {} + ) + if not collection: + return set() + return collection.after + + def merge_datasets(dataset: Dataset, config_dataset: Dataset) -> Dataset: """ Merges all Collections and Fields from the config_dataset into the dataset. @@ -63,6 +74,7 @@ def merge_datasets(dataset: Dataset, config_dataset: Dataset) -> Dataset: grouped_inputs=get_collection_grouped_inputs( config_dataset.collections, collection_name ), + after=get_collection_after(config_dataset.collections, collection_name), ) ) diff --git a/tests/models/test_saasconfig.py b/tests/models/test_saasconfig.py index 81302482b3..e4b04f547b 100644 --- a/tests/models/test_saasconfig.py +++ b/tests/models/test_saasconfig.py @@ -2,7 +2,7 @@ import pytest from pydantic import ValidationError -from fidesops.graph.config import FieldAddress +from fidesops.graph.config import CollectionAddress, FieldAddress from fidesops.schemas.saas.saas_config import SaaSConfig, SaaSRequest @@ -40,19 +40,23 @@ def test_saas_config_to_dataset(saas_configs: Dict[str, Dict]): assert query_field.name == "email" assert query_field.identity == "email" - user_feedback_collection = saas_dataset.collections[5] - assert user_feedback_collection.grouped_inputs == { + user_collection = saas_dataset.collections[5] + assert user_collection.after == { + CollectionAddress("saas_connector_example", "projects") + } + assert user_collection.grouped_inputs == { "organization_slug", "project_slug", + "query" } - org_slug_reference, direction = user_feedback_collection.fields[0].references[0] + org_slug_reference, direction = user_collection.fields[0].references[0] assert org_slug_reference == FieldAddress( saas_config.fides_key, "projects", "organization", "slug" ) assert direction == "from" - project_slug_reference, direction = user_feedback_collection.fields[1].references[0] + project_slug_reference, direction = user_collection.fields[1].references[0] assert project_slug_reference == FieldAddress( saas_config.fides_key, "projects", "slug" ) diff --git a/tests/task/test_graph_task.py b/tests/task/test_graph_task.py index be7688d001..6f5d6b8e72 100644 --- a/tests/task/test_graph_task.py +++ b/tests/task/test_graph_task.py @@ -249,15 +249,14 @@ def test_pre_process_input_conversation_collection( } def test_pre_process_input_data_group_dependent_fields(self): - """Test processing inputs where fields have been marked as dependent""" + """Test processing inputs where several reference fields and an identity field have + been marked as dependent. + """ traversal_with_grouped_inputs = traversal_paired_dependency() n = traversal_with_grouped_inputs.traversal_node_dict[ CollectionAddress("mysql", "User") ] - - task = MockSqlTask( - n, TaskResources(EMPTY_REQUEST, Policy(), connection_configs) - ) + task = MockSqlTask(n, TaskResources(EMPTY_REQUEST, Policy(), connection_configs)) project_output = [ { @@ -277,22 +276,36 @@ def test_pre_process_input_data_group_dependent_fields(self): }, ] + identity_output = [{"email": "email@gmail.com"}] # Typical output - project ids and organization ids would be completely independent from each other - assert task.pre_process_input_data(project_output) == { - "organization": ["12345", "54321", "54321"], + assert task.pre_process_input_data(identity_output, project_output) == { + "email": ["email@gmail.com"], "project": ["abcde", "fghij", "klmno"], + "organization": ["12345", "54321", "54321"], "fidesops_grouped_inputs": [], } # With group_dependent_fields = True. Fields are grouped together under a key that shouldn't overlap # with actual table keys "fidesops_grouped_inputs" assert task.pre_process_input_data( - project_output, group_dependent_fields=True + identity_output, project_output, group_dependent_fields=True ) == { "fidesops_grouped_inputs": [ - {"organization": ["12345"], "project": ["abcde"]}, - {"organization": ["54321"], "project": ["fghij"]}, - {"organization": ["54321"], "project": ["klmno"]}, + { + "project": ["abcde"], + "organization": ["12345"], + "email": ["email@gmail.com"], + }, + { + "project": ["fghij"], + "organization": ["54321"], + "email": ["email@gmail.com"], + }, + { + "project": ["klmno"], + "organization": ["54321"], + "email": ["email@gmail.com"], + }, ] } diff --git a/tests/task/traversal_data.py b/tests/task/traversal_data.py index 46ff59d80d..e82b8df04e 100644 --- a/tests/task/traversal_data.py +++ b/tests/task/traversal_data.py @@ -610,6 +610,9 @@ def traversal_paired_dependency() -> Traversal: ) users = Collection( name="User", + after={ + CollectionAddress("mysql", "Project"), + }, fields=[ ScalarField( name="project", @@ -620,11 +623,11 @@ def traversal_paired_dependency() -> Traversal: references=[(FieldAddress("mysql", "Project", "organization_id"), "from")], ), ScalarField(name="username"), - ScalarField(name="email"), + ScalarField(name="email", identity="email"), ScalarField(name="position"), ], - grouped_inputs= {"project", "organization"} + grouped_inputs={"project", "organization", "email"} ) mysql = Dataset(