diff --git a/metadata-ingestion/src/datahub/ingestion/source/salesforce.py b/metadata-ingestion/src/datahub/ingestion/source/salesforce.py index 42128123c6144..7a7f1f30950eb 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/salesforce.py +++ b/metadata-ingestion/src/datahub/ingestion/source/salesforce.py @@ -3,7 +3,7 @@ import time from datetime import datetime from enum import Enum -from typing import Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional import requests from pydantic import Field, validator @@ -124,6 +124,9 @@ class SalesforceConfig(DatasetSourceConfigMixin): default=dict(), description='Regex patterns for tables/schemas to describe domain_key domain key (domain_key can be any string like "sales".) There can be multiple domain keys specified.', ) + api_version: Optional[str] = Field( + description="If specified, overrides default version used by the Salesforce package. Example value: '59.0'" + ) profiling: SalesforceProfilingConfig = SalesforceProfilingConfig() @@ -222,6 +225,12 @@ def __init__(self, config: SalesforceConfig, ctx: PipelineContext) -> None: self.session = requests.Session() self.platform: str = "salesforce" self.fieldCounts = {} + common_args: Dict[str, Any] = { + "domain": "test" if self.config.is_sandbox else None, + "session": self.session, + } + if self.config.api_version: + common_args["version"] = self.config.api_version try: if self.config.auth is SalesforceAuthType.DIRECT_ACCESS_TOKEN: @@ -236,8 +245,7 @@ def __init__(self, config: SalesforceConfig, ctx: PipelineContext) -> None: self.sf = Salesforce( instance_url=self.config.instance_url, session_id=self.config.access_token, - session=self.session, - domain="test" if self.config.is_sandbox else None, + **common_args, ) elif self.config.auth is SalesforceAuthType.USERNAME_PASSWORD: logger.debug("Username/Password Provided in Config") @@ -255,8 +263,7 @@ def __init__(self, config: SalesforceConfig, ctx: PipelineContext) -> None: username=self.config.username, password=self.config.password, security_token=self.config.security_token, - session=self.session, - domain="test" if self.config.is_sandbox else None, + **common_args, ) elif self.config.auth is SalesforceAuthType.JSON_WEB_TOKEN: @@ -275,14 +282,13 @@ def __init__(self, config: SalesforceConfig, ctx: PipelineContext) -> None: username=self.config.username, consumer_key=self.config.consumer_key, privatekey=self.config.private_key, - session=self.session, - domain="test" if self.config.is_sandbox else None, + **common_args, ) except Exception as e: logger.error(e) raise ConfigurationError("Salesforce login failed") from e - else: + if not self.config.api_version: # List all REST API versions and use latest one versions_url = "https://{instance}/services/data/".format( instance=self.sf.sf_instance, @@ -290,17 +296,22 @@ def __init__(self, config: SalesforceConfig, ctx: PipelineContext) -> None: versions_response = self.sf._call_salesforce("GET", versions_url).json() latest_version = versions_response[-1] version = latest_version["version"] + # we could avoid setting the version like below (after the Salesforce object has been already initiated + # above), since, according to the docs: + # https://developer.salesforce.com/docs/atlas.en-us.api_rest.meta/api_rest/dome_versions.htm + # we don't need to be authenticated to list the versions (so we could perform this call before even + # authenticating) self.sf.sf_version = version - self.base_url = "https://{instance}/services/data/v{sf_version}/".format( - instance=self.sf.sf_instance, sf_version=version - ) + self.base_url = "https://{instance}/services/data/v{sf_version}/".format( + instance=self.sf.sf_instance, sf_version=self.sf.sf_version + ) - logger.debug( - "Using Salesforce REST API with {label} version: {version}".format( - label=latest_version["label"], version=latest_version["version"] - ) + logger.debug( + "Using Salesforce REST API version: {version}".format( + version=self.sf.sf_version ) + ) def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: sObjects = self.get_salesforce_objects() diff --git a/metadata-ingestion/tests/integration/salesforce/test_salesforce.py b/metadata-ingestion/tests/integration/salesforce/test_salesforce.py index 8b6b883b2148d..89a37a372df84 100644 --- a/metadata-ingestion/tests/integration/salesforce/test_salesforce.py +++ b/metadata-ingestion/tests/integration/salesforce/test_salesforce.py @@ -1,10 +1,12 @@ import json import pathlib from unittest import mock +from unittest.mock import Mock from freezegun import freeze_time from datahub.ingestion.run.pipeline import Pipeline +from datahub.ingestion.source.salesforce import SalesforceConfig, SalesforceSource from tests.test_helpers import mce_helpers FROZEN_TIME = "2022-05-12 11:00:00" @@ -19,15 +21,16 @@ def _read_response(file_name: str) -> dict: return data -def side_effect_call_salesforce(type, url): - class MockResponse: - def __init__(self, json_data, status_code): - self.json_data = json_data - self.status_code = status_code +class MockResponse: + def __init__(self, json_data, status_code): + self.json_data = json_data + self.status_code = status_code + + def json(self): + return self.json_data - def json(self): - return self.json_data +def side_effect_call_salesforce(type, url): if url.endswith("/services/data/"): return MockResponse(_read_response("versions_response.json"), 200) if url.endswith("FROM EntityDefinition WHERE IsCustomizable = true"): @@ -55,9 +58,92 @@ def json(self): return MockResponse({}, 404) +@mock.patch("datahub.ingestion.source.salesforce.Salesforce") +def test_latest_version(mock_sdk): + mock_sf = mock.Mock() + mocked_call = mock.Mock() + mocked_call.side_effect = side_effect_call_salesforce + mock_sf._call_salesforce = mocked_call + mock_sdk.return_value = mock_sf + + config = SalesforceConfig.parse_obj( + { + "auth": "DIRECT_ACCESS_TOKEN", + "instance_url": "https://mydomain.my.salesforce.com/", + "access_token": "access_token`", + "ingest_tags": True, + "object_pattern": { + "allow": [ + "^Account$", + "^Property__c$", + ], + }, + "domain": {"sales": {"allow": {"^Property__c$"}}}, + "profiling": {"enabled": True}, + "profile_pattern": { + "allow": [ + "^Property__c$", + ] + }, + } + ) + SalesforceSource(config=config, ctx=Mock()) + calls = mock_sf._call_salesforce.mock_calls + assert ( + len(calls) == 1 + ), "We didn't specify version but source didn't call SF API to get the latest one" + assert calls[0].ends_with( + "/services/data" + ), "Source didn't call proper SF API endpoint to get all versions" + assert ( + mock_sf.sf_version == "54.0" + ), "API version was not correctly set (see versions_responses.json)" + + +@mock.patch("datahub.ingestion.source.salesforce.Salesforce") +def test_custom_version(mock_sdk): + mock_sf = mock.Mock() + mocked_call = mock.Mock() + mocked_call.side_effect = side_effect_call_salesforce + mock_sf._call_salesforce = mocked_call + mock_sdk.return_value = mock_sf + + config = SalesforceConfig.parse_obj( + { + "auth": "DIRECT_ACCESS_TOKEN", + "api_version": "46.0", + "instance_url": "https://mydomain.my.salesforce.com/", + "access_token": "access_token`", + "ingest_tags": True, + "object_pattern": { + "allow": [ + "^Account$", + "^Property__c$", + ], + }, + "domain": {"sales": {"allow": {"^Property__c$"}}}, + "profiling": {"enabled": True}, + "profile_pattern": { + "allow": [ + "^Property__c$", + ] + }, + } + ) + SalesforceSource(config=config, ctx=Mock()) + + calls = mock_sf._call_salesforce.mock_calls + assert ( + len(calls) == 0 + ), "Source called API to get all versions even though we specified proper version" + assert ( + mock_sdk.call_args.kwargs["version"] == "46.0" + ), "API client object was not correctly initialized with the custom version" + + @freeze_time(FROZEN_TIME) def test_salesforce_ingest(pytestconfig, tmp_path): - with mock.patch("simple_salesforce.Salesforce") as mock_sdk: + with mock.patch("datahub.ingestion.source.salesforce.Salesforce") as mock_sdk: mock_sf = mock.Mock() mocked_call = mock.Mock() mocked_call.side_effect = side_effect_call_salesforce