Skip to content

Commit

Permalink
feat(neo4j): neo4j pagination as per v2 scrollApi for related entities (
Browse files Browse the repository at this point in the history
  • Loading branch information
deepgarg-visa committed May 19, 2024
1 parent f5a252c commit a44a549
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.linkedin.common.UrnArrayArray;
import com.linkedin.common.urn.Urn;
import com.linkedin.metadata.aspect.models.graph.Edge;
import com.linkedin.metadata.aspect.models.graph.RelatedEntities;
import com.linkedin.metadata.aspect.models.graph.RelatedEntitiesScrollResult;
import com.linkedin.metadata.aspect.models.graph.RelatedEntity;
import com.linkedin.metadata.graph.EntityLineageResult;
Expand All @@ -28,6 +29,7 @@
import com.linkedin.metadata.query.filter.RelationshipDirection;
import com.linkedin.metadata.query.filter.RelationshipFilter;
import com.linkedin.metadata.query.filter.SortCriterion;
import com.linkedin.metadata.search.elasticsearch.query.request.SearchAfterWrapper;
import com.linkedin.metadata.utils.metrics.MetricUtils;
import com.linkedin.util.Pair;
import io.opentelemetry.extension.annotations.WithSpan;
Expand All @@ -38,6 +40,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.StringJoiner;
Expand Down Expand Up @@ -905,6 +908,99 @@ public RelatedEntitiesScrollResult scrollRelatedEntities(
int count,
@Nullable Long startTimeMillis,
@Nullable Long endTimeMillis) {
throw new IllegalArgumentException("Not implemented");

if (sourceTypes != null && sourceTypes.isEmpty()
|| destinationTypes != null && destinationTypes.isEmpty()) {
return new RelatedEntitiesScrollResult(0, 0, null, Collections.emptyList());
}

final String srcCriteria = filterToCriteria(sourceEntityFilter).trim();
final String destCriteria = filterToCriteria(destinationEntityFilter).trim();
final String edgeCriteria = relationshipFilterToCriteria(relationshipFilter);

final RelationshipDirection relationshipDirection = relationshipFilter.getDirection();
String srcNodeLabel = "";
// Create a URN from the String. Only proceed if srcCriteria is not null or empty
if (srcCriteria != null && !srcCriteria.isEmpty()) {
final String urnValue =
sourceEntityFilter.getOr().get(0).getAnd().get(0).getValue().toString();
try {
final Urn urn = Urn.createFromString(urnValue);
srcNodeLabel = urn.getEntityType();
} catch (URISyntaxException e) {
log.error("Failed to parse URN: {} ", urnValue, e);
}
}
String matchTemplate = "MATCH (src:%s %s)-[r%s %s]-(dest %s)%s";
if (relationshipDirection == RelationshipDirection.INCOMING) {
matchTemplate = "MATCH (src:%s %s)<-[r%s %s]-(dest %s)%s";
} else if (relationshipDirection == RelationshipDirection.OUTGOING) {
matchTemplate = "MATCH (src:%s %s)-[r%s %s]->(dest %s)%s";
}

final String returnNodes =
String.format(
"RETURN dest, src, type(r)"); // Return both related entity and the relationship type.
final String returnCount = "RETURN count(*)"; // For getting the total results.

String relationshipTypeFilter = "";
if (!relationshipTypes.isEmpty()) {
relationshipTypeFilter = ":" + StringUtils.join(relationshipTypes, "|");
}

String whereClause = computeEntityTypeWhereClause(sourceTypes, destinationTypes);

// Build Statement strings
String baseStatementString =
String.format(
matchTemplate,
srcNodeLabel,
srcCriteria,
relationshipTypeFilter,
edgeCriteria,
destCriteria,
whereClause);

log.info(baseStatementString);

final String resultStatementString =
String.format("%s %s SKIP $offset LIMIT $count", baseStatementString, returnNodes);
final String countStatementString = String.format("%s %s", baseStatementString, returnCount);

int offset = 0;
if (Objects.nonNull(scrollId)) {
offset = Integer.valueOf(SearchAfterWrapper.fromScrollId(scrollId).getPitId().toString());
}

// Build Statements
final Statement resultStatement =
new Statement(resultStatementString, ImmutableMap.of("offset", offset, "count", count));
final Statement countStatement = new Statement(countStatementString, Collections.emptyMap());

// Execute Queries
final List<RelatedEntities> relatedEntities =
runQuery(resultStatement)
.list(
record ->
new RelatedEntities(
record.values().get(2).asString(), // Relationship Type
record.values().get(0).asNode().get("urn").asString(),
record.values().get(1).asNode().get("urn").asString(),
relationshipDirection,
null));
final int totalCount = runQuery(countStatement).single().get(0).asInt();
log.info("Total Related Entities: {}", totalCount);
// return new RelatedEntitiesResult(0, relatedEntities.size(), totalCount, relatedEntities);
String nextScrollId = null;
if (relatedEntities.size() == count) {
String pitId = Integer.toString(offset + count);
nextScrollId = new SearchAfterWrapper(null, pitId, 0L).toScrollId();
}
return RelatedEntitiesScrollResult.builder()
.entities(relatedEntities)
.pageSize(relatedEntities.size())
.numResults(totalCount)
.scrollId(nextScrollId)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import java.util.Arrays;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.BiConsumer;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
Expand Down Expand Up @@ -75,16 +75,22 @@ public void handleChangeEvent(

Urn urn = entityChangeEvent.getEntityUrn();
log.info("Business Attribute update hook invoked for urn :" + urn);

fetchRelatedEntities(opContext, urn, batch -> processBatch(opContext, batch), null, 0);
fetchRelatedEntities(
opContext,
urn,
(batch, batchNumber) -> processBatch(opContext, batch, batchNumber),
null,
0,
1);
}

private void fetchRelatedEntities(
@NonNull final OperationContext opContext,
@NonNull final Urn urn,
@NonNull final Consumer<RelatedEntitiesScrollResult> resultConsumer,
@NonNull final BiConsumer<RelatedEntitiesScrollResult, Integer> resultConsumer,
@Nullable String scrollId,
int consumedEntityCount) {
int consumedEntityCount,
int batchNumber) {
GraphRetriever graph = opContext.getRetrieverContext().get().getGraphRetriever();

RelatedEntitiesScrollResult result =
Expand All @@ -100,22 +106,21 @@ private void fetchRelatedEntities(
getRelatedEntitiesBatchSize,
null,
null);
resultConsumer.accept(result);

resultConsumer.accept(result, batchNumber);
consumedEntityCount = consumedEntityCount + result.getEntities().size();
if (result.getScrollId() != null && consumedEntityCount < relatedEntitiesCount) {
batchNumber = batchNumber + 1;
fetchRelatedEntities(
opContext,
urn,
resultConsumer,
result.getScrollId(),
consumedEntityCount + result.getEntities().size());
opContext, urn, resultConsumer, result.getScrollId(), consumedEntityCount, batchNumber);
}
}

private void processBatch(
@NonNull OperationContext opContext, @NonNull RelatedEntitiesScrollResult batch) {
@NonNull OperationContext opContext,
@NonNull RelatedEntitiesScrollResult batch,
int batchNumber) {
AspectRetriever aspectRetriever = opContext.getRetrieverContext().get().getAspectRetriever();

log.info("BA Update Batch {} started", batchNumber);
Set<Urn> entityUrns =
batch.getEntities().stream()
.map(RelatedEntity::getUrn)
Expand Down Expand Up @@ -147,5 +152,6 @@ private void processBatch(
null,
null));
});
log.info("BA Update Batch {} completed", batchNumber);
}
}

0 comments on commit a44a549

Please sign in to comment.