Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(policy): enable support for 10k+ policies #9177

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,15 @@ public CompletableFuture<ListPoliciesResult> get(final DataFetchingEnvironment e
final Integer count = input.getCount() == null ? DEFAULT_COUNT : input.getCount();
final String query = input.getQuery() == null ? DEFAULT_QUERY : input.getQuery();

return CompletableFuture.supplyAsync(() -> {
try {
// First, get all policy Urns.
final PolicyFetcher.PolicyFetchResult policyFetchResult =
_policyFetcher.fetchPolicies(start, count, query, context.getAuthentication());

// Now that we have entities we can bind this to a result.
final ListPoliciesResult result = new ListPoliciesResult();
result.setStart(start);
result.setCount(count);
result.setTotal(policyFetchResult.getTotal());
result.setPolicies(mapEntities(policyFetchResult.getPolicies()));
return result;
} catch (Exception e) {
throw new RuntimeException("Failed to list policies", e);
}
});
return _policyFetcher.fetchPolicies(start, query, count, context.getAuthentication())
.thenApply(policyFetchResult -> {
final ListPoliciesResult result = new ListPoliciesResult();
result.setStart(start);
result.setCount(count);
result.setTotal(policyFetchResult.getTotal());
result.setPolicies(mapEntities(policyFetchResult.getPolicies()));
return result;
});
}
throw new AuthorizationException("Unauthorized to perform this action. Please contact your DataHub administrator.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ public SearchResult searchAcrossEntities(
@Nonnull
@Override
public ScrollResult scrollAcrossEntities(@Nonnull List<String> entities, @Nonnull String input,
@Nullable Filter filter, @Nullable String scrollId, @Nonnull String keepAlive, int count,
@Nullable Filter filter, @Nullable String scrollId, @Nullable String keepAlive, int count,
@Nullable SearchFlags searchFlags, @Nonnull Authentication authentication)
throws RemoteInvocationException {
final SearchFlags finalFlags = searchFlags != null ? searchFlags : new SearchFlags().setFulltext(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,13 @@ public DataHubAuthorizer(
final EntityClient entityClient,
final int delayIntervalSeconds,
final int refreshIntervalSeconds,
final AuthorizationMode mode) {
final AuthorizationMode mode,
final int policyFetchSize) {
_systemAuthentication = Objects.requireNonNull(systemAuthentication);
_mode = Objects.requireNonNull(mode);
_policyEngine = new PolicyEngine(systemAuthentication, Objects.requireNonNull(entityClient));
_policyRefreshRunnable = new PolicyRefreshRunnable(systemAuthentication, new PolicyFetcher(entityClient), _policyCache, readWriteLock.writeLock());
_policyRefreshRunnable = new PolicyRefreshRunnable(systemAuthentication, new PolicyFetcher(entityClient), _policyCache,
readWriteLock.writeLock(), policyFetchSize);
_refreshExecutorService.scheduleAtFixedRate(_policyRefreshRunnable, delayIntervalSeconds, refreshIntervalSeconds, TimeUnit.SECONDS);
}

Expand Down Expand Up @@ -244,29 +246,28 @@ static class PolicyRefreshRunnable implements Runnable {
private final PolicyFetcher _policyFetcher;
private final Map<String, List<DataHubPolicyInfo>> _policyCache;
private final Lock writeLock;
private final int count;

@Override
public void run() {
try {
// Populate new cache and swap.
Map<String, List<DataHubPolicyInfo>> newCache = new HashMap<>();
Integer total = null;
String scrollId = null;

int start = 0;
int count = 30;
int total = 30;

while (start < total) {
while (total == null || scrollId != null) {
try {
final PolicyFetcher.PolicyFetchResult
policyFetchResult = _policyFetcher.fetchPolicies(start, count, _systemAuthentication);
policyFetchResult = _policyFetcher.fetchPolicies(count, scrollId, _systemAuthentication);

addPoliciesToCache(newCache, policyFetchResult.getPolicies());

total = policyFetchResult.getTotal();
start = start + count;
scrollId = policyFetchResult.getScrollId();
} catch (Exception e) {
log.error(
"Failed to retrieve policy urns! Skipping updating policy cache until next refresh. start: {}, count: {}", start, count, e);
"Failed to retrieve policy urns! Skipping updating policy cache until next refresh. count: {}, scrollId: {}", count, scrollId, e);
return;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import com.linkedin.metadata.query.SearchFlags;
import com.linkedin.metadata.query.filter.SortCriterion;
import com.linkedin.metadata.query.filter.SortOrder;
import com.linkedin.metadata.search.ScrollResult;
import com.linkedin.metadata.search.SearchEntity;
import com.linkedin.metadata.search.SearchResult;
import com.linkedin.policy.DataHubPolicyInfo;
import com.linkedin.r2.RemoteInvocationException;
import java.net.URISyntaxException;
Expand All @@ -18,11 +18,14 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import lombok.RequiredArgsConstructor;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;

import javax.annotation.Nullable;

import static com.linkedin.metadata.Constants.DATAHUB_POLICY_INFO_ASPECT_NAME;
import static com.linkedin.metadata.Constants.POLICY_ENTITY_NAME;

Expand All @@ -38,22 +41,53 @@ public class PolicyFetcher {
private static final SortCriterion POLICY_SORT_CRITERION =
new SortCriterion().setField("lastUpdatedTimestamp").setOrder(SortOrder.DESCENDING);

public PolicyFetchResult fetchPolicies(int start, int count, Authentication authentication)
throws RemoteInvocationException, URISyntaxException {
return fetchPolicies(start, count, "", authentication);
/**
* This is to provide a scroll implementation using the start/count api. It is not efficient
* and the scroll native functions should be used instead. This does fix a failure to fetch
* policies when deep pagination happens where there are >10k policies.
* Exists primarily to prevent breaking change to the graphql api.
*/
@Deprecated
public CompletableFuture<PolicyFetchResult> fetchPolicies(int start, String query, int count, Authentication authentication) {
return CompletableFuture.supplyAsync(() -> {
try {
PolicyFetchResult result = PolicyFetchResult.EMPTY;
String scrollId = "";
int fetchedResults = 0;

while (PolicyFetchResult.EMPTY.equals(result) && scrollId != null) {
PolicyFetchResult tmpResult = fetchPolicies(query, count, scrollId.isEmpty() ? null : scrollId, authentication);
fetchedResults += tmpResult.getPolicies().size();
scrollId = tmpResult.getScrollId();
if (fetchedResults > start) {
result = tmpResult;
}
}

return result;
} catch (Exception e) {
throw new RuntimeException("Failed to list policies", e);
}
});
}

public PolicyFetchResult fetchPolicies(int start, int count, String query, Authentication authentication)
public PolicyFetchResult fetchPolicies(int count, @Nullable String scrollId, Authentication authentication)
throws RemoteInvocationException, URISyntaxException {
return fetchPolicies("", count, scrollId, authentication);
david-leifker marked this conversation as resolved.
Show resolved Hide resolved
}

public PolicyFetchResult fetchPolicies(String query, int count, @Nullable String scrollId, Authentication authentication)
throws RemoteInvocationException, URISyntaxException {
log.debug(String.format("Batch fetching policies. start: %s, count: %s ", start, count));
// First fetch all policy urns from start - start + count
SearchResult result =
_entityClient.search(POLICY_ENTITY_NAME, query, null, POLICY_SORT_CRITERION, start, count, authentication,
new SearchFlags().setFulltext(true));
log.debug(String.format("Batch fetching policies. count: %s, scroll: %s", count, scrollId));

// First fetch all policy urns
ScrollResult result = _entityClient.scrollAcrossEntities(List.of(POLICY_ENTITY_NAME), query, null, scrollId,
david-leifker marked this conversation as resolved.
Show resolved Hide resolved
null, count, new SearchFlags().setSkipCache(true).setSkipAggregates(true)
.setSkipHighlighting(true).setFulltext(true), authentication);
List<Urn> policyUrns = result.getEntities().stream().map(SearchEntity::getEntity).collect(Collectors.toList());

if (policyUrns.isEmpty()) {
return new PolicyFetchResult(Collections.emptyList(), 0);
return PolicyFetchResult.EMPTY;
}

// Fetch DataHubPolicyInfo aspects for each urn
Expand All @@ -64,7 +98,7 @@ public PolicyFetchResult fetchPolicies(int start, int count, String query, Authe
.filter(Objects::nonNull)
.map(this::extractPolicy)
.filter(Objects::nonNull)
.collect(Collectors.toList()), result.getNumEntities());
.collect(Collectors.toList()), result.getNumEntities(), result.getScrollId());
}

private Policy extractPolicy(EntityResponse entityResponse) {
Expand All @@ -82,6 +116,10 @@ private Policy extractPolicy(EntityResponse entityResponse) {
public static class PolicyFetchResult {
List<Policy> policies;
int total;
@Nullable
String scrollId;

public static final PolicyFetchResult EMPTY = new PolicyFetchResult(Collections.emptyList(), 0, null);
}

@Value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.linkedin.entity.EnvelopedAspectMap;
import com.linkedin.entity.client.EntityClient;
import com.linkedin.metadata.query.SearchFlags;
import com.linkedin.metadata.search.ScrollResult;
import com.linkedin.metadata.search.SearchEntity;
import com.linkedin.metadata.search.SearchEntityArray;
import com.linkedin.metadata.search.SearchResult;
Expand All @@ -35,6 +36,8 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

Expand Down Expand Up @@ -89,30 +92,58 @@ public void setupTest() throws Exception {
final EnvelopedAspectMap childDomainPolicyAspectMap = new EnvelopedAspectMap();
childDomainPolicyAspectMap.put(DATAHUB_POLICY_INFO_ASPECT_NAME, new EnvelopedAspect().setValue(new Aspect(childDomainPolicy.data())));

final SearchResult policySearchResult = new SearchResult();
policySearchResult.setNumEntities(3);
policySearchResult.setEntities(
new SearchEntityArray(
ImmutableList.of(
new SearchEntity().setEntity(activePolicyUrn),
new SearchEntity().setEntity(inactivePolicyUrn),
new SearchEntity().setEntity(parentDomainPolicyUrn),
new SearchEntity().setEntity(childDomainPolicyUrn)
)
)
);

when(_entityClient.search(eq("dataHubPolicy"), eq(""), isNull(), any(), anyInt(), anyInt(), any(),
eq(new SearchFlags().setFulltext(true)))).thenReturn(policySearchResult);
when(_entityClient.batchGetV2(eq(POLICY_ENTITY_NAME),
eq(ImmutableSet.of(activePolicyUrn, inactivePolicyUrn, parentDomainPolicyUrn, childDomainPolicyUrn)), eq(null), any())).thenReturn(
ImmutableMap.of(
activePolicyUrn, new EntityResponse().setUrn(activePolicyUrn).setAspects(activeAspectMap),
inactivePolicyUrn, new EntityResponse().setUrn(inactivePolicyUrn).setAspects(inactiveAspectMap),
parentDomainPolicyUrn, new EntityResponse().setUrn(parentDomainPolicyUrn).setAspects(parentDomainPolicyAspectMap),
childDomainPolicyUrn, new EntityResponse().setUrn(childDomainPolicyUrn).setAspects(childDomainPolicyAspectMap)
)
);
final ScrollResult policySearchResult1 = new ScrollResult()
.setScrollId("1")
.setNumEntities(4)
.setEntities(
new SearchEntityArray(
ImmutableList.of(new SearchEntity().setEntity(activePolicyUrn))));

final ScrollResult policySearchResult2 = new ScrollResult()
.setScrollId("2")
.setNumEntities(4)
.setEntities(
new SearchEntityArray(
ImmutableList.of(new SearchEntity().setEntity(inactivePolicyUrn))));

final ScrollResult policySearchResult3 = new ScrollResult()
.setScrollId("3")
.setNumEntities(4)
.setEntities(
new SearchEntityArray(
ImmutableList.of(new SearchEntity().setEntity(parentDomainPolicyUrn))));

final ScrollResult policySearchResult4 = new ScrollResult()
.setNumEntities(4)
.setEntities(
new SearchEntityArray(
ImmutableList.of(
new SearchEntity().setEntity(childDomainPolicyUrn))));

when(_entityClient.scrollAcrossEntities(eq(List.of("dataHubPolicy")), eq(""), isNull(), any(), isNull(),
anyInt(), eq(new SearchFlags().setFulltext(true).setSkipAggregates(true).setSkipHighlighting(true).setSkipCache(true)), any()))
.thenReturn(policySearchResult1)
.thenReturn(policySearchResult2)
.thenReturn(policySearchResult3)
.thenReturn(policySearchResult4);

when(_entityClient.batchGetV2(eq(POLICY_ENTITY_NAME), any(), eq(null), any())).thenAnswer(args -> {
Set<Urn> inputUrns = args.getArgument(1);
Urn urn = inputUrns.stream().findFirst().get();

switch (urn.toString()) {
case "urn:li:dataHubPolicy:0":
return Map.of(activePolicyUrn, new EntityResponse().setUrn(activePolicyUrn).setAspects(activeAspectMap));
case "urn:li:dataHubPolicy:1":
return Map.of(inactivePolicyUrn, new EntityResponse().setUrn(inactivePolicyUrn).setAspects(inactiveAspectMap));
case "urn:li:dataHubPolicy:2":
return Map.of(parentDomainPolicyUrn, new EntityResponse().setUrn(parentDomainPolicyUrn).setAspects(parentDomainPolicyAspectMap));
case "urn:li:dataHubPolicy:3":
return Map.of(childDomainPolicyUrn, new EntityResponse().setUrn(childDomainPolicyUrn).setAspects(childDomainPolicyAspectMap));
default:
throw new IllegalStateException();
}
});

final List<Urn> userUrns = ImmutableList.of(Urn.createFromString("urn:li:corpuser:user3"), Urn.createFromString("urn:li:corpuser:user4"));
final List<Urn> groupUrns = ImmutableList.of(Urn.createFromString("urn:li:corpGroup:group3"), Urn.createFromString("urn:li:corpGroup:group4"));
Expand Down Expand Up @@ -146,7 +177,8 @@ childDomainPolicyUrn, new EntityResponse().setUrn(childDomainPolicyUrn).setAspec
_entityClient,
10,
10,
DataHubAuthorizer.AuthorizationMode.DEFAULT
DataHubAuthorizer.AuthorizationMode.DEFAULT,
1 // force pagination logic
);
_dataHubAuthorizer.init(Collections.emptyMap(), createAuthorizerContext(systemAuthentication, _entityClient));
_dataHubAuthorizer.invalidateCache();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ authorization:
defaultAuthorizer:
enabled: ${AUTH_POLICIES_ENABLED:true}
cacheRefreshIntervalSecs: ${POLICY_CACHE_REFRESH_INTERVAL_SECONDS:120}
cachePolicyFetchSize: ${POLICY_CACHE_FETCH_SIZE:1000}
# Enables authorization of reads, writes, and deletes on REST APIs. Defaults to false for backwards compatibility, but should become true down the road
restApiAuthorization: ${REST_API_AUTHORIZATION_ENABLED:false}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ public class DataHubAuthorizerFactory {
@Value("${authorization.defaultAuthorizer.cacheRefreshIntervalSecs}")
private Integer policyCacheRefreshIntervalSeconds;

@Value("${authorization.defaultAuthorizer.cachePolicyFetchSize}")
private Integer policyCacheFetchSize;

@Value("${authorization.defaultAuthorizer.enabled:true}")
private Boolean policiesEnabled;

Expand All @@ -44,6 +47,6 @@ protected DataHubAuthorizer getInstance() {
: DataHubAuthorizer.AuthorizationMode.ALLOW_ALL;

return new DataHubAuthorizer(systemAuthentication, entityClient, 10,
policyCacheRefreshIntervalSeconds, mode);
policyCacheRefreshIntervalSeconds, mode, policyCacheFetchSize);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ public SearchResult searchAcrossEntities(@Nonnull List<String> entities, @Nonnul
*/
@Nonnull
ScrollResult scrollAcrossEntities(@Nonnull List<String> entities, @Nonnull String input,
@Nullable Filter filter, @Nullable String scrollId, @Nonnull String keepAlive, int count, @Nullable SearchFlags searchFlags,
@Nullable Filter filter, @Nullable String scrollId, @Nullable String keepAlive, int count, @Nullable SearchFlags searchFlags,
@Nonnull Authentication authentication)
throws RemoteInvocationException;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,11 +482,11 @@ public SearchResult searchAcrossEntities(@Nonnull List<String> entities, @Nonnul
@Nonnull
@Override
public ScrollResult scrollAcrossEntities(@Nonnull List<String> entities, @Nonnull String input,
@Nullable Filter filter, @Nullable String scrollId, @Nonnull String keepAlive, int count,
@Nullable Filter filter, @Nullable String scrollId, @Nullable String keepAlive, int count,
@Nullable SearchFlags searchFlags, @Nonnull Authentication authentication)
throws RemoteInvocationException {
final EntitiesDoScrollAcrossEntitiesRequestBuilder requestBuilder =
ENTITIES_REQUEST_BUILDERS.actionScrollAcrossEntities().inputParam(input).countParam(count).keepAliveParam(keepAlive);
ENTITIES_REQUEST_BUILDERS.actionScrollAcrossEntities().inputParam(input).countParam(count);

if (entities != null) {
requestBuilder.entitiesParam(new StringArray(entities));
Expand All @@ -500,6 +500,9 @@ public ScrollResult scrollAcrossEntities(@Nonnull List<String> entities, @Nonnul
if (searchFlags != null) {
requestBuilder.searchFlagsParam(searchFlags);
}
if (keepAlive != null) {
requestBuilder.keepAliveParam(keepAlive);
}

return sendClientRequest(requestBuilder, authentication).getEntity();
}
Expand Down
Loading