Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DLS in Salesforce #2022

Merged
merged 3 commits into from
Jan 10, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Salesforce DLS
moxarth-rathod committed Dec 30, 2023
commit e820f183c141ea0b005808b2a65a6fc0dfd47f72
277 changes: 243 additions & 34 deletions connectors/sources/salesforce.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,11 @@
import fastjsonschema
from aiohttp.client_exceptions import ClientResponseError

from connectors.access_control import (
ACCESS_CONTROL,
es_access_control_query,
prefix_identity,
)
from connectors.filtering.validation import (
AdvancedRulesValidator,
SyncRuleValidationResult,
@@ -41,13 +46,16 @@
TOKEN_ENDPOINT = "/services/oauth2/token" # noqa S105
QUERY_ENDPOINT = f"/services/data/{API_VERSION}/query"
SOSL_SEARCH_ENDPOINT = f"/services/data/{API_VERSION}/search"
CUSTOM_OBJECT_ENDPOINT = f"/services/data/{API_VERSION}/tooling/query"
DESCRIBE_ENDPOINT = f"/services/data/{API_VERSION}/sobjects"
DESCRIBE_SOBJECT_ENDPOINT = f"/services/data/{API_VERSION}/sobjects/<sobject>/describe"
# https://developer.salesforce.com/docs/atlas.en-us.api_rest.meta/api_rest/resources_sobject_blob_retrieve.htm
CONTENT_VERSION_DOWNLOAD_ENDPOINT = f"/services/data/{API_VERSION}/sobjects/ContentVersion/<content_version_id>/VersionData"
OFFSET = 200

OBJECT_READ_PERMISSION_USERS = "SELECT AssigneeId FROM PermissionSetAssignment WHERE PermissionSetId IN (SELECT ParentId FROM ObjectPermissions WHERE PermissionsRead = true AND SObjectType = '{sobject}')"
USERNAME_FROM_IDS = "SELECT Name, Email FROM User WHERE Id IN {user_list}"
FILE_ACCESS = "SELECT ContentDocumentId, LinkedEntityId, LinkedEntity.Name FROM ContentDocumentLink WHERE ContentDocumentId = '{document_id}'"

RELEVANT_SOBJECTS = [
"Account",
"Campaign",
@@ -115,9 +123,23 @@
"VersionDataUrl",
"VersionNumber",
"Website",
"UserType",
]


def _prefix_user(user):
if user:
return prefix_identity("user", user)


def _prefix_user_id(user_id):
return prefix_identity("user_id", user_id)


def _prefix_email(email):
return prefix_identity("email", email)


class RateLimitedException(Exception):
"""Notifies that Salesforce has begun rate limiting the current account"""

@@ -249,36 +271,52 @@ async def get_sync_rules_results(self, rule):
for record in records:
yield record

async def get_custom_objects(self):
custom_object_query = (
SalesforceSoqlBuilder("CustomObject")
.with_id()
.with_default_metafields()
.with_fields(["DeveloperName"])
.build()
)

async def _custom_objects(self):
response = await self._get_json(f"{self.base_url}{DESCRIBE_ENDPOINT}")
custom_objects = []
async for records in self._yield_non_bulk_query_pages(
soql_query=custom_object_query, endpoint=CUSTOM_OBJECT_ENDPOINT
):
custom_objects.extend(records)

for _object in custom_objects:
if (
_object["DeveloperName"] == "Knowledge_kav"
): # kav stands for Knowledge Article Version
object_name = "Knowledge__kav"
elif _object["DeveloperName"] == "Knowledge":
object_name = f"{_object['DeveloperName']}__kav"
else:
object_name = f"{_object['DeveloperName']}__c"

query = await self._custom_object_query(custom_object=object_name)
for sobject in response.get("sobjects", []):
if sobject.get("custom"):
custom_objects.append(sobject.get("name"))
return custom_objects

async def get_custom_objects(self):
for custom_object in await self._custom_objects():
query = await self._custom_object_query(custom_object=custom_object)
async for records in self._yield_non_bulk_query_pages(query):
for record in records:
yield record

async def get_salesforce_users(self):
if not await self._is_queryable("User"):
self._logger.warning(
"Object User is not queryable, so they won't be ingested."
)
return

query = await self._user_query()
async for records in self._yield_non_bulk_query_pages(query):
for record in records:
yield record

async def get_users_with_read_access(self, sobject):
query = OBJECT_READ_PERMISSION_USERS.format(sobject=sobject)
async for records in self._yield_non_bulk_query_pages(query):
for record in records:
yield record

async def get_username_by_id(self, user_list):
query = USERNAME_FROM_IDS.format(user_list=user_list)
async for records in self._yield_non_bulk_query_pages(query):
for record in records:
yield record

async def get_file_access(self, document_id):
query = FILE_ACCESS.format(document_id=document_id)
async for records in self._yield_non_bulk_query_pages(query):
for record in records:
yield record

async def get_accounts(self):
if not await self._is_queryable("Account"):
self._logger.warning(
@@ -664,6 +702,20 @@ async def _custom_object_query(self, custom_object):
.build()
)

async def _user_query(self):
queryable_fields = await self._select_queryable_fields(
"User",
["Name", "Email", "UserType"],
)

return (
SalesforceSoqlBuilder("User")
.with_id()
.with_default_metafields()
.with_fields(queryable_fields)
.build()
)

async def _accounts_query(self):
queryable_fields = await self._select_queryable_fields(
"Account",
@@ -1233,6 +1285,7 @@ class SalesforceDataSource(BaseDataSource):
name = "Salesforce"
service_type = "salesforce"
advanced_rules_enabled = True
dls_enabled = True

def __init__(self, configuration):
super().__init__(configuration=configuration)
@@ -1247,6 +1300,7 @@ def __init__(self, configuration):
configuration=configuration, base_url=base_url
)
self.doc_mapper = SalesforceDocMapper(base_url)
self.permissions = {}

def _set_internal_logger(self):
self.salesforce_client.set_logger(self._logger)
@@ -1283,8 +1337,80 @@ def get_default_configuration(cls):
"ui_restrictions": ["advanced"],
"value": False,
},
"use_document_level_security": {
"display": "toggle",
"label": "Enable document level security",
"order": 5,
"tooltip": "Document level security ensures identities and permissions set in Salesforce are maintained in Elasticsearch. This enables you to restrict and personalize read-access users and groups have to documents in this index. Access control syncs ensure this metadata is kept up to date in your Elasticsearch documents.",
"type": "bool",
"value": False,
},
}

def _dls_enabled(self):
"""Check if document level security is enabled. This method checks whether document level security (DLS) is enabled based on the provided configuration.
Returns:
bool: True if document level security is enabled, False otherwise.
"""
if self._features is None:
return False

if not self._features.document_level_security_enabled():
return False

return self.configuration["use_document_level_security"]

def _decorate_with_access_control(self, document, access_control):
if self._dls_enabled():
document[ACCESS_CONTROL] = list(
set(document.get(ACCESS_CONTROL, []) + access_control)
)
return document

async def _user_access_control_doc(self, user):
email = user.get("Email")
username = user.get("Name")

prefixed_email = _prefix_email(email)
prefixed_username = _prefix_user(username)
prefixed_user_id = _prefix_user_id(user.get("Id"))

access_control = [prefixed_email, prefixed_username, prefixed_user_id]
return {
"_id": user.get("Id"),
"identity": {
"email": prefixed_email,
"username": prefixed_username,
"user_id": prefixed_user_id,
},
"created_at": user.get("CreatedDate", iso_utc()),
"_timestamp": user.get("LastModifiedDate", iso_utc()),
} | es_access_control_query(access_control)

async def get_access_control(self):
"""Get access control documents for active Atlassian users.
This method fetches access control documents for active Atlassian users when document level security (DLS)
is enabled. It starts by checking if DLS is enabled, and if not, it logs a warning message and skips further processing.
If DLS is enabled, the method fetches all users from the Salesforce API, filters out active Atlassian users,
and fetches additional information for each active user using the _fetch_user method. After gathering the user information,
it generates an access control document for each user using the user_access_control_doc method and yields the results.
Yields:
dict: An access control document for each active Atlassian user.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do Atlassian users have to do with Salesforce DLS?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, my bad! Updated now ✅

if not self._dls_enabled():
self._logger.warning("DLS is not enabled. Skipping")
return

self._logger.info("Fetching Salesforce users")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self._logger.info("Fetching Salesforce users")
self._logger.debug("Fetching Salesforce users")

async for user in self.salesforce_client.get_salesforce_users():
if user.get("UserType") in ["CloudIntegrationUser", "AutomatedProcess"]:
continue
user_doc = await self._user_access_control_doc(user=user)
yield user_doc

async def validate_config(self):
await super().validate_config()

@@ -1304,8 +1430,41 @@ def advanced_rules_validators(self):

async def _get_advanced_sync_rules_result(self, rule):
async for doc in self.salesforce_client.get_sync_rules_results(rule=rule):
if sobject := doc.get("attributes", {}).get("type"):
await self._fetch_users_with_read_access(sobject=sobject)
yield doc

async def _fetch_users_with_read_access(self, sobject):
if not self._dls_enabled():
self._logger.warning("DLS is not enabled. Skipping")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a warning? I feel that it could get quite noisy. This isn't called during ACL syncs as far as I can see, so debug might be more appropriate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no issue in it, i've made it a debug logger

return

self._logger.debug(
f"Fetching users who has Read access for Salesforce object: {sobject}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"Fetching users who has Read access for Salesforce object: {sobject}"
f"Fetching users who have read access for Salesforce object: {sobject}"

)

if sobject in self.permissions:
return

user_list = []
access_control = []
async for assignee in self.salesforce_client.get_users_with_read_access(
sobject=sobject
):
user_list.append(assignee.get("AssigneeId"))
access_control.append(_prefix_user_id(assignee.get("AssigneeId")))

if user_list == []:
return

async for user in self.salesforce_client.get_username_by_id(
user_list=tuple(user_list)
):
access_control.append(_prefix_user(user.get("Name")))
access_control.append(_prefix_email(user.get("Email")))

self.permissions[sobject] = access_control

async def get_docs(self, filtering=None):
# We collect all content documents and de-duplicate them before downloading and yielding
content_docs = []
@@ -1316,40 +1475,90 @@ async def get_docs(self, filtering=None):
for rule in advanced_rules:
async for doc in self._get_advanced_sync_rules_result(rule=rule):
content_docs.extend(self._parse_content_documents(doc))
yield self.doc_mapper.map_salesforce_objects(doc), None
access_control = self.permissions.get(
doc.get("attributes", {}).get("type"), []
)
yield self.doc_mapper.map_salesforce_objects(
self._decorate_with_access_control(doc, access_control)
), None

else:
for sobject in [
"Account",
"Opportunity",
"Contact",
"Lead",
"Campaign",
"Case",
]:
await self._fetch_users_with_read_access(sobject=sobject)

for custom_object in await self.salesforce_client._custom_objects():
await self._fetch_users_with_read_access(sobject=custom_object)

async for account in self.salesforce_client.get_accounts():
content_docs.extend(self._parse_content_documents(account))
yield self.doc_mapper.map_salesforce_objects(account), None
access_control = self.permissions.get("Account", [])
yield self.doc_mapper.map_salesforce_objects(
self._decorate_with_access_control(account, access_control)
), None

async for opportunity in self.salesforce_client.get_opportunities():
content_docs.extend(self._parse_content_documents(opportunity))
yield self.doc_mapper.map_salesforce_objects(opportunity), None
access_control = self.permissions.get("Opportunity", [])
yield self.doc_mapper.map_salesforce_objects(
self._decorate_with_access_control(opportunity, access_control)
), None

async for contact in self.salesforce_client.get_contacts():
content_docs.extend(self._parse_content_documents(contact))
yield self.doc_mapper.map_salesforce_objects(contact), None
access_control = self.permissions.get("Contact", [])
yield self.doc_mapper.map_salesforce_objects(
self._decorate_with_access_control(contact, access_control)
), None

async for lead in self.salesforce_client.get_leads():
content_docs.extend(self._parse_content_documents(lead))
yield self.doc_mapper.map_salesforce_objects(lead), None
access_control = self.permissions.get("Lead", [])
yield self.doc_mapper.map_salesforce_objects(
self._decorate_with_access_control(lead, access_control)
), None

async for campaign in self.salesforce_client.get_campaigns():
content_docs.extend(self._parse_content_documents(campaign))
yield self.doc_mapper.map_salesforce_objects(campaign), None
access_control = self.permissions.get("Campaign", [])
yield self.doc_mapper.map_salesforce_objects(
self._decorate_with_access_control(campaign, access_control)
), None

async for case in self.salesforce_client.get_cases():
content_docs.extend(self._parse_content_documents(case))
yield self.doc_mapper.map_salesforce_objects(case), None
access_control = self.permissions.get("Case", [])
yield self.doc_mapper.map_salesforce_objects(
self._decorate_with_access_control(case, access_control)
), None

async for custom_object in self.salesforce_client.get_custom_objects():
content_docs.extend(self._parse_content_documents(custom_object))
yield self.doc_mapper.map_salesforce_objects(custom_object), None
access_control = self.permissions.get(
custom_object.get("attributes", {}).get("type"), []
)
yield self.doc_mapper.map_salesforce_objects(
self._decorate_with_access_control(custom_object, access_control)
), None

# Note: this could possibly be done on the fly if memory becomes an issue
content_docs = self._combine_duplicate_content_docs(content_docs)
for content_doc in content_docs:
access_control = []
async for permission in self.salesforce_client.get_file_access(
document_id=content_doc["Id"]
):
access_control.append(_prefix_user_id(permission.get("LinkedEntityId")))
access_control.append(
_prefix_user(permission.get("LinkedEntity", {}).get("Name"))
)

content_version_id = (
content_doc.get("LatestPublishedVersion", {}) or {}
).get("Id")
@@ -1362,7 +1571,7 @@ async def get_docs(self, filtering=None):
doc = self.doc_mapper.map_content_document(content_doc)
doc = await self.get_content(doc, content_version_id)

yield doc, None
yield self._decorate_with_access_control(doc, access_control), None

async def get_content(self, doc, content_version_id):
file_size = doc["content_size"]
152 changes: 151 additions & 1 deletion tests/sources/test_salesforce.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@
from contextlib import asynccontextmanager
from copy import deepcopy
from unittest import TestCase, mock
from unittest.mock import patch
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from aioresponses import CallbackResult
@@ -27,6 +27,9 @@
SalesforceServerError,
SalesforceSoqlBuilder,
TokenFetchException,
_prefix_email,
_prefix_user,
_prefix_user_id,
)
from tests.sources.support import create_source

@@ -39,6 +42,7 @@
)
TEST_CLIENT_ID = "1234"
TEST_CLIENT_SECRET = "9876"
TEST_DESCRIBE_ENDPOINT = f"/services/data/{API_VERSION}/sobjects"

ADVANCED_SNIPPET = "advanced_snippet"

@@ -611,6 +615,35 @@
"_attachment": "Y2h1bmsx",
}

USER_RESPONSE_PAYLOAD = {
"totalSize": 1,
"done": True,
"records": [
{
"attributes": {
"type": "ContentDocumentLink",
"url": f"/services/data/{API_VERSION}/sobjects/ContentDocumentLink/content_document_link_id",
},
"Id": "user_id",
"Name": "Dummy User",
"CreatedDate": "2023-12-25T01:01:01Z",
"LastModifiedDate": "2023-12-25T01:01:01Z",
}
],
}

PERMISSION_SET_RESPONSE_PAYLOAD = {
"totalSize": 1,
"done": True,
"records": [{"AssigneeId": "user_id"}],
}

CONTENT_DOCUMENT_RESPONSE_PAYLOAD = {
"totalSize": 1,
"done": True,
"records": [{"LinkedEntityId": "user_id", "LinkedEntity": {"Name": "Dummy User"}}],
}


@asynccontextmanager
async def create_salesforce_source(
@@ -669,6 +702,10 @@ def salesforce_query_callback(url, **kwargs):
payload = deepcopy(CUSTOM_OBJECT_RESPONSE_PAYLOAD)
case "Connector__c":
payload = deepcopy(CONNECTOR_RESPONSE_PAYLOAD)
case "ObjectPermissions":
payload = deepcopy(PERMISSION_SET_RESPONSE_PAYLOAD)
case "ContentDocumentLink":
payload = deepcopy(CONTENT_DOCUMENT_RESPONSE_PAYLOAD)
case _:
payload = {"records": []}

@@ -1483,6 +1520,9 @@ async def test_get_all_with_content_docs_when_success(
mock_responses.get(
TEST_QUERY_MATCH_URL, repeat=True, callback=salesforce_query_callback
)
source.salesforce_client._custom_objects = AsyncMock(
return_value=["CustomObject"]
)

content_document_records = []
async for record, _ in source.get_docs():
@@ -1537,6 +1577,9 @@ async def test_get_all_with_content_docs_and_extraction_service(mock_responses):
mock_responses.get(
TEST_QUERY_MATCH_URL, repeat=True, callback=salesforce_query_callback
)
source.salesforce_client._custom_objects = AsyncMock(
return_value=["CustomObject"]
)

content_document_records = []
async for record, _ in source.get_docs():
@@ -1690,11 +1733,16 @@ async def test_get_docs_for_sosl_query(mock_responses, filtering):
mock_responses.get(
TEST_QUERY_MATCH_URL, status=200, payload=SOSL_RESPONSE_PAYLOAD
)
mock_responses.get(
TEST_QUERY_MATCH_URL, status=200, payload=CONTENT_DOCUMENT_RESPONSE_PAYLOAD
)

resultant_docs = []
async for record, _ in source.get_docs(filtering):
resultant_docs.append(record)

assert len(resultant_docs) == 2


@pytest.mark.asyncio
async def test_remote_validation(mock_responses):
@@ -1936,3 +1984,105 @@ async def test_combine_duplicate_content_docs_with_duplicates():

combined_docs = source._combine_duplicate_content_docs(content_docs)
TestCase().assertCountEqual(combined_docs, expected_docs)


@pytest.mark.asyncio
@pytest.mark.parametrize(
"user, result",
[("Alex Wilber", "user:Alex Wilber"), ("", None)],
)
async def test_prefix_user(user, result):
prefixed_user = _prefix_user(user=user)
assert prefixed_user == result


@pytest.mark.asyncio
async def test_prefix_user_id():
prefixed_user_id = _prefix_user_id(user_id="ae34fad12")
assert prefixed_user_id == "user_id:ae34fad12"


@pytest.mark.asyncio
async def test_prefix_email():
prefixed_email = _prefix_email(email="alex.wilber@gmail.com")
assert prefixed_email == "email:alex.wilber@gmail.com"


@pytest.mark.asyncio
async def test_get_access_control_dls_disabled():
async with create_salesforce_source() as source:
source._dls_enabled = MagicMock(return_value=False)

access_control_list = []
async for access_control in source.get_access_control():
access_control_list.append(access_control)

assert len(access_control_list) == 0


@pytest.mark.asyncio
async def test_get_access_control_dls_enabled(mock_responses):
expected_user_doc = {
"_id": "user_id",
"identity": {
"email": None,
"username": "user:Dummy User",
"user_id": "user_id:user_id",
},
"created_at": "2023-12-25T01:01:01Z",
"_timestamp": "2023-12-25T01:01:01Z",
"query": {
"template": {
"params": {"access_control": ["user:Dummy User", "user_id:user_id"]}
},
"source": {
"bool": {
"filter": {
"bool": {
"should": [
{
"terms": {
"_allow_access_control.enum": [
"user:Dummy User",
"user_id:user_id",
]
}
}
]
}
}
}
},
},
}

async with create_salesforce_source() as source:
source._dls_enabled = MagicMock(return_value=True)
mock_responses.get(
TEST_QUERY_MATCH_URL,
status=200,
payload=USER_RESPONSE_PAYLOAD,
)

async for user_doc in source.get_access_control():
assert user_doc == expected_user_doc


@pytest.mark.asyncio
async def test_get_docs_with_dls_enabled(mock_responses):
async with create_salesforce_source() as source:
source._dls_enabled = MagicMock(return_value=True)
source.salesforce_client._custom_objects = AsyncMock(
return_value=["CustomObject"]
)
mock_responses.get(
TEST_FILE_DOWNLOAD_URL,
status=200,
body=b"chunk1",
)
mock_responses.get(
TEST_QUERY_MATCH_URL, repeat=True, callback=salesforce_query_callback
)

async for record, _ in source.get_docs():
assert len(record["_allow_access_control"]) > 0