Skip to content

Commit

Permalink
Added relationships APIs to V3. Added these generic APIs to V3 swagge…
Browse files Browse the repository at this point in the history
…r doc. (#10939)
  • Loading branch information
ajoymajumdar authored Jul 18, 2024
1 parent a7ae99c commit 8266b02
Show file tree
Hide file tree
Showing 5 changed files with 285 additions and 224 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@
import io.swagger.v3.oas.annotations.OpenAPIDefinition;
import io.swagger.v3.oas.annotations.info.Info;
import io.swagger.v3.oas.annotations.servers.Server;
import io.swagger.v3.oas.models.Components;
import io.swagger.v3.oas.models.OpenAPI;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.springdoc.core.models.GroupedOpenApi;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
Expand All @@ -38,8 +43,6 @@ public class SpringWebConfig implements WebMvcConfigurer {
private static final Set<String> V1_PACKAGES = Set.of("io.datahubproject.openapi.v1");
private static final Set<String> V2_PACKAGES = Set.of("io.datahubproject.openapi.v2");
private static final Set<String> V3_PACKAGES = Set.of("io.datahubproject.openapi.v3");
private static final Set<String> SCHEMA_REGISTRY_PACKAGES =
Set.of("io.datahubproject.openapi.schema.registry");

private static final Set<String> OPENLINEAGE_PACKAGES =
Set.of("io.datahubproject.openapi.openlineage");
Expand Down Expand Up @@ -74,14 +77,31 @@ public void addFormatters(FormatterRegistry registry) {
public GroupedOpenApi v3OpenApiGroup(final EntityRegistry entityRegistry) {
return GroupedOpenApi.builder()
.group("10-openapi-v3")
.displayName("DataHub Entities v3 (OpenAPI)")
.displayName("DataHub v3 (OpenAPI)")
.addOpenApiCustomizer(
openApi -> {
OpenAPI v3OpenApi = OpenAPIV3Generator.generateOpenApiSpec(entityRegistry);
openApi.setInfo(v3OpenApi.getInfo());
openApi.setTags(Collections.emptyList());
openApi.setPaths(v3OpenApi.getPaths());
openApi.setComponents(v3OpenApi.getComponents());
openApi.getPaths().putAll(v3OpenApi.getPaths());
// Merge components. Swagger does not provide append method to add components.
final Components components = new Components();
final Components oComponents = openApi.getComponents();
final Components v3Components = v3OpenApi.getComponents();
components
.callbacks(concat(oComponents::getCallbacks, v3Components::getCallbacks))
.examples(concat(oComponents::getExamples, v3Components::getExamples))
.extensions(concat(oComponents::getExtensions, v3Components::getExtensions))
.headers(concat(oComponents::getHeaders, v3Components::getHeaders))
.links(concat(oComponents::getLinks, v3Components::getLinks))
.parameters(concat(oComponents::getParameters, v3Components::getParameters))
.requestBodies(
concat(oComponents::getRequestBodies, v3Components::getRequestBodies))
.responses(concat(oComponents::getResponses, v3Components::getResponses))
.schemas(concat(oComponents::getSchemas, v3Components::getSchemas))
.securitySchemes(
concat(oComponents::getSecuritySchemes, v3Components::getSecuritySchemes));
openApi.setComponents(components);
})
.packagesToScan(V3_PACKAGES.toArray(String[]::new))
.build();
Expand Down Expand Up @@ -122,4 +142,14 @@ public GroupedOpenApi openlineageOpenApiGroup() {
.packagesToScan(OPENLINEAGE_PACKAGES.toArray(String[]::new))
.build();
}

/** Concatenates two maps. */
private <K, V> Map<K, V> concat(Supplier<Map<K, V>> a, Supplier<Map<K, V>> b) {
return a.get() == null
? b.get()
: b.get() == null
? a.get()
: Stream.concat(a.get().entrySet().stream(), b.get().entrySet().stream())
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
package io.datahubproject.openapi.controller;

import static com.linkedin.metadata.authorization.ApiGroup.RELATIONSHIP;
import static com.linkedin.metadata.authorization.ApiOperation.READ;

import com.datahub.authentication.Authentication;
import com.datahub.authentication.AuthenticationContext;
import com.datahub.authorization.AuthUtil;
import com.datahub.authorization.AuthorizerChain;
import com.linkedin.common.urn.Urn;
import com.linkedin.common.urn.UrnUtils;
import com.linkedin.metadata.aspect.models.graph.Edge;
import com.linkedin.metadata.aspect.models.graph.RelatedEntities;
import com.linkedin.metadata.aspect.models.graph.RelatedEntitiesScrollResult;
import com.linkedin.metadata.graph.elastic.ElasticSearchGraphService;
import com.linkedin.metadata.models.registry.EntityRegistry;
import com.linkedin.metadata.query.filter.RelationshipDirection;
import com.linkedin.metadata.query.filter.RelationshipFilter;
import com.linkedin.metadata.search.utils.QueryUtils;
import io.datahubproject.openapi.exception.UnauthorizedException;
import io.datahubproject.openapi.models.GenericScrollResult;
import io.datahubproject.openapi.v2.models.GenericRelationship;
import io.swagger.v3.oas.annotations.Operation;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;

public abstract class GenericRelationshipController {

@Autowired private EntityRegistry entityRegistry;
@Autowired private ElasticSearchGraphService graphService;
@Autowired private AuthorizerChain authorizationChain;

/**
* Returns relationship edges by type
*
* @param relationshipType the relationship type
* @param count number of results
* @param scrollId scrolling id
* @return list of relation edges
*/
@GetMapping(value = "/{relationshipType}", produces = MediaType.APPLICATION_JSON_VALUE)
@Operation(summary = "Scroll relationships of the given type.")
public ResponseEntity<GenericScrollResult<GenericRelationship>> getRelationshipsByType(
@PathVariable("relationshipType") String relationshipType,
@RequestParam(value = "count", defaultValue = "10") Integer count,
@RequestParam(value = "scrollId", required = false) String scrollId) {

Authentication authentication = AuthenticationContext.getAuthentication();
if (!AuthUtil.isAPIAuthorized(authentication, authorizationChain, RELATIONSHIP, READ)) {
throw new UnauthorizedException(
authentication.getActor().toUrnStr()
+ " is unauthorized to "
+ READ
+ " "
+ RELATIONSHIP);
}

RelatedEntitiesScrollResult result =
graphService.scrollRelatedEntities(
null,
null,
null,
null,
List.of(relationshipType),
new RelationshipFilter().setDirection(RelationshipDirection.UNDIRECTED),
Edge.EDGE_SORT_CRITERION,
scrollId,
count,
null,
null);

if (!AuthUtil.isAPIAuthorizedUrns(
authentication,
authorizationChain,
RELATIONSHIP,
READ,
result.getEntities().stream()
.flatMap(
edge ->
Stream.of(
UrnUtils.getUrn(edge.getSourceUrn()),
UrnUtils.getUrn(edge.getDestinationUrn())))
.collect(Collectors.toSet()))) {
throw new UnauthorizedException(
authentication.getActor().toUrnStr()
+ " is unauthorized to "
+ READ
+ " "
+ RELATIONSHIP);
}

return ResponseEntity.ok(
GenericScrollResult.<GenericRelationship>builder()
.results(toGenericRelationships(result.getEntities()))
.scrollId(result.getScrollId())
.build());
}

/**
* Returns edges for a given urn
*
* @param relationshipTypes types of edges
* @param direction direction of the edges
* @param count number of results
* @param scrollId scroll id
* @return urn edges
*/
@GetMapping(value = "/{entityName}/{entityUrn}", produces = MediaType.APPLICATION_JSON_VALUE)
@Operation(summary = "Scroll relationships from a given entity.")
public ResponseEntity<GenericScrollResult<GenericRelationship>> getRelationshipsByEntity(
@PathVariable("entityName") String entityName,
@PathVariable("entityUrn") String entityUrn,
@RequestParam(value = "relationshipType[]", required = false, defaultValue = "*")
String[] relationshipTypes,
@RequestParam(value = "direction", defaultValue = "OUTGOING") String direction,
@RequestParam(value = "count", defaultValue = "10") Integer count,
@RequestParam(value = "scrollId", required = false) String scrollId) {

final RelatedEntitiesScrollResult result;

Authentication authentication = AuthenticationContext.getAuthentication();
if (!AuthUtil.isAPIAuthorizedUrns(
authentication,
authorizationChain,
RELATIONSHIP,
READ,
List.of(UrnUtils.getUrn(entityUrn)))) {
throw new UnauthorizedException(
authentication.getActor().toUrnStr()
+ " is unauthorized to "
+ READ
+ " "
+ RELATIONSHIP);
}

switch (RelationshipDirection.valueOf(direction.toUpperCase())) {
case INCOMING -> result =
graphService.scrollRelatedEntities(
null,
null,
null,
null,
relationshipTypes.length > 0 && !relationshipTypes[0].equals("*")
? Arrays.stream(relationshipTypes).toList()
: List.of(),
new RelationshipFilter()
.setDirection(RelationshipDirection.UNDIRECTED)
.setOr(QueryUtils.newFilter("destination.urn", entityUrn).getOr()),
Edge.EDGE_SORT_CRITERION,
scrollId,
count,
null,
null);
case OUTGOING -> result =
graphService.scrollRelatedEntities(
null,
null,
null,
null,
relationshipTypes.length > 0 && !relationshipTypes[0].equals("*")
? Arrays.stream(relationshipTypes).toList()
: List.of(),
new RelationshipFilter()
.setDirection(RelationshipDirection.UNDIRECTED)
.setOr(QueryUtils.newFilter("source.urn", entityUrn).getOr()),
Edge.EDGE_SORT_CRITERION,
scrollId,
count,
null,
null);
default -> throw new IllegalArgumentException("Direction must be INCOMING or OUTGOING");
}

if (!AuthUtil.isAPIAuthorizedUrns(
authentication,
authorizationChain,
RELATIONSHIP,
READ,
result.getEntities().stream()
.flatMap(
edge ->
Stream.of(
UrnUtils.getUrn(edge.getSourceUrn()),
UrnUtils.getUrn(edge.getDestinationUrn())))
.collect(Collectors.toSet()))) {
throw new UnauthorizedException(
authentication.getActor().toUrnStr()
+ " is unauthorized to "
+ READ
+ " "
+ RELATIONSHIP);
}

return ResponseEntity.ok(
GenericScrollResult.<GenericRelationship>builder()
.results(toGenericRelationships(result.getEntities()))
.scrollId(result.getScrollId())
.build());
}

private List<GenericRelationship> toGenericRelationships(List<RelatedEntities> relatedEntities) {
return relatedEntities.stream()
.map(
result -> {
Urn source = UrnUtils.getUrn(result.getSourceUrn());
Urn dest = UrnUtils.getUrn(result.getDestinationUrn());
return GenericRelationship.builder()
.relationshipType(result.getRelationshipType())
.source(GenericRelationship.GenericNode.fromUrn(source))
.destination(GenericRelationship.GenericNode.fromUrn(dest))
.build();
})
.collect(Collectors.toList());
}
}
Loading

0 comments on commit 8266b02

Please sign in to comment.