Skip to content

Commit

Permalink
feat(ingest): allow custom SF API version (#11145)
Browse files Browse the repository at this point in the history
  • Loading branch information
skrydal authored Aug 16, 2024
1 parent 12b3da3 commit 11890e5
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 23 deletions.
41 changes: 26 additions & 15 deletions metadata-ingestion/src/datahub/ingestion/source/salesforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
from datetime import datetime
from enum import Enum
from typing import Dict, Iterable, List, Optional
from typing import Any, Dict, Iterable, List, Optional

import requests
from pydantic import Field, validator
Expand Down Expand Up @@ -124,6 +124,9 @@ class SalesforceConfig(DatasetSourceConfigMixin):
default=dict(),
description='Regex patterns for tables/schemas to describe domain_key domain key (domain_key can be any string like "sales".) There can be multiple domain keys specified.',
)
api_version: Optional[str] = Field(
description="If specified, overrides default version used by the Salesforce package. Example value: '59.0'"
)

profiling: SalesforceProfilingConfig = SalesforceProfilingConfig()

Expand Down Expand Up @@ -222,6 +225,12 @@ def __init__(self, config: SalesforceConfig, ctx: PipelineContext) -> None:
self.session = requests.Session()
self.platform: str = "salesforce"
self.fieldCounts = {}
common_args: Dict[str, Any] = {
"domain": "test" if self.config.is_sandbox else None,
"session": self.session,
}
if self.config.api_version:
common_args["version"] = self.config.api_version

try:
if self.config.auth is SalesforceAuthType.DIRECT_ACCESS_TOKEN:
Expand All @@ -236,8 +245,7 @@ def __init__(self, config: SalesforceConfig, ctx: PipelineContext) -> None:
self.sf = Salesforce(
instance_url=self.config.instance_url,
session_id=self.config.access_token,
session=self.session,
domain="test" if self.config.is_sandbox else None,
**common_args,
)
elif self.config.auth is SalesforceAuthType.USERNAME_PASSWORD:
logger.debug("Username/Password Provided in Config")
Expand All @@ -255,8 +263,7 @@ def __init__(self, config: SalesforceConfig, ctx: PipelineContext) -> None:
username=self.config.username,
password=self.config.password,
security_token=self.config.security_token,
session=self.session,
domain="test" if self.config.is_sandbox else None,
**common_args,
)

elif self.config.auth is SalesforceAuthType.JSON_WEB_TOKEN:
Expand All @@ -275,32 +282,36 @@ def __init__(self, config: SalesforceConfig, ctx: PipelineContext) -> None:
username=self.config.username,
consumer_key=self.config.consumer_key,
privatekey=self.config.private_key,
session=self.session,
domain="test" if self.config.is_sandbox else None,
**common_args,
)

except Exception as e:
logger.error(e)
raise ConfigurationError("Salesforce login failed") from e
else:
if not self.config.api_version:
# List all REST API versions and use latest one
versions_url = "https://{instance}/services/data/".format(
instance=self.sf.sf_instance,
)
versions_response = self.sf._call_salesforce("GET", versions_url).json()
latest_version = versions_response[-1]
version = latest_version["version"]
# we could avoid setting the version like below (after the Salesforce object has been already initiated
# above), since, according to the docs:
# https://developer.salesforce.com/docs/atlas.en-us.api_rest.meta/api_rest/dome_versions.htm
# we don't need to be authenticated to list the versions (so we could perform this call before even
# authenticating)
self.sf.sf_version = version

self.base_url = "https://{instance}/services/data/v{sf_version}/".format(
instance=self.sf.sf_instance, sf_version=version
)
self.base_url = "https://{instance}/services/data/v{sf_version}/".format(
instance=self.sf.sf_instance, sf_version=self.sf.sf_version
)

logger.debug(
"Using Salesforce REST API with {label} version: {version}".format(
label=latest_version["label"], version=latest_version["version"]
)
logger.debug(
"Using Salesforce REST API version: {version}".format(
version=self.sf.sf_version
)
)

def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
sObjects = self.get_salesforce_objects()
Expand Down
102 changes: 94 additions & 8 deletions metadata-ingestion/tests/integration/salesforce/test_salesforce.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import json
import pathlib
from unittest import mock
from unittest.mock import Mock

from freezegun import freeze_time

from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.salesforce import SalesforceConfig, SalesforceSource
from tests.test_helpers import mce_helpers

FROZEN_TIME = "2022-05-12 11:00:00"
Expand All @@ -19,15 +21,16 @@ def _read_response(file_name: str) -> dict:
return data


def side_effect_call_salesforce(type, url):
class MockResponse:
def __init__(self, json_data, status_code):
self.json_data = json_data
self.status_code = status_code
class MockResponse:
def __init__(self, json_data, status_code):
self.json_data = json_data
self.status_code = status_code

def json(self):
return self.json_data

def json(self):
return self.json_data

def side_effect_call_salesforce(type, url):
if url.endswith("/services/data/"):
return MockResponse(_read_response("versions_response.json"), 200)
if url.endswith("FROM EntityDefinition WHERE IsCustomizable = true"):
Expand Down Expand Up @@ -55,9 +58,92 @@ def json(self):
return MockResponse({}, 404)


@mock.patch("datahub.ingestion.source.salesforce.Salesforce")
def test_latest_version(mock_sdk):
mock_sf = mock.Mock()
mocked_call = mock.Mock()
mocked_call.side_effect = side_effect_call_salesforce
mock_sf._call_salesforce = mocked_call
mock_sdk.return_value = mock_sf

config = SalesforceConfig.parse_obj(
{
"auth": "DIRECT_ACCESS_TOKEN",
"instance_url": "https://mydomain.my.salesforce.com/",
"access_token": "access_token`",
"ingest_tags": True,
"object_pattern": {
"allow": [
"^Account$",
"^Property__c$",
],
},
"domain": {"sales": {"allow": {"^Property__c$"}}},
"profiling": {"enabled": True},
"profile_pattern": {
"allow": [
"^Property__c$",
]
},
}
)
SalesforceSource(config=config, ctx=Mock())
calls = mock_sf._call_salesforce.mock_calls
assert (
len(calls) == 1
), "We didn't specify version but source didn't call SF API to get the latest one"
assert calls[0].ends_with(
"/services/data"
), "Source didn't call proper SF API endpoint to get all versions"
assert (
mock_sf.sf_version == "54.0"
), "API version was not correctly set (see versions_responses.json)"


@mock.patch("datahub.ingestion.source.salesforce.Salesforce")
def test_custom_version(mock_sdk):
mock_sf = mock.Mock()
mocked_call = mock.Mock()
mocked_call.side_effect = side_effect_call_salesforce
mock_sf._call_salesforce = mocked_call
mock_sdk.return_value = mock_sf

config = SalesforceConfig.parse_obj(
{
"auth": "DIRECT_ACCESS_TOKEN",
"api_version": "46.0",
"instance_url": "https://mydomain.my.salesforce.com/",
"access_token": "access_token`",
"ingest_tags": True,
"object_pattern": {
"allow": [
"^Account$",
"^Property__c$",
],
},
"domain": {"sales": {"allow": {"^Property__c$"}}},
"profiling": {"enabled": True},
"profile_pattern": {
"allow": [
"^Property__c$",
]
},
}
)
SalesforceSource(config=config, ctx=Mock())

calls = mock_sf._call_salesforce.mock_calls
assert (
len(calls) == 0
), "Source called API to get all versions even though we specified proper version"
assert (
mock_sdk.call_args.kwargs["version"] == "46.0"
), "API client object was not correctly initialized with the custom version"


@freeze_time(FROZEN_TIME)
def test_salesforce_ingest(pytestconfig, tmp_path):
with mock.patch("simple_salesforce.Salesforce") as mock_sdk:
with mock.patch("datahub.ingestion.source.salesforce.Salesforce") as mock_sdk:
mock_sf = mock.Mock()
mocked_call = mock.Mock()
mocked_call.side_effect = side_effect_call_salesforce
Expand Down

0 comments on commit 11890e5

Please sign in to comment.