diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index 71e816ce6..c385c45e2 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -24,6 +24,7 @@ import io.substrait.relation.Set; import io.substrait.relation.Sort; import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.NestedLoopJoin; import io.substrait.type.ImmutableType; import io.substrait.type.NamedStruct; import io.substrait.type.Type; @@ -218,6 +219,30 @@ private NamedScan namedScan( return NamedScan.builder().names(tableName).initialSchema(namedStruct).remap(remap).build(); } + public NestedLoopJoin nestedLoopJoin( + Function conditionFn, + NestedLoopJoin.JoinType joinType, + Rel left, + Rel right) { + return nestedLoopJoin(conditionFn, joinType, Optional.empty(), left, right); + } + + private NestedLoopJoin nestedLoopJoin( + Function conditionFn, + NestedLoopJoin.JoinType joinType, + Optional remap, + Rel left, + Rel right) { + var condition = conditionFn.apply(new JoinInput(left, right)); + return NestedLoopJoin.builder() + .left(left) + .right(right) + .condition(condition) + .joinType(joinType) + .remap(remap) + .build(); + } + public Project project(Function> expressionsFn, Rel input) { return project(expressionsFn, Optional.empty(), input); } @@ -286,6 +311,16 @@ public List fieldReferences(Rel input, int... indexes) { .collect(java.util.stream.Collectors.toList()); } + public FieldReference fieldReference(List inputs, int index) { + return ImmutableFieldReference.newInputRelReference(index, inputs); + } + + public List fieldReferences(List inputs, int... indexes) { + return Arrays.stream(indexes) + .mapToObj(index -> fieldReference(inputs, index)) + .collect(java.util.stream.Collectors.toList()); + } + public Expression cast(Expression input, Type type) { return Cast.builder() .input(input) diff --git a/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java b/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java index 645f692e2..52a70bf33 100644 --- a/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java +++ b/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java @@ -1,6 +1,7 @@ package io.substrait.relation; import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.NestedLoopJoin; public abstract class AbstractRelVisitor implements RelVisitor { @@ -31,6 +32,11 @@ public OUTPUT visit(Join join) throws EXCEPTION { return visitFallback(join); } + @Override + public OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION { + return visitFallback(nestedLoopJoin); + } + @Override public OUTPUT visit(Set set) throws EXCEPTION { return visitFallback(set); diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index bb479b83b..05c595419 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -17,6 +17,7 @@ import io.substrait.proto.FilterRel; import io.substrait.proto.HashJoinRel; import io.substrait.proto.JoinRel; +import io.substrait.proto.NestedLoopJoinRel; import io.substrait.proto.ProjectRel; import io.substrait.proto.ReadRel; import io.substrait.proto.SetRel; @@ -27,6 +28,7 @@ import io.substrait.relation.files.ImmutableFileFormat; import io.substrait.relation.files.ImmutableFileOrFiles; import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.NestedLoopJoin; import io.substrait.type.ImmutableNamedStruct; import io.substrait.type.NamedStruct; import io.substrait.type.Type; @@ -77,6 +79,9 @@ public Rel from(io.substrait.proto.Rel rel) { case JOIN -> { return newJoin(rel.getJoin()); } + case NESTED_LOOP_JOIN -> { + return newNestedLoopJoin(rel.getNestedLoopJoin()); + } case SET -> { return newSet(rel.getSet()); } @@ -532,6 +537,33 @@ private Rel newHashJoin(HashJoinRel rel) { return builder.build(); } + private NestedLoopJoin newNestedLoopJoin(NestedLoopJoinRel rel) { + Rel left = from(rel.getLeft()); + Rel right = from(rel.getRight()); + Type.Struct leftStruct = left.getRecordType(); + Type.Struct rightStruct = right.getRecordType(); + Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build(); + var converter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); + var builder = + NestedLoopJoin.builder() + .left(left) + .right(right) + .condition( + // defaults to true (aka cartesian join) if the join expression is missing + rel.hasExpression() + ? converter.from(rel.getExpression()) + : Expression.BoolLiteral.builder().value(true).build()) + .joinType(NestedLoopJoin.JoinType.fromProto(rel.getType())); + + builder + .commonExtension(optionalAdvancedExtension(rel.getCommon())) + .remap(optionalRelmap(rel.getCommon())); + if (rel.hasAdvancedExtension()) { + builder.extension(advancedExtension(rel.getAdvancedExtension())); + } + return builder.build(); + } + private static Optional optionalRelmap(io.substrait.proto.RelCommon relCommon) { return Optional.ofNullable( relCommon.hasEmit() ? Rel.Remap.of(relCommon.getEmit().getOutputMappingList()) : null); diff --git a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java index 0dddfbd98..e67a6e112 100644 --- a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java @@ -8,6 +8,8 @@ import io.substrait.expression.ImmutableFieldReference; import io.substrait.relation.physical.HashJoin; import io.substrait.relation.physical.ImmutableHashJoin; +import io.substrait.relation.physical.ImmutableNestedLoopJoin; +import io.substrait.relation.physical.NestedLoopJoin; import io.substrait.type.Type; import java.util.ArrayList; import java.util.List; @@ -120,6 +122,23 @@ public Optional visit(Join join) throws RuntimeException { .build()); } + @Override + public Optional visit(NestedLoopJoin nestedLoopJoin) throws RuntimeException { + var left = nestedLoopJoin.getLeft().accept(this); + var right = nestedLoopJoin.getRight().accept(this); + var condition = visitExpression(nestedLoopJoin.getCondition()); + if (allEmpty(left, right, condition)) { + return Optional.empty(); + } + return Optional.of( + ImmutableNestedLoopJoin.builder() + .from(nestedLoopJoin) + .left(left.orElse(nestedLoopJoin.getLeft())) + .right(right.orElse(nestedLoopJoin.getRight())) + .condition(condition.orElse(nestedLoopJoin.getCondition())) + .build()); + } + @Override public Optional visit(Set set) throws RuntimeException { return transformList(set.getInputs(), t -> t.accept(this)) diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index 566ae5c3a..2ab4c0527 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -15,6 +15,7 @@ import io.substrait.proto.FilterRel; import io.substrait.proto.HashJoinRel; import io.substrait.proto.JoinRel; +import io.substrait.proto.NestedLoopJoinRel; import io.substrait.proto.ProjectRel; import io.substrait.proto.ReadRel; import io.substrait.proto.Rel; @@ -24,6 +25,7 @@ import io.substrait.proto.SortRel; import io.substrait.relation.files.FileOrFiles; import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.NestedLoopJoin; import io.substrait.type.proto.TypeProtoConverter; import java.util.Collection; import java.util.List; @@ -179,6 +181,20 @@ public Rel visit(Join join) throws RuntimeException { return Rel.newBuilder().setJoin(builder).build(); } + @Override + public Rel visit(NestedLoopJoin nestedLoopJoin) throws RuntimeException { + var builder = + NestedLoopJoinRel.newBuilder() + .setCommon(common(nestedLoopJoin)) + .setLeft(toProto(nestedLoopJoin.getLeft())) + .setRight(toProto(nestedLoopJoin.getRight())) + .setExpression(toProto(nestedLoopJoin.getCondition())) + .setType(nestedLoopJoin.getJoinType().toProto()); + + nestedLoopJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto())); + return Rel.newBuilder().setNestedLoopJoin(builder).build(); + } + @Override public Rel visit(Set set) throws RuntimeException { var builder = SetRel.newBuilder().setCommon(common(set)).setOp(set.getSetOp().toProto()); diff --git a/core/src/main/java/io/substrait/relation/RelVisitor.java b/core/src/main/java/io/substrait/relation/RelVisitor.java index e8e78aaf7..38b70816c 100644 --- a/core/src/main/java/io/substrait/relation/RelVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelVisitor.java @@ -1,6 +1,7 @@ package io.substrait.relation; import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.NestedLoopJoin; public interface RelVisitor { OUTPUT visit(Aggregate aggregate) throws EXCEPTION; @@ -13,6 +14,8 @@ public interface RelVisitor { OUTPUT visit(Join join) throws EXCEPTION; + OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION; + OUTPUT visit(Set set) throws EXCEPTION; OUTPUT visit(NamedScan namedScan) throws EXCEPTION; diff --git a/core/src/main/java/io/substrait/relation/physical/NestedLoopJoin.java b/core/src/main/java/io/substrait/relation/physical/NestedLoopJoin.java new file mode 100644 index 000000000..722fdb471 --- /dev/null +++ b/core/src/main/java/io/substrait/relation/physical/NestedLoopJoin.java @@ -0,0 +1,79 @@ +package io.substrait.relation.physical; + +import io.substrait.expression.Expression; +import io.substrait.proto.NestedLoopJoinRel; +import io.substrait.relation.BiRel; +import io.substrait.relation.HasExtension; +import io.substrait.relation.RelVisitor; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.util.stream.Stream; +import org.immutables.value.Value; + +@Value.Immutable +public abstract class NestedLoopJoin extends BiRel implements HasExtension { + + public abstract Expression getCondition(); + + public abstract JoinType getJoinType(); + + public static enum JoinType { + UNKNOWN(NestedLoopJoinRel.JoinType.JOIN_TYPE_UNSPECIFIED), + INNER(NestedLoopJoinRel.JoinType.JOIN_TYPE_INNER), + OUTER(NestedLoopJoinRel.JoinType.JOIN_TYPE_OUTER), + LEFT(NestedLoopJoinRel.JoinType.JOIN_TYPE_LEFT), + RIGHT(NestedLoopJoinRel.JoinType.JOIN_TYPE_RIGHT), + LEFT_SEMI(NestedLoopJoinRel.JoinType.JOIN_TYPE_LEFT_SEMI), + RIGHT_SEMI(NestedLoopJoinRel.JoinType.JOIN_TYPE_RIGHT_SEMI), + LEFT_ANTI(NestedLoopJoinRel.JoinType.JOIN_TYPE_LEFT_ANTI), + RIGHT_ANTI(NestedLoopJoinRel.JoinType.JOIN_TYPE_RIGHT_ANTI); + + private NestedLoopJoinRel.JoinType proto; + + JoinType(NestedLoopJoinRel.JoinType proto) { + this.proto = proto; + } + + public NestedLoopJoinRel.JoinType toProto() { + return proto; + } + + public static JoinType fromProto(NestedLoopJoinRel.JoinType proto) { + for (var v : values()) { + if (v.proto == proto) { + return v; + } + } + + throw new IllegalArgumentException("Unknown type: " + proto); + } + } + + @Override + protected Type.Struct deriveRecordType() { + Stream leftTypes = + switch (getJoinType()) { + case RIGHT, OUTER -> getLeft().getRecordType().fields().stream() + .map(TypeCreator::asNullable); + case RIGHT_ANTI, RIGHT_SEMI -> Stream.empty(); + default -> getLeft().getRecordType().fields().stream(); + }; + Stream rightTypes = + switch (getJoinType()) { + case LEFT, OUTER -> getRight().getRecordType().fields().stream() + .map(TypeCreator::asNullable); + case LEFT_ANTI, LEFT_SEMI -> Stream.empty(); + default -> getRight().getRecordType().fields().stream(); + }; + return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); + } + + @Override + public O accept(RelVisitor visitor) throws E { + return visitor.visit(this); + } + + public static ImmutableNestedLoopJoin.Builder builder() { + return ImmutableNestedLoopJoin.builder(); + } +} diff --git a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java index 5417625bf..b076f03c1 100644 --- a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java @@ -23,6 +23,7 @@ import io.substrait.relation.Sort; import io.substrait.relation.VirtualTableScan; import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.NestedLoopJoin; import io.substrait.relation.utils.StringHolder; import io.substrait.relation.utils.StringHolderHandlingProtoRelConverter; import io.substrait.type.NamedStruct; @@ -186,6 +187,19 @@ void hashJoin() { verifyRoundTrip(relWithoutKeys); } + @Test + void nestedLoopJoin() { + Rel rel = + NestedLoopJoin.builder() + .from( + b.nestedLoopJoin( + __ -> b.bool(true), NestedLoopJoin.JoinType.INNER, commonTable, commonTable)) + .commonExtension(commonExtension) + .extension(relExtension) + .build(); + verifyRoundTrip(rel); + } + @Test void project() { Rel rel = diff --git a/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java index 9b0156dc9..8ae8a7da5 100644 --- a/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java @@ -3,6 +3,7 @@ import io.substrait.TestBase; import io.substrait.relation.Rel; import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.NestedLoopJoin; import java.util.Arrays; import java.util.List; import org.junit.jupiter.api.Test; @@ -31,4 +32,19 @@ void hashJoin() { .build(); verifyRoundTrip(relWithoutKeys); } + + @Test + void nestedLoopJoin() { + List inputRels = Arrays.asList(leftTable, rightTable); + Rel rel = + NestedLoopJoin.builder() + .from( + b.nestedLoopJoin( + __ -> b.equal(b.fieldReference(inputRels, 0), b.fieldReference(inputRels, 5)), + NestedLoopJoin.JoinType.INNER, + leftTable, + rightTable)) + .build(); + verifyRoundTrip(rel); + } }