Skip to content

Commit

Permalink
feat(policy): enable support for 10k+ policies (#9177)
Browse files Browse the repository at this point in the history
Co-authored-by: Pedro Silva <pedro@acryl.io>
  • Loading branch information
david-leifker and pedro93 authored Nov 7, 2023
1 parent 88cde08 commit 23c98ec
Show file tree
Hide file tree
Showing 11 changed files with 153 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,15 @@ public CompletableFuture<ListPoliciesResult> get(final DataFetchingEnvironment e
final Integer count = input.getCount() == null ? DEFAULT_COUNT : input.getCount();
final String query = input.getQuery() == null ? DEFAULT_QUERY : input.getQuery();

return CompletableFuture.supplyAsync(() -> {
try {
// First, get all policy Urns.
final PolicyFetcher.PolicyFetchResult policyFetchResult =
_policyFetcher.fetchPolicies(start, count, query, context.getAuthentication());

// Now that we have entities we can bind this to a result.
final ListPoliciesResult result = new ListPoliciesResult();
result.setStart(start);
result.setCount(count);
result.setTotal(policyFetchResult.getTotal());
result.setPolicies(mapEntities(policyFetchResult.getPolicies()));
return result;
} catch (Exception e) {
throw new RuntimeException("Failed to list policies", e);
}
});
return _policyFetcher.fetchPolicies(start, query, count, context.getAuthentication())
.thenApply(policyFetchResult -> {
final ListPoliciesResult result = new ListPoliciesResult();
result.setStart(start);
result.setCount(count);
result.setTotal(policyFetchResult.getTotal());
result.setPolicies(mapEntities(policyFetchResult.getPolicies()));
return result;
});
}
throw new AuthorizationException("Unauthorized to perform this action. Please contact your DataHub administrator.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ public SearchResult searchAcrossEntities(
@Nonnull
@Override
public ScrollResult scrollAcrossEntities(@Nonnull List<String> entities, @Nonnull String input,
@Nullable Filter filter, @Nullable String scrollId, @Nonnull String keepAlive, int count,
@Nullable Filter filter, @Nullable String scrollId, @Nullable String keepAlive, int count,
@Nullable SearchFlags searchFlags, @Nonnull Authentication authentication)
throws RemoteInvocationException {
final SearchFlags finalFlags = searchFlags != null ? searchFlags : new SearchFlags().setFulltext(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,23 @@ public SearchResult searchAcrossEntities(@Nonnull List<String> entities, @Nonnul
return result;
}

/**
* If no entities are provided, fallback to the list of non-empty entities
* @param inputEntities the requested entities
* @return some entities to search
*/
private List<String> getEntitiesToSearch(@Nonnull List<String> inputEntities) {
List<String> nonEmptyEntities;
List<String> lowercaseEntities = inputEntities.stream().map(String::toLowerCase).collect(Collectors.toList());
try (Timer.Context ignored = MetricUtils.timer(this.getClass(), "getNonEmptyEntities").time()) {
nonEmptyEntities = _entityDocCountCache.getNonEmptyEntities();
}
if (!inputEntities.isEmpty()) {
nonEmptyEntities = nonEmptyEntities.stream().filter(lowercaseEntities::contains).collect(Collectors.toList());

if (lowercaseEntities.isEmpty()) {
try (Timer.Context ignored = MetricUtils.timer(this.getClass(), "getNonEmptyEntities").time()) {
nonEmptyEntities = _entityDocCountCache.getNonEmptyEntities();
}
} else {
nonEmptyEntities = lowercaseEntities;
}

return nonEmptyEntities;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,13 @@ public DataHubAuthorizer(
final EntityClient entityClient,
final int delayIntervalSeconds,
final int refreshIntervalSeconds,
final AuthorizationMode mode) {
final AuthorizationMode mode,
final int policyFetchSize) {
_systemAuthentication = Objects.requireNonNull(systemAuthentication);
_mode = Objects.requireNonNull(mode);
_policyEngine = new PolicyEngine(systemAuthentication, Objects.requireNonNull(entityClient));
_policyRefreshRunnable = new PolicyRefreshRunnable(systemAuthentication, new PolicyFetcher(entityClient), _policyCache, readWriteLock.writeLock());
_policyRefreshRunnable = new PolicyRefreshRunnable(systemAuthentication, new PolicyFetcher(entityClient), _policyCache,
readWriteLock.writeLock(), policyFetchSize);
_refreshExecutorService.scheduleAtFixedRate(_policyRefreshRunnable, delayIntervalSeconds, refreshIntervalSeconds, TimeUnit.SECONDS);
}

Expand Down Expand Up @@ -244,29 +246,28 @@ static class PolicyRefreshRunnable implements Runnable {
private final PolicyFetcher _policyFetcher;
private final Map<String, List<DataHubPolicyInfo>> _policyCache;
private final Lock writeLock;
private final int count;

@Override
public void run() {
try {
// Populate new cache and swap.
Map<String, List<DataHubPolicyInfo>> newCache = new HashMap<>();
Integer total = null;
String scrollId = null;

int start = 0;
int count = 30;
int total = 30;

while (start < total) {
while (total == null || scrollId != null) {
try {
final PolicyFetcher.PolicyFetchResult
policyFetchResult = _policyFetcher.fetchPolicies(start, count, _systemAuthentication);
policyFetchResult = _policyFetcher.fetchPolicies(count, scrollId, _systemAuthentication);

addPoliciesToCache(newCache, policyFetchResult.getPolicies());

total = policyFetchResult.getTotal();
start = start + count;
scrollId = policyFetchResult.getScrollId();
} catch (Exception e) {
log.error(
"Failed to retrieve policy urns! Skipping updating policy cache until next refresh. start: {}, count: {}", start, count, e);
"Failed to retrieve policy urns! Skipping updating policy cache until next refresh. count: {}, scrollId: {}", count, scrollId, e);
return;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import com.linkedin.metadata.query.SearchFlags;
import com.linkedin.metadata.query.filter.SortCriterion;
import com.linkedin.metadata.query.filter.SortOrder;
import com.linkedin.metadata.search.ScrollResult;
import com.linkedin.metadata.search.SearchEntity;
import com.linkedin.metadata.search.SearchResult;
import com.linkedin.policy.DataHubPolicyInfo;
import com.linkedin.r2.RemoteInvocationException;
import java.net.URISyntaxException;
Expand All @@ -18,11 +18,14 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import lombok.RequiredArgsConstructor;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;

import javax.annotation.Nullable;

import static com.linkedin.metadata.Constants.DATAHUB_POLICY_INFO_ASPECT_NAME;
import static com.linkedin.metadata.Constants.POLICY_ENTITY_NAME;

Expand All @@ -38,22 +41,53 @@ public class PolicyFetcher {
private static final SortCriterion POLICY_SORT_CRITERION =
new SortCriterion().setField("lastUpdatedTimestamp").setOrder(SortOrder.DESCENDING);

public PolicyFetchResult fetchPolicies(int start, int count, Authentication authentication)
throws RemoteInvocationException, URISyntaxException {
return fetchPolicies(start, count, "", authentication);
/**
* This is to provide a scroll implementation using the start/count api. It is not efficient
* and the scroll native functions should be used instead. This does fix a failure to fetch
* policies when deep pagination happens where there are >10k policies.
* Exists primarily to prevent breaking change to the graphql api.
*/
@Deprecated
public CompletableFuture<PolicyFetchResult> fetchPolicies(int start, String query, int count, Authentication authentication) {
return CompletableFuture.supplyAsync(() -> {
try {
PolicyFetchResult result = PolicyFetchResult.EMPTY;
String scrollId = "";
int fetchedResults = 0;

while (PolicyFetchResult.EMPTY.equals(result) && scrollId != null) {
PolicyFetchResult tmpResult = fetchPolicies(query, count, scrollId.isEmpty() ? null : scrollId, authentication);
fetchedResults += tmpResult.getPolicies().size();
scrollId = tmpResult.getScrollId();
if (fetchedResults > start) {
result = tmpResult;
}
}

return result;
} catch (Exception e) {
throw new RuntimeException("Failed to list policies", e);
}
});
}

public PolicyFetchResult fetchPolicies(int start, int count, String query, Authentication authentication)
public PolicyFetchResult fetchPolicies(int count, @Nullable String scrollId, Authentication authentication)
throws RemoteInvocationException, URISyntaxException {
return fetchPolicies("", count, scrollId, authentication);
}

public PolicyFetchResult fetchPolicies(String query, int count, @Nullable String scrollId, Authentication authentication)
throws RemoteInvocationException, URISyntaxException {
log.debug(String.format("Batch fetching policies. start: %s, count: %s ", start, count));
// First fetch all policy urns from start - start + count
SearchResult result =
_entityClient.search(POLICY_ENTITY_NAME, query, null, POLICY_SORT_CRITERION, start, count, authentication,
new SearchFlags().setFulltext(true));
log.debug(String.format("Batch fetching policies. count: %s, scroll: %s", count, scrollId));

// First fetch all policy urns
ScrollResult result = _entityClient.scrollAcrossEntities(List.of(POLICY_ENTITY_NAME), query, null, scrollId,
null, count, new SearchFlags().setSkipCache(true).setSkipAggregates(true)
.setSkipHighlighting(true).setFulltext(true), authentication);
List<Urn> policyUrns = result.getEntities().stream().map(SearchEntity::getEntity).collect(Collectors.toList());

if (policyUrns.isEmpty()) {
return new PolicyFetchResult(Collections.emptyList(), 0);
return PolicyFetchResult.EMPTY;
}

// Fetch DataHubPolicyInfo aspects for each urn
Expand All @@ -64,7 +98,7 @@ public PolicyFetchResult fetchPolicies(int start, int count, String query, Authe
.filter(Objects::nonNull)
.map(this::extractPolicy)
.filter(Objects::nonNull)
.collect(Collectors.toList()), result.getNumEntities());
.collect(Collectors.toList()), result.getNumEntities(), result.getScrollId());
}

private Policy extractPolicy(EntityResponse entityResponse) {
Expand All @@ -82,6 +116,10 @@ private Policy extractPolicy(EntityResponse entityResponse) {
public static class PolicyFetchResult {
List<Policy> policies;
int total;
@Nullable
String scrollId;

public static final PolicyFetchResult EMPTY = new PolicyFetchResult(Collections.emptyList(), 0, null);
}

@Value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.linkedin.entity.EnvelopedAspectMap;
import com.linkedin.entity.client.EntityClient;
import com.linkedin.metadata.query.SearchFlags;
import com.linkedin.metadata.search.ScrollResult;
import com.linkedin.metadata.search.SearchEntity;
import com.linkedin.metadata.search.SearchEntityArray;
import com.linkedin.metadata.search.SearchResult;
Expand All @@ -35,6 +36,8 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

Expand Down Expand Up @@ -89,30 +92,58 @@ public void setupTest() throws Exception {
final EnvelopedAspectMap childDomainPolicyAspectMap = new EnvelopedAspectMap();
childDomainPolicyAspectMap.put(DATAHUB_POLICY_INFO_ASPECT_NAME, new EnvelopedAspect().setValue(new Aspect(childDomainPolicy.data())));

final SearchResult policySearchResult = new SearchResult();
policySearchResult.setNumEntities(3);
policySearchResult.setEntities(
new SearchEntityArray(
ImmutableList.of(
new SearchEntity().setEntity(activePolicyUrn),
new SearchEntity().setEntity(inactivePolicyUrn),
new SearchEntity().setEntity(parentDomainPolicyUrn),
new SearchEntity().setEntity(childDomainPolicyUrn)
)
)
);

when(_entityClient.search(eq("dataHubPolicy"), eq(""), isNull(), any(), anyInt(), anyInt(), any(),
eq(new SearchFlags().setFulltext(true)))).thenReturn(policySearchResult);
when(_entityClient.batchGetV2(eq(POLICY_ENTITY_NAME),
eq(ImmutableSet.of(activePolicyUrn, inactivePolicyUrn, parentDomainPolicyUrn, childDomainPolicyUrn)), eq(null), any())).thenReturn(
ImmutableMap.of(
activePolicyUrn, new EntityResponse().setUrn(activePolicyUrn).setAspects(activeAspectMap),
inactivePolicyUrn, new EntityResponse().setUrn(inactivePolicyUrn).setAspects(inactiveAspectMap),
parentDomainPolicyUrn, new EntityResponse().setUrn(parentDomainPolicyUrn).setAspects(parentDomainPolicyAspectMap),
childDomainPolicyUrn, new EntityResponse().setUrn(childDomainPolicyUrn).setAspects(childDomainPolicyAspectMap)
)
);
final ScrollResult policySearchResult1 = new ScrollResult()
.setScrollId("1")
.setNumEntities(4)
.setEntities(
new SearchEntityArray(
ImmutableList.of(new SearchEntity().setEntity(activePolicyUrn))));

final ScrollResult policySearchResult2 = new ScrollResult()
.setScrollId("2")
.setNumEntities(4)
.setEntities(
new SearchEntityArray(
ImmutableList.of(new SearchEntity().setEntity(inactivePolicyUrn))));

final ScrollResult policySearchResult3 = new ScrollResult()
.setScrollId("3")
.setNumEntities(4)
.setEntities(
new SearchEntityArray(
ImmutableList.of(new SearchEntity().setEntity(parentDomainPolicyUrn))));

final ScrollResult policySearchResult4 = new ScrollResult()
.setNumEntities(4)
.setEntities(
new SearchEntityArray(
ImmutableList.of(
new SearchEntity().setEntity(childDomainPolicyUrn))));

when(_entityClient.scrollAcrossEntities(eq(List.of("dataHubPolicy")), eq(""), isNull(), any(), isNull(),
anyInt(), eq(new SearchFlags().setFulltext(true).setSkipAggregates(true).setSkipHighlighting(true).setSkipCache(true)), any()))
.thenReturn(policySearchResult1)
.thenReturn(policySearchResult2)
.thenReturn(policySearchResult3)
.thenReturn(policySearchResult4);

when(_entityClient.batchGetV2(eq(POLICY_ENTITY_NAME), any(), eq(null), any())).thenAnswer(args -> {
Set<Urn> inputUrns = args.getArgument(1);
Urn urn = inputUrns.stream().findFirst().get();

switch (urn.toString()) {
case "urn:li:dataHubPolicy:0":
return Map.of(activePolicyUrn, new EntityResponse().setUrn(activePolicyUrn).setAspects(activeAspectMap));
case "urn:li:dataHubPolicy:1":
return Map.of(inactivePolicyUrn, new EntityResponse().setUrn(inactivePolicyUrn).setAspects(inactiveAspectMap));
case "urn:li:dataHubPolicy:2":
return Map.of(parentDomainPolicyUrn, new EntityResponse().setUrn(parentDomainPolicyUrn).setAspects(parentDomainPolicyAspectMap));
case "urn:li:dataHubPolicy:3":
return Map.of(childDomainPolicyUrn, new EntityResponse().setUrn(childDomainPolicyUrn).setAspects(childDomainPolicyAspectMap));
default:
throw new IllegalStateException();
}
});

final List<Urn> userUrns = ImmutableList.of(Urn.createFromString("urn:li:corpuser:user3"), Urn.createFromString("urn:li:corpuser:user4"));
final List<Urn> groupUrns = ImmutableList.of(Urn.createFromString("urn:li:corpGroup:group3"), Urn.createFromString("urn:li:corpGroup:group4"));
Expand Down Expand Up @@ -146,7 +177,8 @@ childDomainPolicyUrn, new EntityResponse().setUrn(childDomainPolicyUrn).setAspec
_entityClient,
10,
10,
DataHubAuthorizer.AuthorizationMode.DEFAULT
DataHubAuthorizer.AuthorizationMode.DEFAULT,
1 // force pagination logic
);
_dataHubAuthorizer.init(Collections.emptyMap(), createAuthorizerContext(systemAuthentication, _entityClient));
_dataHubAuthorizer.invalidateCache();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ authorization:
defaultAuthorizer:
enabled: ${AUTH_POLICIES_ENABLED:true}
cacheRefreshIntervalSecs: ${POLICY_CACHE_REFRESH_INTERVAL_SECONDS:120}
cachePolicyFetchSize: ${POLICY_CACHE_FETCH_SIZE:1000}
# Enables authorization of reads, writes, and deletes on REST APIs. Defaults to false for backwards compatibility, but should become true down the road
restApiAuthorization: ${REST_API_AUTHORIZATION_ENABLED:false}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ public class DataHubAuthorizerFactory {
@Value("${authorization.defaultAuthorizer.cacheRefreshIntervalSecs}")
private Integer policyCacheRefreshIntervalSeconds;

@Value("${authorization.defaultAuthorizer.cachePolicyFetchSize}")
private Integer policyCacheFetchSize;

@Value("${authorization.defaultAuthorizer.enabled:true}")
private Boolean policiesEnabled;

Expand All @@ -44,6 +47,6 @@ protected DataHubAuthorizer getInstance() {
: DataHubAuthorizer.AuthorizationMode.ALLOW_ALL;

return new DataHubAuthorizer(systemAuthentication, entityClient, 10,
policyCacheRefreshIntervalSeconds, mode);
policyCacheRefreshIntervalSeconds, mode, policyCacheFetchSize);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ public SearchResult searchAcrossEntities(@Nonnull List<String> entities, @Nonnul
*/
@Nonnull
ScrollResult scrollAcrossEntities(@Nonnull List<String> entities, @Nonnull String input,
@Nullable Filter filter, @Nullable String scrollId, @Nonnull String keepAlive, int count, @Nullable SearchFlags searchFlags,
@Nullable Filter filter, @Nullable String scrollId, @Nullable String keepAlive, int count, @Nullable SearchFlags searchFlags,
@Nonnull Authentication authentication)
throws RemoteInvocationException;

Expand Down
Loading

0 comments on commit 23c98ec

Please sign in to comment.