Skip to content

Commit

Permalink
Merge pull request #52 from GoogleCloudPlatform/main
Browse files Browse the repository at this point in the history
fix: eager source row fetching logic (GoogleCloudPlatform#2071)
  • Loading branch information
taherkl authored Jan 8, 2025
2 parents ef5ae8d + 9d39b07 commit c6442ba
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@
import org.neo4j.importer.v1.targets.RelationshipTarget;
import org.neo4j.importer.v1.targets.Target;
import org.neo4j.importer.v1.targets.TargetType;
import org.neo4j.importer.v1.targets.Targets;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -275,6 +274,10 @@ public void run() {
Entry::getKey, mapping(Entry::getValue, Collectors.<PCollection<?>>toList())));
var sourceRows = new ArrayList<PCollection<?>>(importSpecification.getSources().size());
var targetRows = new HashMap<TargetType, List<PCollection<?>>>(targetCount());
var allActiveTargets =
importSpecification.getTargets().getAll().stream()
.filter(Target::isActive)
.collect(toList());
var allActiveNodeTargets =
importSpecification.getTargets().getNodes().stream()
.filter(Target::isActive)
Expand All @@ -283,40 +286,42 @@ public void run() {
////////////////////////////
// Process sources
for (var source : importSpecification.getSources()) {
String sourceName = source.getName();
var activeSourceTargets =
allActiveTargets.stream()
.filter(target -> target.getSource().equals(sourceName))
.collect(toList());
if (activeSourceTargets.isEmpty()) {
return;
}

// get provider implementation for source
Provider provider = ProviderFactory.of(source, targetSequence);
provider.configure(optionsParams);
PCollection<Row> sourceMetadata =
pipeline.apply(
String.format("Metadata for source %s", source.getName()), provider.queryMetadata());
String.format("Metadata for source %s", sourceName), provider.queryMetadata());
sourceRows.add(sourceMetadata);
Schema sourceBeamSchema = sourceMetadata.getSchema();
processingQueue.addToQueue(
ArtifactType.source, false, source.getName(), defaultActionContext, sourceMetadata);
PCollection<Row> nullableSourceBeamRows = null;
ArtifactType.source, false, sourceName, defaultActionContext, sourceMetadata);

////////////////////////////
// Optimization: if single source query, reuse this PCollection rather than write it again
boolean targetsHaveTransforms = ModelUtils.targetsHaveTransforms(importSpecification, source);
if (!targetsHaveTransforms || !provider.supportsSqlPushDown()) {
// Optimization: if some of the current source's targets either
// - do not alter the source query (i.e. define no transformations)
// - or the source provider does not support SQL pushdown
// then the source PCollection can be defined here and reused across all the relevant targets
PCollection<Row> nullableSourceBeamRows = null;
if (!provider.supportsSqlPushDown()
|| activeSourceTargets.stream()
.anyMatch(target -> !ModelUtils.targetHasTransforms(target))) {
nullableSourceBeamRows =
pipeline
.apply("Query " + source.getName(), provider.querySourceBeamRows(sourceBeamSchema))
.apply("Query " + sourceName, provider.querySourceBeamRows(sourceBeamSchema))
.setRowSchema(sourceBeamSchema);
}

String sourceName = source.getName();

////////////////////////////
// Optimization: if we're not mixing nodes and edges, then run in parallel
// For relationship updates, max workers should be max 2. This parameter is job configurable.

////////////////////////////
// No optimization possible so write nodes then edges.
// Write node targets
List<NodeTarget> nodeTargets =
getActiveTargetsBySourceAndType(importSpecification, sourceName, TargetType.NODE);
List<NodeTarget> nodeTargets = getTargetsByType(activeSourceTargets, TargetType.NODE);
for (NodeTarget target : nodeTargets) {
TargetQuerySpec targetQuerySpec =
new TargetQuerySpecBuilder()
Expand All @@ -327,7 +332,7 @@ public void run() {
String nodeStepDescription =
targetSequence.getSequenceNumber(target)
+ ": "
+ source.getName()
+ sourceName
+ "->"
+ target.getName()
+ " nodes";
Expand Down Expand Up @@ -371,7 +376,7 @@ public void run() {
////////////////////////////
// Write relationship targets
List<RelationshipTarget> relationshipTargets =
getActiveTargetsBySourceAndType(importSpecification, sourceName, TargetType.RELATIONSHIP);
getTargetsByType(activeSourceTargets, TargetType.RELATIONSHIP);
for (var target : relationshipTargets) {
var targetQuerySpec =
new TargetQuerySpecBuilder()
Expand All @@ -383,14 +388,14 @@ public void run() {
.endNodeTarget(
findNodeTargetByName(allActiveNodeTargets, target.getEndNodeReference()))
.build();
PCollection<Row> preInsertBeamRows;
String relationshipStepDescription =
targetSequence.getSequenceNumber(target)
+ ": "
+ source.getName()
+ sourceName
+ "->"
+ target.getName()
+ " edges";
PCollection<Row> preInsertBeamRows;
if (ModelUtils.targetHasTransforms(target)) {
preInsertBeamRows =
pipeline.apply(
Expand Down Expand Up @@ -439,12 +444,12 @@ public void run() {
////////////////////////////
// Custom query targets
List<CustomQueryTarget> customQueryTargets =
getActiveTargetsBySourceAndType(importSpecification, sourceName, TargetType.QUERY);
getTargetsByType(activeSourceTargets, TargetType.QUERY);
for (Target target : customQueryTargets) {
String customQueryStepDescription =
targetSequence.getSequenceNumber(target)
+ ": "
+ source.getName()
+ sourceName
+ "->"
+ target.getName()
+ " (custom query)";
Expand All @@ -455,6 +460,8 @@ public void run() {
processingQueue.waitOnCollections(
target.getDependencies(), customQueryStepDescription));

// note: nullableSourceBeamRows is guaranteed to be non-null here since custom query targets
// cannot define source transformations
PCollection<Row> blockingReturn =
nullableSourceBeamRows
.apply(
Expand Down Expand Up @@ -581,15 +588,10 @@ private static NodeTarget findNodeTargetByName(List<NodeTarget> nodes, String re
}

@SuppressWarnings("unchecked")
private <T extends Target> List<T> getActiveTargetsBySourceAndType(
ImportSpecification importSpecification, String sourceName, TargetType targetType) {
Targets targets = importSpecification.getTargets();
return targets.getAll().stream()
.filter(
target ->
target.getTargetType() == targetType
&& target.isActive()
&& sourceName.equals(target.getSource()))
private <T extends Target> List<T> getTargetsByType(
List<Target> activeSourceTargets, TargetType targetType) {
return activeSourceTargets.stream()
.filter(target -> target.getTargetType() == targetType)
.map(target -> (T) target)
.collect(toList());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
Expand All @@ -38,8 +37,6 @@
import net.sf.jsqlparser.statement.select.PlainSelect;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.commons.lang3.StringUtils;
import org.neo4j.importer.v1.ImportSpecification;
import org.neo4j.importer.v1.sources.Source;
import org.neo4j.importer.v1.targets.Aggregation;
import org.neo4j.importer.v1.targets.EntityTarget;
import org.neo4j.importer.v1.targets.NodeTarget;
Expand All @@ -58,12 +55,6 @@ public class ModelUtils {
private static final Pattern variablePattern = Pattern.compile("(\\$([a-zA-Z0-9_]+))");
private static final Logger LOG = LoggerFactory.getLogger(ModelUtils.class);

public static boolean targetsHaveTransforms(ImportSpecification jobSpec, Source source) {
return jobSpec.getTargets().getAll().stream()
.filter(target -> target.isActive() && Objects.equals(target.getSource(), source.getName()))
.anyMatch(ModelUtils::targetHasTransforms);
}

public static boolean targetHasTransforms(Target target) {
if (target.getTargetType() == TargetType.QUERY) {
return false;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright (C) 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package com.google.cloud.teleport.v2.neo4j.templates;

import static com.google.cloud.teleport.v2.neo4j.templates.Connections.jsonBasicPayload;
import static com.google.cloud.teleport.v2.neo4j.templates.Resources.contentOf;
import static org.apache.beam.it.truthmatchers.PipelineAsserts.assertThatResult;

import com.google.cloud.teleport.metadata.TemplateIntegrationTest;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import org.apache.beam.it.common.PipelineLauncher.LaunchConfig;
import org.apache.beam.it.common.PipelineLauncher.LaunchInfo;
import org.apache.beam.it.common.PipelineOperator.Result;
import org.apache.beam.it.common.TestProperties;
import org.apache.beam.it.common.utils.ResourceManagerUtils;
import org.apache.beam.it.gcp.TemplateTestBase;
import org.apache.beam.it.neo4j.Neo4jResourceManager;
import org.apache.beam.it.neo4j.conditions.Neo4jQueryCheck;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@Category(TemplateIntegrationTest.class)
@TemplateIntegrationTest(GoogleCloudToNeo4j.class)
@RunWith(JUnit4.class)
public class SyntheticFieldsIT extends TemplateTestBase {

private Neo4jResourceManager neo4jClient;

@Before
public void setup() {
neo4jClient =
Neo4jResourceManager.builder(testName)
.setAdminPassword("letmein!")
.setHost(TestProperties.hostIp())
.build();
}

@After
public void tearDown() {
ResourceManagerUtils.cleanResources(neo4jClient);
}

@Test
// TODO: generate bigquery data set once import-spec supports value interpolation
public void importsStackoverflowUsers() throws IOException {
String spec = contentOf("/testing-specs/synthetic-fields/spec.yml");
gcsClient.createArtifact("spec.yml", spec);
gcsClient.createArtifact("neo4j-connection.json", jsonBasicPayload(neo4jClient));

LaunchConfig.Builder options =
LaunchConfig.builder(testName, specPath)
.addParameter("jobSpecUri", getGcsPath("spec.yml"))
.addParameter("neo4jConnectionUri", getGcsPath("neo4j-connection.json"));
LaunchInfo info = launchTemplate(options);

Result result =
pipelineOperator()
.waitForCondition(
createConfig(info),
Neo4jQueryCheck.builder(neo4jClient)
.setQuery("MATCH (u:User) RETURN count(u) AS count")
.setExpectedResult(List.of(Map.of("count", 10L)))
.build(),
Neo4jQueryCheck.builder(neo4jClient)
.setQuery(
"MATCH (l:Letter) WITH DISTINCT toUpper(l.char) AS char ORDER BY char ASC RETURN collect(char) AS chars")
.setExpectedResult(
List.of(Map.of("chars", List.of("A", "C", "G", "I", "J", "T", "W"))))
.build());
assertThatResult(result).meetsConditions();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
version: '1'
sources:
- type: bigquery
name: so_users
# once value interpolation is supported by import-spec, this public data set query
# will be replaced by a query against a generated test bigquery data set
query: |-
SELECT id, display_name
FROM
`bigquery-public-data.stackoverflow.users`
ORDER BY id ASC
LIMIT 10
targets:
nodes:
- name: users
source: so_users
write_mode: merge
labels: [User]
source_transformations:
aggregations:
- expression: max(id)
field_name: max_id
properties:
- source_field: id
target_property: id
- source_field: display_name
target_property: name
- source_field: max_id
target_property: max_id
schema:
key_constraints:
- name: key_user_id
label: User
properties: [id]
queries:
# here we just need a custom query from the same source as another node/rel target that defines transformations
- name: user_name_starts_with
depends_on:
- users
source: so_users
query: |-
UNWIND $rows AS row
MATCH (user:User {id: row.id})
MERGE (letter:Letter {char: left(user.name, 1)})
CREATE (user)-[:NAME_STARTS_WITH]->(letter)

0 comments on commit c6442ba

Please sign in to comment.