From fbc4e3d1e330eb703cf09b7aa834a1dca592b188 Mon Sep 17 00:00:00 2001
From: David Leifker <david.leifker@acryl.io>
Date: Fri, 20 Sep 2024 14:55:23 -0500
Subject: [PATCH] test(graphql): fix searchFlags in searchAcrossLineage

* added test
---
 .../search/ScrollAcrossLineageResolver.java   |  12 +-
 .../ScrollAcrossLineageResolverTest.java      | 155 ++++++++++++++++++
 2 files changed, 159 insertions(+), 8 deletions(-)
 create mode 100644 datahub-graphql-core/src/test/java/com/linkedin/datahub/graphql/resolvers/search/ScrollAcrossLineageResolverTest.java

diff --git a/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/resolvers/search/ScrollAcrossLineageResolver.java b/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/resolvers/search/ScrollAcrossLineageResolver.java
index 2c058eb60a7ee3..fff1dfee7ef9c1 100644
--- a/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/resolvers/search/ScrollAcrossLineageResolver.java
+++ b/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/resolvers/search/ScrollAcrossLineageResolver.java
@@ -14,6 +14,7 @@
 import com.linkedin.datahub.graphql.generated.ScrollAcrossLineageResults;
 import com.linkedin.datahub.graphql.resolvers.ResolverUtils;
 import com.linkedin.datahub.graphql.types.common.mappers.LineageFlagsInputMapper;
+import com.linkedin.datahub.graphql.types.common.mappers.SearchFlagsInputMapper;
 import com.linkedin.datahub.graphql.types.entitytype.EntityTypeMapper;
 import com.linkedin.datahub.graphql.types.mappers.UrnScrollAcrossLineageResultsMapper;
 import com.linkedin.entity.client.EntityClient;
@@ -89,7 +90,6 @@ public CompletableFuture<ScrollAcrossLineageResults> get(DataFetchingEnvironment
     if (lineageFlags.getEndTimeMillis() == null && endTimeMillis != null) {
       lineageFlags.setEndTimeMillis(endTimeMillis);
     }
-    ;
 
     com.linkedin.metadata.graph.LineageDirection resolvedDirection =
         com.linkedin.metadata.graph.LineageDirection.valueOf(lineageDirection.toString());
@@ -107,17 +107,13 @@ public CompletableFuture<ScrollAcrossLineageResults> get(DataFetchingEnvironment
                 count);
 
             final SearchFlags searchFlags;
-            final com.linkedin.datahub.graphql.generated.SearchFlags inputFlags =
-                input.getSearchFlags();
+            com.linkedin.datahub.graphql.generated.SearchFlags inputFlags = input.getSearchFlags();
             if (inputFlags != null) {
-              searchFlags =
-                  new SearchFlags()
-                      .setSkipCache(inputFlags.getSkipCache())
-                      .setFulltext(inputFlags.getFulltext())
-                      .setMaxAggValues(inputFlags.getMaxAggValues());
+              searchFlags = SearchFlagsInputMapper.INSTANCE.apply(context, inputFlags);
             } else {
               searchFlags = null;
             }
+
             return UrnScrollAcrossLineageResultsMapper.map(
                 context,
                 _entityClient.scrollAcrossLineage(
diff --git a/datahub-graphql-core/src/test/java/com/linkedin/datahub/graphql/resolvers/search/ScrollAcrossLineageResolverTest.java b/datahub-graphql-core/src/test/java/com/linkedin/datahub/graphql/resolvers/search/ScrollAcrossLineageResolverTest.java
new file mode 100644
index 00000000000000..a12f593253b533
--- /dev/null
+++ b/datahub-graphql-core/src/test/java/com/linkedin/datahub/graphql/resolvers/search/ScrollAcrossLineageResolverTest.java
@@ -0,0 +1,155 @@
+package com.linkedin.datahub.graphql.resolvers.search;
+
+import static com.linkedin.datahub.graphql.TestUtils.getMockAllowContext;
+import static org.mockito.ArgumentMatchers.nullable;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.anyList;
+import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertFalse;
+import static org.testng.Assert.assertTrue;
+
+import com.datahub.authentication.Authentication;
+import com.linkedin.common.UrnArrayArray;
+import com.linkedin.common.urn.UrnUtils;
+import com.linkedin.data.schema.annotation.PathSpecBasedSchemaAnnotationVisitor;
+import com.linkedin.datahub.graphql.QueryContext;
+import com.linkedin.datahub.graphql.generated.EntityType;
+import com.linkedin.datahub.graphql.generated.LineageDirection;
+import com.linkedin.datahub.graphql.generated.ScrollAcrossLineageInput;
+import com.linkedin.datahub.graphql.generated.ScrollAcrossLineageResults;
+import com.linkedin.datahub.graphql.generated.SearchAcrossLineageResult;
+import com.linkedin.datahub.graphql.generated.SearchFlags;
+import com.linkedin.entity.client.EntityClient;
+import com.linkedin.metadata.models.registry.ConfigEntityRegistry;
+import com.linkedin.metadata.models.registry.EntityRegistry;
+import com.linkedin.metadata.search.AggregationMetadataArray;
+import com.linkedin.metadata.search.LineageScrollResult;
+import com.linkedin.metadata.search.LineageSearchEntity;
+import com.linkedin.metadata.search.LineageSearchEntityArray;
+import com.linkedin.metadata.search.MatchedFieldArray;
+import com.linkedin.metadata.search.SearchResultMetadata;
+import graphql.schema.DataFetchingEnvironment;
+import io.datahubproject.metadata.context.OperationContext;
+import java.io.InputStream;
+import java.util.Collections;
+import java.util.List;
+import org.mockito.ArgumentCaptor;
+import org.testng.annotations.BeforeMethod;
+import org.testng.annotations.BeforeTest;
+import org.testng.annotations.Test;
+
+public class ScrollAcrossLineageResolverTest {
+  private static final String SOURCE_URN_STRING =
+      "urn:li:dataset:(urn:li:dataPlatform:foo,bar,PROD)";
+  private static final String TARGET_URN_STRING =
+      "urn:li:dataset:(urn:li:dataPlatform:foo,baz,PROD)";
+  private static final String QUERY = "";
+  private static final int START = 0;
+  private static final int COUNT = 10;
+  private static final Long START_TIMESTAMP_MILLIS = 0L;
+  private static final Long END_TIMESTAMP_MILLIS = 1000L;
+  private EntityClient _entityClient;
+  private DataFetchingEnvironment _dataFetchingEnvironment;
+  private Authentication _authentication;
+  private ScrollAcrossLineageResolver _resolver;
+
+  @BeforeTest
+  public void disableAssert() {
+    PathSpecBasedSchemaAnnotationVisitor.class
+        .getClassLoader()
+        .setClassAssertionStatus(PathSpecBasedSchemaAnnotationVisitor.class.getName(), false);
+  }
+
+  @BeforeMethod
+  public void setupTest() {
+    _entityClient = mock(EntityClient.class);
+    _dataFetchingEnvironment = mock(DataFetchingEnvironment.class);
+    _authentication = mock(Authentication.class);
+    _resolver = new ScrollAcrossLineageResolver(_entityClient);
+  }
+
+  @Test
+  public void testAllEntitiesInitialization() {
+    InputStream inputStream = ClassLoader.getSystemResourceAsStream("entity-registry.yml");
+    EntityRegistry entityRegistry = new ConfigEntityRegistry(inputStream);
+    SearchAcrossLineageResolver resolver =
+        new SearchAcrossLineageResolver(_entityClient, entityRegistry);
+    assertTrue(resolver._allEntities.contains("dataset"));
+    assertTrue(resolver._allEntities.contains("dataFlow"));
+    // Test for case sensitivity
+    assertFalse(resolver._allEntities.contains("dataflow"));
+  }
+
+  @Test
+  public void testSearchAcrossLineage() throws Exception {
+    final QueryContext mockContext = getMockAllowContext();
+    when(mockContext.getAuthentication()).thenReturn(_authentication);
+
+    when(_dataFetchingEnvironment.getContext()).thenReturn(mockContext);
+
+    final SearchFlags searchFlags = new SearchFlags();
+    searchFlags.setFulltext(true);
+
+    final ScrollAcrossLineageInput input = new ScrollAcrossLineageInput();
+    input.setCount(COUNT);
+    input.setDirection(LineageDirection.DOWNSTREAM);
+    input.setOrFilters(Collections.emptyList());
+    input.setQuery(QUERY);
+    input.setTypes(Collections.emptyList());
+    input.setStartTimeMillis(START_TIMESTAMP_MILLIS);
+    input.setEndTimeMillis(END_TIMESTAMP_MILLIS);
+    input.setUrn(SOURCE_URN_STRING);
+    input.setSearchFlags(searchFlags);
+    when(_dataFetchingEnvironment.getArgument(eq("input"))).thenReturn(input);
+
+    final LineageScrollResult lineageSearchResult = new LineageScrollResult();
+    lineageSearchResult.setNumEntities(1);
+    lineageSearchResult.setPageSize(10);
+
+    final SearchResultMetadata searchResultMetadata = new SearchResultMetadata();
+    searchResultMetadata.setAggregations(new AggregationMetadataArray());
+    lineageSearchResult.setMetadata(searchResultMetadata);
+
+    final LineageSearchEntity lineageSearchEntity = new LineageSearchEntity();
+    lineageSearchEntity.setEntity(UrnUtils.getUrn(TARGET_URN_STRING));
+    lineageSearchEntity.setScore(15.0);
+    lineageSearchEntity.setDegree(1);
+    lineageSearchEntity.setMatchedFields(new MatchedFieldArray());
+    lineageSearchEntity.setPaths(new UrnArrayArray());
+    lineageSearchResult.setEntities(new LineageSearchEntityArray(lineageSearchEntity));
+    ArgumentCaptor<OperationContext> opContext = ArgumentCaptor.forClass(OperationContext.class);
+
+    when(_entityClient.scrollAcrossLineage(
+            opContext.capture(),
+            eq(UrnUtils.getUrn(SOURCE_URN_STRING)),
+            eq(com.linkedin.metadata.graph.LineageDirection.DOWNSTREAM),
+            anyList(),
+            eq(QUERY),
+            eq(null),
+            any(),
+            eq(null),
+            nullable(String.class),
+            nullable(String.class),
+            eq(COUNT)))
+        .thenReturn(lineageSearchResult);
+
+    final ScrollAcrossLineageResults results = _resolver.get(_dataFetchingEnvironment).join();
+    assertEquals(results.getCount(), 10);
+    assertEquals(results.getTotal(), 1);
+    assertEquals(
+        opContext.getValue().getSearchContext().getLineageFlags().getStartTimeMillis(),
+        START_TIMESTAMP_MILLIS);
+    assertEquals(
+        opContext.getValue().getSearchContext().getLineageFlags().getEndTimeMillis(),
+        END_TIMESTAMP_MILLIS);
+
+    final List<SearchAcrossLineageResult> entities = results.getSearchResults();
+    assertEquals(entities.size(), 1);
+    final SearchAcrossLineageResult entity = entities.get(0);
+    assertEquals(entity.getEntity().getUrn(), TARGET_URN_STRING);
+    assertEquals(entity.getEntity().getType(), EntityType.DATASET);
+  }
+}