diff --git a/src/main/java/org/rumbledb/compiler/RuntimeIteratorVisitor.java b/src/main/java/org/rumbledb/compiler/RuntimeIteratorVisitor.java index 26895d6090..51962288bf 100644 --- a/src/main/java/org/rumbledb/compiler/RuntimeIteratorVisitor.java +++ b/src/main/java/org/rumbledb/compiler/RuntimeIteratorVisitor.java @@ -126,6 +126,7 @@ import org.rumbledb.runtime.typing.InstanceOfIterator; import org.rumbledb.runtime.typing.TreatIterator; import org.rumbledb.runtime.primary.ArrayRuntimeIterator; +import org.rumbledb.runtime.primary.AtMostOneItemVariableReferenceIterator; import org.rumbledb.runtime.primary.BooleanRuntimeIterator; import org.rumbledb.runtime.primary.ContextExpressionIterator; import org.rumbledb.runtime.primary.DecimalRuntimeIterator; @@ -136,6 +137,7 @@ import org.rumbledb.runtime.primary.StringRuntimeIterator; import org.rumbledb.runtime.primary.VariableReferenceIterator; import org.rumbledb.types.SequenceType; +import org.rumbledb.types.SequenceType.Arity; import java.util.ArrayList; import java.util.LinkedHashMap; @@ -308,9 +310,11 @@ private RuntimeTupleIterator visitFlowrClause( clause.getMetadata() ); } else if (clause instanceof CountClause) { + RuntimeIterator variable = this.visit(((CountClause) clause).getCountVariable(), argument); + Name variableName = ((AtMostOneItemVariableReferenceIterator) variable).getVariableName(); return new CountClauseSparkIterator( previousIterator, - this.visit(((CountClause) clause).getCountVariable(), argument), + variableName, clause.getHighestExecutionMode(this.visitorConfig), clause.getMetadata() ); @@ -320,12 +324,26 @@ private RuntimeTupleIterator visitFlowrClause( @Override public RuntimeIterator visitVariableReference(VariableReferenceExpression expression, RuntimeIterator argument) { - RuntimeIterator runtimeIterator = new VariableReferenceIterator( - expression.getVariableName(), - expression.getType(), - expression.getHighestExecutionMode(this.visitorConfig), - expression.getMetadata() - ); + RuntimeIterator runtimeIterator = null; + if ( + expression.getType().isEmptySequence() + || expression.getType().getArity().equals(Arity.One) + || expression.getType().getArity().equals(Arity.OneOrZero) + ) { + runtimeIterator = new AtMostOneItemVariableReferenceIterator( + expression.getVariableName(), + expression.getType(), + expression.getHighestExecutionMode(this.visitorConfig), + expression.getMetadata() + ); + } else { + runtimeIterator = new VariableReferenceIterator( + expression.getVariableName(), + expression.getType(), + expression.getHighestExecutionMode(this.visitorConfig), + expression.getMetadata() + ); + } runtimeIterator.setStaticContext(expression.getStaticContext()); return runtimeIterator; } diff --git a/src/main/java/org/rumbledb/compiler/StaticContextVisitor.java b/src/main/java/org/rumbledb/compiler/StaticContextVisitor.java index ebbf1cf8e3..c080321649 100644 --- a/src/main/java/org/rumbledb/compiler/StaticContextVisitor.java +++ b/src/main/java/org/rumbledb/compiler/StaticContextVisitor.java @@ -313,6 +313,7 @@ public StaticContext visitGroupByClause(GroupByClause clause, StaticContext argu ); } } + // TODO set cardinalities to * for all non-grouping input tuple variables clause.initHighestExecutionMode(this.visitorConfig); return groupByClauseContext; } diff --git a/src/main/java/org/rumbledb/runtime/flwor/clauses/CountClauseSparkIterator.java b/src/main/java/org/rumbledb/runtime/flwor/clauses/CountClauseSparkIterator.java index c697172b05..7fe6fad943 100644 --- a/src/main/java/org/rumbledb/runtime/flwor/clauses/CountClauseSparkIterator.java +++ b/src/main/java/org/rumbledb/runtime/flwor/clauses/CountClauseSparkIterator.java @@ -32,11 +32,9 @@ import org.rumbledb.exceptions.OurBadException; import org.rumbledb.expressions.ExecutionMode; import org.rumbledb.items.ItemFactory; -import org.rumbledb.runtime.RuntimeIterator; import org.rumbledb.runtime.RuntimeTupleIterator; import org.rumbledb.runtime.flwor.FlworDataFrameUtils; import org.rumbledb.runtime.flwor.udfs.LongSerializeUDF; -import org.rumbledb.runtime.primary.VariableReferenceIterator; import sparksoniq.jsoniq.tuple.FlworTuple; @@ -57,12 +55,12 @@ public class CountClauseSparkIterator extends RuntimeTupleIterator { public CountClauseSparkIterator( RuntimeTupleIterator child, - RuntimeIterator variableReference, + Name variableName, ExecutionMode executionMode, ExceptionMetadata iteratorMetadata ) { super(child, executionMode, iteratorMetadata); - this.variableName = ((VariableReferenceIterator) variableReference).getVariableName(); + this.variableName = variableName; this.currentCountIndex = 1; // indices start at 1 in JSONiq } diff --git a/src/main/java/org/rumbledb/runtime/primary/AtMostOneItemVariableReferenceIterator.java b/src/main/java/org/rumbledb/runtime/primary/AtMostOneItemVariableReferenceIterator.java new file mode 100644 index 0000000000..b0029e06bf --- /dev/null +++ b/src/main/java/org/rumbledb/runtime/primary/AtMostOneItemVariableReferenceIterator.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Authors: Stefan Irimescu, Can Berker Cikis + * + */ + +package org.rumbledb.runtime.primary; + +import org.rumbledb.api.Item; +import org.rumbledb.context.DynamicContext; +import org.rumbledb.context.Name; +import org.rumbledb.exceptions.ExceptionMetadata; +import org.rumbledb.expressions.ExecutionMode; +import org.rumbledb.runtime.AtMostOneItemLocalRuntimeIterator; +import org.rumbledb.types.SequenceType; + +import java.util.List; +import java.util.Map; +import java.util.TreeMap; + +public class AtMostOneItemVariableReferenceIterator extends AtMostOneItemLocalRuntimeIterator { + + + private static final long serialVersionUID = 1L; + private SequenceType sequence; + private Name variableName; + + public AtMostOneItemVariableReferenceIterator( + Name variableName, + SequenceType seq, + ExecutionMode executionMode, + ExceptionMetadata iteratorMetadata + ) { + super(null, executionMode, iteratorMetadata); + this.variableName = variableName; + this.sequence = seq; + } + + public SequenceType getSequence() { + return this.sequence; + } + + public Name getVariableName() { + return this.variableName; + } + + public Map getVariableDependencies() { + Map result = new TreeMap<>(); + result.put(this.variableName, DynamicContext.VariableDependency.FULL); + return result; + } + + @Override + public Item materializeFirstItemOrNull(DynamicContext context) { + List items = context.getVariableValues() + .getLocalVariableValue( + this.variableName, + getMetadata() + ); + if (items.isEmpty()) { + return null; + } + return items.get(0); + } +} diff --git a/src/test/resources/test_files/runtime-spark/DataFrames/GroupbyClause20.jq b/src/test/resources/test_files/runtime-spark/DataFrames/GroupbyClause20.jq new file mode 100644 index 0000000000..5304ea1d9c --- /dev/null +++ b/src/test/resources/test_files/runtime-spark/DataFrames/GroupbyClause20.jq @@ -0,0 +1,6 @@ +(:JIQS: ShouldRun; Output="({ "j" : 0, "i" : [ 2, 4, 6, 8, 10 ] }, { "j" : 1, "i" : [ 1, 3, 5, 7, 9 ] })" :) +for $i in parallelize(1 to 10) +let $j := float($i mod 2) +group by $j +order by $j +return { "j" : $j, "i" : [ $i ] }