diff --git a/.asf.yaml b/.asf.yaml index 22042b355b2fa..3935a525ff3c4 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -31,6 +31,8 @@ github: merge: false squash: true rebase: true + ghp_branch: master + ghp_path: /docs notifications: pullrequests: reviews@spark.apache.org diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 2b459e4c73bbb..43ac6b50052ae 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -1112,6 +1112,10 @@ jobs: with: distribution: zulu java-version: ${{ inputs.java }} + - name: Install R + run: | + sudo apt update + sudo apt-get install r-base - name: Start Minikube uses: medyagh/setup-minikube@v0.0.18 with: diff --git a/.github/workflows/build_maven_java21_macos14.yml b/.github/workflows/build_maven_java21_macos15.yml similarity index 92% rename from .github/workflows/build_maven_java21_macos14.yml rename to .github/workflows/build_maven_java21_macos15.yml index fb5e609f4eae0..cc6d0ea4e90da 100644 --- a/.github/workflows/build_maven_java21_macos14.yml +++ b/.github/workflows/build_maven_java21_macos15.yml @@ -17,7 +17,7 @@ # under the License. # -name: "Build / Maven (master, Scala 2.13, Hadoop 3, JDK 21, macos-14)" +name: "Build / Maven (master, Scala 2.13, Hadoop 3, JDK 21, MacOS-15)" on: schedule: @@ -32,7 +32,7 @@ jobs: if: github.repository == 'apache/spark' with: java: 21 - os: macos-14 + os: macos-15 envs: >- { "OBJC_DISABLE_INITIALIZE_FORK_SAFETY": "YES" diff --git a/.github/workflows/build_python_3.13.yml b/.github/workflows/build_python_3.13.yml new file mode 100644 index 0000000000000..6f67cf383584f --- /dev/null +++ b/.github/workflows/build_python_3.13.yml @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +name: "Build / Python-only (master, Python 3.13)" + +on: + schedule: + - cron: '0 20 * * *' + +jobs: + run-build: + permissions: + packages: write + name: Run + uses: ./.github/workflows/build_and_test.yml + if: github.repository == 'apache/spark' + with: + java: 17 + branch: master + hadoop: hadoop3 + envs: >- + { + "PYTHON_TO_TEST": "python3.13" + } + jobs: >- + { + "pyspark": "true", + "pyspark-pandas": "true" + } diff --git a/.github/workflows/build_python_connect.yml b/.github/workflows/build_python_connect.yml index 3ac1a0117e41b..f668d813ef26e 100644 --- a/.github/workflows/build_python_connect.yml +++ b/.github/workflows/build_python_connect.yml @@ -71,7 +71,7 @@ jobs: python packaging/connect/setup.py sdist cd dist pip install pyspark*connect-*.tar.gz - pip install 'six==1.16.0' 'pandas<=2.2.2' scipy 'plotly>=4.8' 'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 'scikit-learn>=1.3.2' 'graphviz==0.20.3' torch torchvision torcheval deepspeed unittest-xml-reporting + pip install 'six==1.16.0' 'pandas<=2.2.2' scipy 'plotly>=4.8' 'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 'scikit-learn>=1.3.2' 'graphviz==0.20.3' torch torchvision torcheval deepspeed unittest-xml-reporting 'plotly>=4.8' - name: Run tests env: SPARK_TESTING: 1 diff --git a/.github/workflows/maven_test.yml b/.github/workflows/maven_test.yml index 82b72bd7e91d2..dd089d665d6e3 100644 --- a/.github/workflows/maven_test.yml +++ b/.github/workflows/maven_test.yml @@ -40,7 +40,7 @@ on: description: OS to run this build. required: false type: string - default: ubuntu-22.04 + default: ubuntu-latest envs: description: Additional environment variables to set when running the tests. Should be in JSON format. required: false diff --git a/.github/workflows/pages.yml b/.github/workflows/pages.yml index f10dadf315a1b..f78f7895a183f 100644 --- a/.github/workflows/pages.yml +++ b/.github/workflows/pages.yml @@ -26,7 +26,7 @@ on: concurrency: group: 'docs preview' - cancel-in-progress: true + cancel-in-progress: false jobs: docs: @@ -35,6 +35,8 @@ jobs: permissions: id-token: write pages: write + environment: + name: github-pages # https://github.com/actions/deploy-pages/issues/271 env: SPARK_TESTING: 1 # Reduce some noise in the logs RELEASE_VERSION: 'In-Progress' @@ -56,7 +58,12 @@ jobs: architecture: x64 cache: 'pip' - name: Install Python dependencies - run: pip install --upgrade -r dev/requirements.txt + run: | + pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \ + ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.20.0' pyarrow 'pandas==2.2.3' 'plotly>=4.8' 'docutils<0.18.0' \ + 'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==23.9.1' \ + 'pandas-stubs==1.2.0.53' 'grpcio==1.62.0' 'grpcio-status==1.62.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \ + 'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5' - name: Install Ruby for documentation generation uses: ruby/setup-ruby@v1 with: diff --git a/.nojekyll b/.nojekyll new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/assembly/pom.xml b/assembly/pom.xml index 01bd324efc118..17bb81fa023ba 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -117,6 +117,12 @@ org.apache.spark spark-connect-client-jvm_${scala.binary.version} ${project.version} + + + org.apache.spark + spark-connect-shims_${scala.binary.version} + + provided diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java index 13a9d89f4705c..7f8d6c58aec7e 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java @@ -255,7 +255,8 @@ public Iterator iterator() { iteratorTracker.add(new WeakReference<>(it)); return it; } catch (Exception e) { - throw Throwables.propagate(e); + Throwables.throwIfUnchecked(e); + throw new RuntimeException(e); } } }; diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java index 69757fdc65d68..29ed37ffa44e5 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java @@ -127,7 +127,7 @@ public boolean hasNext() { try { close(); } catch (IOException ioe) { - throw Throwables.propagate(ioe); + throw new RuntimeException(ioe); } } return next != null; @@ -151,7 +151,8 @@ public T next() { next = null; return ret; } catch (Exception e) { - throw Throwables.propagate(e); + Throwables.throwIfUnchecked(e); + throw new RuntimeException(e); } } diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDB.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDB.java index dc7ad0be5c007..4bc2b233fe12d 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDB.java @@ -287,7 +287,8 @@ public Iterator iterator() { iteratorTracker.add(new WeakReference<>(it)); return it; } catch (Exception e) { - throw Throwables.propagate(e); + Throwables.throwIfUnchecked(e); + throw new RuntimeException(e); } } }; diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDBIterator.java index a98b0482e35cc..e350ddc2d445a 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDBIterator.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDBIterator.java @@ -113,7 +113,7 @@ public boolean hasNext() { try { close(); } catch (IOException ioe) { - throw Throwables.propagate(ioe); + throw new RuntimeException(ioe); } } return next != null; @@ -137,7 +137,8 @@ public T next() { next = null; return ret; } catch (Exception e) { - throw Throwables.propagate(e); + Throwables.throwIfUnchecked(e); + throw new RuntimeException(e); } } diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index cdb5bd72158a1..cbe4836b58da5 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -194,6 +194,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 4c144a73a9299..a9df47645d36f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -290,9 +290,11 @@ public void onFailure(Throwable e) { try { return result.get(timeoutMs, TimeUnit.MILLISECONDS); } catch (ExecutionException e) { - throw Throwables.propagate(e.getCause()); + Throwables.throwIfUnchecked(e.getCause()); + throw new RuntimeException(e.getCause()); } catch (Exception e) { - throw Throwables.propagate(e); + Throwables.throwIfUnchecked(e); + throw new RuntimeException(e); } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index e1f19f956cc0a..d64b8c8f838e9 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -342,7 +342,8 @@ public void operationComplete(final Future handshakeFuture) { logger.error("Exception while bootstrapping client after {} ms", e, MDC.of(LogKeys.BOOTSTRAP_TIME$.MODULE$, bootstrapTimeMs)); client.close(); - throw Throwables.propagate(e); + Throwables.throwIfUnchecked(e); + throw new RuntimeException(e); } long postBootstrap = System.nanoTime(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java index 08e2c084fe67b..2e9ccd0e0ad21 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java @@ -22,7 +22,6 @@ import java.security.GeneralSecurityException; import java.util.concurrent.TimeoutException; -import com.google.common.base.Throwables; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; @@ -80,7 +79,7 @@ public void doBootstrap(TransportClient client, Channel channel) { doSparkAuth(client, channel); client.setClientId(appId); } catch (GeneralSecurityException | IOException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } catch (RuntimeException e) { // There isn't a good exception that can be caught here to know whether it's really // OK to switch back to SASL (because the server doesn't speak the new protocol). So diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java index 65367743e24f9..087e3d21e22bb 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java @@ -132,7 +132,8 @@ protected boolean doAuthChallenge( try { engine.close(); } catch (Exception e) { - throw Throwables.propagate(e); + Throwables.throwIfUnchecked(e); + throw new RuntimeException(e); } } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java index 355c552720185..33494aee4444d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java @@ -17,32 +17,12 @@ package org.apache.spark.network.crypto; -import com.google.common.annotations.VisibleForTesting; -import com.google.crypto.tink.subtle.Hex; -import com.google.crypto.tink.subtle.Hkdf; import io.netty.channel.Channel; -import javax.crypto.spec.SecretKeySpec; import java.io.IOException; -import java.nio.charset.StandardCharsets; import java.security.GeneralSecurityException; interface TransportCipher { String getKeyId() throws GeneralSecurityException; void addToChannel(Channel channel) throws IOException, GeneralSecurityException; } - -class TransportCipherUtil { - /* - * This method is used for testing to verify key derivation. - */ - @VisibleForTesting - static String getKeyId(SecretKeySpec key) throws GeneralSecurityException { - byte[] keyIdBytes = Hkdf.computeHkdf("HmacSha256", - key.getEncoded(), - null, - "keyID".getBytes(StandardCharsets.UTF_8), - 32); - return Hex.encode(keyIdBytes); - } -} diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipherUtil.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipherUtil.java new file mode 100644 index 0000000000000..1df2732f240cc --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipherUtil.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.nio.charset.StandardCharsets; +import java.security.GeneralSecurityException; +import javax.crypto.spec.SecretKeySpec; + +import com.google.common.annotations.VisibleForTesting; +import com.google.crypto.tink.subtle.Hex; +import com.google.crypto.tink.subtle.Hkdf; + +class TransportCipherUtil { + /** + * This method is used for testing to verify key derivation. + */ + @VisibleForTesting + static String getKeyId(SecretKeySpec key) throws GeneralSecurityException { + byte[] keyIdBytes = Hkdf.computeHkdf("HmacSha256", + key.getEncoded(), + null, + "keyID".getBytes(StandardCharsets.UTF_8), + 32); + return Hex.encode(keyIdBytes); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java index 3600c1045dbf4..a61b1c3c0c416 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java @@ -29,7 +29,6 @@ import javax.security.sasl.SaslClient; import javax.security.sasl.SaslException; -import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; import org.apache.spark.internal.SparkLogger; @@ -62,7 +61,7 @@ public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder, bool this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM, saslProps, new ClientCallbackHandler()); } catch (SaslException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } } @@ -72,7 +71,7 @@ public synchronized byte[] firstToken() { try { return saslClient.evaluateChallenge(new byte[0]); } catch (SaslException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } } else { return new byte[0]; @@ -98,7 +97,7 @@ public synchronized byte[] response(byte[] token) { try { return saslClient != null ? saslClient.evaluateChallenge(token) : new byte[0]; } catch (SaslException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java index b897650afe832..f32fd5145c7c5 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -31,7 +31,6 @@ import java.util.Map; import com.google.common.base.Preconditions; -import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; @@ -94,7 +93,7 @@ public SparkSaslServer( this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, saslProps, new DigestCallbackHandler()); } catch (SaslException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } } @@ -119,7 +118,7 @@ public synchronized byte[] response(byte[] token) { try { return saslServer != null ? saslServer.evaluateResponse(token) : new byte[0]; } catch (SaslException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/shuffledb/LevelDBIterator.java b/common/network-common/src/main/java/org/apache/spark/network/shuffledb/LevelDBIterator.java index 5796e34a6f05e..2ac549775449a 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/shuffledb/LevelDBIterator.java +++ b/common/network-common/src/main/java/org/apache/spark/network/shuffledb/LevelDBIterator.java @@ -17,8 +17,6 @@ package org.apache.spark.network.shuffledb; -import com.google.common.base.Throwables; - import java.io.IOException; import java.util.Map; import java.util.NoSuchElementException; @@ -47,7 +45,7 @@ public boolean hasNext() { try { close(); } catch (IOException ioe) { - throw Throwables.propagate(ioe); + throw new RuntimeException(ioe); } } return next != null; diff --git a/common/network-common/src/main/java/org/apache/spark/network/shuffledb/RocksDB.java b/common/network-common/src/main/java/org/apache/spark/network/shuffledb/RocksDB.java index d33895d6c2d62..2737ab8ed754c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/shuffledb/RocksDB.java +++ b/common/network-common/src/main/java/org/apache/spark/network/shuffledb/RocksDB.java @@ -19,7 +19,6 @@ import java.io.IOException; -import com.google.common.base.Throwables; import org.rocksdb.RocksDBException; /** @@ -37,7 +36,7 @@ public void put(byte[] key, byte[] value) { try { db.put(key, value); } catch (RocksDBException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } } @@ -46,7 +45,7 @@ public byte[] get(byte[] key) { try { return db.get(key); } catch (RocksDBException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } } @@ -55,7 +54,7 @@ public void delete(byte[] key) { try { db.delete(key); } catch (RocksDBException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/shuffledb/RocksDBIterator.java b/common/network-common/src/main/java/org/apache/spark/network/shuffledb/RocksDBIterator.java index 78562f91a4b75..829a7ded6330b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/shuffledb/RocksDBIterator.java +++ b/common/network-common/src/main/java/org/apache/spark/network/shuffledb/RocksDBIterator.java @@ -22,7 +22,6 @@ import java.util.Map; import java.util.NoSuchElementException; -import com.google.common.base.Throwables; import org.rocksdb.RocksIterator; /** @@ -52,7 +51,7 @@ public boolean hasNext() { try { close(); } catch (IOException ioe) { - throw Throwables.propagate(ioe); + throw new RuntimeException(ioe); } } return next != null; diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 0f7036ef746cc..49e6e08476151 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -113,6 +113,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + commons-io commons-io diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index a5ef9847859a7..cf15301273303 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -104,6 +104,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.scalacheck scalacheck_${scala.binary.version} diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java index 5ed3048fb72b3..d67697eaea38b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -109,7 +109,7 @@ private static int lowercaseMatchLengthFrom( } // Compare the characters in the target and pattern strings. int matchLength = 0, codePointBuffer = -1, targetCodePoint, patternCodePoint; - while (targetIterator.hasNext() && patternIterator.hasNext()) { + while ((targetIterator.hasNext() || codePointBuffer != -1) && patternIterator.hasNext()) { if (codePointBuffer != -1) { targetCodePoint = codePointBuffer; codePointBuffer = -1; @@ -211,7 +211,7 @@ private static int lowercaseMatchLengthUntil( } // Compare the characters in the target and pattern strings. int matchLength = 0, codePointBuffer = -1, targetCodePoint, patternCodePoint; - while (targetIterator.hasNext() && patternIterator.hasNext()) { + while ((targetIterator.hasNext() || codePointBuffer != -1) && patternIterator.hasNext()) { if (codePointBuffer != -1) { targetCodePoint = codePointBuffer; codePointBuffer = -1; @@ -1363,9 +1363,9 @@ public static UTF8String trimRight( public static UTF8String[] splitSQL(final UTF8String input, final UTF8String delim, final int limit, final int collationId) { - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + if (CollationFactory.fetchCollation(collationId).isUtf8BinaryType) { return input.split(delim, limit); - } else if (CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) { + } else if (CollationFactory.fetchCollation(collationId).isUtf8LcaseType) { return lowercaseSplitSQL(input, delim, limit); } else { return icuSplitSQL(input, delim, limit, collationId); diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 4b88e15e8ed72..50bb93465921e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -99,7 +99,8 @@ public record CollationMeta( String icuVersion, String padAttribute, boolean accentSensitivity, - boolean caseSensitivity) { } + boolean caseSensitivity, + String spaceTrimming) { } /** * Entry encapsulating all information about a collation. @@ -153,6 +154,24 @@ public static class Collation { */ public final boolean supportsLowercaseEquality; + /** + * Support for Space Trimming implies that that based on specifier (for now only right trim) + * leading, trailing or both spaces are removed from the input string before comparison. + */ + public final boolean supportsSpaceTrimming; + + /** + * Is Utf8 binary type as indicator if collation base type is UTF8 binary. Note currently only + * collations Utf8_Binary and Utf8_Binary_RTRIM are considered as Utf8 binary type. + */ + public final boolean isUtf8BinaryType; + + /** + * Is Utf8 lcase type as indicator if collation base type is UTF8 lcase. Note currently only + * collations Utf8_Lcase and Utf8_Lcase_RTRIM are considered as Utf8 Lcase type. + */ + public final boolean isUtf8LcaseType; + public Collation( String collationName, String provider, @@ -160,31 +179,27 @@ public Collation( Comparator comparator, String version, ToLongFunction hashFunction, - boolean supportsBinaryEquality, - boolean supportsBinaryOrdering, - boolean supportsLowercaseEquality) { + BiFunction equalsFunction, + boolean isUtf8BinaryType, + boolean isUtf8LcaseType, + boolean supportsSpaceTrimming) { this.collationName = collationName; this.provider = provider; this.collator = collator; this.comparator = comparator; this.version = version; this.hashFunction = hashFunction; - this.supportsBinaryEquality = supportsBinaryEquality; - this.supportsBinaryOrdering = supportsBinaryOrdering; - this.supportsLowercaseEquality = supportsLowercaseEquality; - - // De Morgan's Law to check supportsBinaryOrdering => supportsBinaryEquality - assert(!supportsBinaryOrdering || supportsBinaryEquality); + this.isUtf8BinaryType = isUtf8BinaryType; + this.isUtf8LcaseType = isUtf8LcaseType; + this.equalsFunction = equalsFunction; + this.supportsSpaceTrimming = supportsSpaceTrimming; + this.supportsBinaryEquality = !supportsSpaceTrimming && isUtf8BinaryType; + this.supportsBinaryOrdering = !supportsSpaceTrimming && isUtf8BinaryType; + this.supportsLowercaseEquality = !supportsSpaceTrimming && isUtf8LcaseType; // No Collation can simultaneously support binary equality and lowercase equality assert(!supportsBinaryEquality || !supportsLowercaseEquality); assert(SUPPORTED_PROVIDERS.contains(provider)); - - if (supportsBinaryEquality) { - this.equalsFunction = UTF8String::equals; - } else { - this.equalsFunction = (s1, s2) -> this.comparator.compare(s1, s2) == 0; - } } /** @@ -199,7 +214,8 @@ public Collation( * bit 29: 0 for UTF8_BINARY, 1 for ICU collations. * bit 28-24: Reserved. * bit 23-22: Reserved for version. - * bit 21-18: Reserved for space trimming. + * bit 21-19 Zeros, reserved for future trimmings. + * bit 18 0 = none, 1 = right trim. * bit 17-0: Depend on collation family. * --- * INDETERMINATE collation ID binary layout: @@ -214,7 +230,8 @@ public Collation( * UTF8_BINARY collation ID binary layout: * bit 31-24: Zeroes. * bit 23-22: Zeroes, reserved for version. - * bit 21-18: Zeroes, reserved for space trimming. + * bit 21-19 Zeros, reserved for future trimmings. + * bit 18 0 = none, 1 = right trim. * bit 17-3: Zeroes. * bit 2: 0, reserved for accent sensitivity. * bit 1: 0, reserved for uppercase and case-insensitive. @@ -225,7 +242,8 @@ public Collation( * bit 29: 1 * bit 28-24: Zeroes. * bit 23-22: Zeroes, reserved for version. - * bit 21-18: Zeroes, reserved for space trimming. + * bit 21-18: Reserved for space trimming. + * 0000 = none, 0001 = right trim. Bits 21-19 remain reserved and fixed to 0. * bit 17: 0 = case-sensitive, 1 = case-insensitive. * bit 16: 0 = accent-sensitive, 1 = accent-insensitive. * bit 15-14: Zeroes, reserved for punctuation sensitivity. @@ -233,14 +251,20 @@ public Collation( * bit 11-0: Locale ID as specified in `ICULocaleToId` mapping. * --- * Some illustrative examples of collation name to ID mapping: - * - UTF8_BINARY -> 0 - * - UTF8_LCASE -> 1 - * - UNICODE -> 0x20000000 - * - UNICODE_AI -> 0x20010000 - * - UNICODE_CI -> 0x20020000 - * - UNICODE_CI_AI -> 0x20030000 - * - af -> 0x20000001 - * - af_CI_AI -> 0x20030001 + * - UTF8_BINARY -> 0 + * - UTF8_BINARY_RTRIM -> 0x00040000 + * - UTF8_LCASE -> 1 + * - UTF8_LCASE_RTRIM -> 0x00040001 + * - UNICODE -> 0x20000000 + * - UNICODE_AI -> 0x20010000 + * - UNICODE_CI -> 0x20020000 + * - UNICODE_RTRIM -> 0x20040000 + * - UNICODE_CI_AI -> 0x20030000 + * - UNICODE_CI_RTRIM -> 0x20060000 + * - UNICODE_AI_RTRIM -> 0x20050000 + * - UNICODE_CI_AI_RTRIM-> 0x20070000 + * - af -> 0x20000001 + * - af_CI_AI -> 0x20030001 */ private abstract static class CollationSpec { @@ -259,6 +283,14 @@ protected enum ImplementationProvider { UTF8_BINARY, ICU } + /** + * Bit 18 in collation ID having value 0 for none and 1 for right trimming. + * Bits 21, 20, 19 remained reserved (and fixed to 0) for future use. + */ + protected enum SpaceTrimming { + NONE, RTRIM + } + /** * Offset in binary collation ID layout. */ @@ -279,6 +311,17 @@ protected enum ImplementationProvider { */ protected static final int IMPLEMENTATION_PROVIDER_MASK = 0b1; + + /** + * Offset in binary collation ID layout. + */ + protected static final int SPACE_TRIMMING_OFFSET = 18; + + /** + * Bitmask corresponding to width in bits in binary collation ID layout. + */ + protected static final int SPACE_TRIMMING_MASK = 0b1; + private static final int INDETERMINATE_COLLATION_ID = -1; /** @@ -303,6 +346,28 @@ private static DefinitionOrigin getDefinitionOrigin(int collationId) { DEFINITION_ORIGIN_OFFSET, DEFINITION_ORIGIN_MASK)]; } + /** + * Utility function to retrieve `SpaceTrimming` enum instance from collation ID. + */ + protected static SpaceTrimming getSpaceTrimming(int collationId) { + return SpaceTrimming.values()[SpecifierUtils.getSpecValue(collationId, + SPACE_TRIMMING_OFFSET, SPACE_TRIMMING_MASK)]; + } + + protected static UTF8String applyTrimmingPolicy(UTF8String s, int collationId) { + return applyTrimmingPolicy(s, getSpaceTrimming(collationId)); + } + + /** + * Utility function to trim spaces when collation uses space trimming. + */ + protected static UTF8String applyTrimmingPolicy(UTF8String s, SpaceTrimming spaceTrimming) { + if(spaceTrimming == SpaceTrimming.RTRIM){ + return s.trimRight(); + } + return s; // No trimming. + } + /** * Main entry point for retrieving `Collation` instance from collation ID. */ @@ -358,6 +423,8 @@ private static int collationNameToId(String collationName) throws SparkException protected abstract CollationMeta buildCollationMeta(); + protected abstract String normalizedCollationName(); + static List listCollations() { return Stream.concat( CollationSpecUTF8.listCollations().stream(), @@ -398,97 +465,201 @@ private enum CaseSensitivity { private static final String UTF8_LCASE_COLLATION_NAME = "UTF8_LCASE"; private static final int UTF8_BINARY_COLLATION_ID = - new CollationSpecUTF8(CaseSensitivity.UNSPECIFIED).collationId; + new CollationSpecUTF8(CaseSensitivity.UNSPECIFIED, SpaceTrimming.NONE).collationId; private static final int UTF8_LCASE_COLLATION_ID = - new CollationSpecUTF8(CaseSensitivity.LCASE).collationId; + new CollationSpecUTF8(CaseSensitivity.LCASE, SpaceTrimming.NONE).collationId; protected static Collation UTF8_BINARY_COLLATION = - new CollationSpecUTF8(CaseSensitivity.UNSPECIFIED).buildCollation(); + new CollationSpecUTF8(CaseSensitivity.UNSPECIFIED, SpaceTrimming.NONE).buildCollation(); protected static Collation UTF8_LCASE_COLLATION = - new CollationSpecUTF8(CaseSensitivity.LCASE).buildCollation(); + new CollationSpecUTF8(CaseSensitivity.LCASE, SpaceTrimming.NONE).buildCollation(); + private final CaseSensitivity caseSensitivity; + private final SpaceTrimming spaceTrimming; private final int collationId; - private CollationSpecUTF8(CaseSensitivity caseSensitivity) { - this.collationId = + private CollationSpecUTF8( + CaseSensitivity caseSensitivity, + SpaceTrimming spaceTrimming) { + this.caseSensitivity = caseSensitivity; + this.spaceTrimming = spaceTrimming; + + int collationId = SpecifierUtils.setSpecValue(0, CASE_SENSITIVITY_OFFSET, caseSensitivity); + this.collationId = + SpecifierUtils.setSpecValue(collationId, SPACE_TRIMMING_OFFSET, spaceTrimming); } private static int collationNameToId(String originalName, String collationName) throws SparkException { - if (UTF8_BINARY_COLLATION.collationName.equals(collationName)) { - return UTF8_BINARY_COLLATION_ID; - } else if (UTF8_LCASE_COLLATION.collationName.equals(collationName)) { - return UTF8_LCASE_COLLATION_ID; + + int baseId; + String collationNamePrefix; + + if (collationName.startsWith(UTF8_BINARY_COLLATION.collationName)) { + baseId = UTF8_BINARY_COLLATION_ID; + collationNamePrefix = UTF8_BINARY_COLLATION.collationName; + } else if (collationName.startsWith(UTF8_LCASE_COLLATION.collationName)) { + baseId = UTF8_LCASE_COLLATION_ID; + collationNamePrefix = UTF8_LCASE_COLLATION.collationName; } else { // Throw exception with original (before case conversion) collation name. throw collationInvalidNameException(originalName); } + + String remainingSpecifiers = collationName.substring(collationNamePrefix.length()); + if(remainingSpecifiers.isEmpty()) { + return baseId; + } + if(!remainingSpecifiers.startsWith("_")){ + throw collationInvalidNameException(originalName); + } + + SpaceTrimming spaceTrimming = SpaceTrimming.NONE; + String remainingSpec = remainingSpecifiers.substring(1); + if (remainingSpec.equals("RTRIM")) { + spaceTrimming = SpaceTrimming.RTRIM; + } else { + throw collationInvalidNameException(originalName); + } + + return SpecifierUtils.setSpecValue(baseId, SPACE_TRIMMING_OFFSET, spaceTrimming); } private static CollationSpecUTF8 fromCollationId(int collationId) { // Extract case sensitivity from collation ID. int caseConversionOrdinal = SpecifierUtils.getSpecValue(collationId, CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK); - // Verify only case sensitivity bits were set settable in UTF8_BINARY family of collations. - assert (SpecifierUtils.removeSpec(collationId, - CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK) == 0); - return new CollationSpecUTF8(CaseSensitivity.values()[caseConversionOrdinal]); + // Extract space trimming from collation ID. + int spaceTrimmingOrdinal = getSpaceTrimming(collationId).ordinal(); + assert(isValidCollationId(collationId)); + return new CollationSpecUTF8( + CaseSensitivity.values()[caseConversionOrdinal], + SpaceTrimming.values()[spaceTrimmingOrdinal]); + } + + private static boolean isValidCollationId(int collationId) { + collationId = SpecifierUtils.removeSpec( + collationId, + SPACE_TRIMMING_OFFSET, + SPACE_TRIMMING_MASK); + collationId = SpecifierUtils.removeSpec( + collationId, + CASE_SENSITIVITY_OFFSET, + CASE_SENSITIVITY_MASK); + return collationId == 0; } @Override protected Collation buildCollation() { - if (collationId == UTF8_BINARY_COLLATION_ID) { + if (caseSensitivity == CaseSensitivity.UNSPECIFIED) { + Comparator comparator; + ToLongFunction hashFunction; + BiFunction equalsFunction; + boolean supportsSpaceTrimming = spaceTrimming != SpaceTrimming.NONE; + + if (spaceTrimming == SpaceTrimming.NONE) { + comparator = UTF8String::binaryCompare; + hashFunction = s -> (long) s.hashCode(); + equalsFunction = UTF8String::equals; + } else { + comparator = (s1, s2) -> applyTrimmingPolicy(s1, spaceTrimming).binaryCompare( + applyTrimmingPolicy(s2, spaceTrimming)); + hashFunction = s -> (long) applyTrimmingPolicy(s, spaceTrimming).hashCode(); + equalsFunction = (s1, s2) -> applyTrimmingPolicy(s1, spaceTrimming).equals( + applyTrimmingPolicy(s2, spaceTrimming)); + } + return new Collation( - UTF8_BINARY_COLLATION_NAME, + normalizedCollationName(), PROVIDER_SPARK, null, - UTF8String::binaryCompare, + comparator, "1.0", - s -> (long) s.hashCode(), - /* supportsBinaryEquality = */ true, - /* supportsBinaryOrdering = */ true, - /* supportsLowercaseEquality = */ false); + hashFunction, + equalsFunction, + /* isUtf8BinaryType = */ true, + /* isUtf8LcaseType = */ false, + spaceTrimming != SpaceTrimming.NONE); } else { + Comparator comparator; + ToLongFunction hashFunction; + + if (spaceTrimming == SpaceTrimming.NONE) { + comparator = CollationAwareUTF8String::compareLowerCase; + hashFunction = s -> + (long) CollationAwareUTF8String.lowerCaseCodePoints(s).hashCode(); + } else { + comparator = (s1, s2) -> CollationAwareUTF8String.compareLowerCase( + applyTrimmingPolicy(s1, spaceTrimming), + applyTrimmingPolicy(s2, spaceTrimming)); + hashFunction = s -> (long) CollationAwareUTF8String.lowerCaseCodePoints( + applyTrimmingPolicy(s, spaceTrimming)).hashCode(); + } + return new Collation( - UTF8_LCASE_COLLATION_NAME, + normalizedCollationName(), PROVIDER_SPARK, null, - CollationAwareUTF8String::compareLowerCase, + comparator, "1.0", - s -> (long) CollationAwareUTF8String.lowerCaseCodePoints(s).hashCode(), - /* supportsBinaryEquality = */ false, - /* supportsBinaryOrdering = */ false, - /* supportsLowercaseEquality = */ true); + hashFunction, + (s1, s2) -> comparator.compare(s1, s2) == 0, + /* isUtf8BinaryType = */ false, + /* isUtf8LcaseType = */ true, + spaceTrimming != SpaceTrimming.NONE); } } @Override protected CollationMeta buildCollationMeta() { - if (collationId == UTF8_BINARY_COLLATION_ID) { + if (caseSensitivity == CaseSensitivity.UNSPECIFIED) { return new CollationMeta( CATALOG, SCHEMA, - UTF8_BINARY_COLLATION_NAME, + normalizedCollationName(), /* language = */ null, /* country = */ null, /* icuVersion = */ null, COLLATION_PAD_ATTRIBUTE, /* accentSensitivity = */ true, - /* caseSensitivity = */ true); + /* caseSensitivity = */ true, + spaceTrimming.toString()); } else { return new CollationMeta( CATALOG, SCHEMA, - UTF8_LCASE_COLLATION_NAME, + normalizedCollationName(), /* language = */ null, /* country = */ null, /* icuVersion = */ null, COLLATION_PAD_ATTRIBUTE, /* accentSensitivity = */ true, - /* caseSensitivity = */ false); + /* caseSensitivity = */ false, + spaceTrimming.toString()); } } + /** + * Compute normalized collation name. Components of collation name are given in order: + * - Base collation name (UTF8_BINARY or UTF8_LCASE) + * - Optional space trimming when non-default preceded by underscore + * Examples: UTF8_BINARY, UTF8_BINARY_LCASE_LTRIM, UTF8_BINARY_TRIM. + */ + @Override + protected String normalizedCollationName() { + StringBuilder builder = new StringBuilder(); + if(caseSensitivity == CaseSensitivity.UNSPECIFIED){ + builder.append(UTF8_BINARY_COLLATION_NAME); + } else{ + builder.append(UTF8_LCASE_COLLATION_NAME); + } + if (spaceTrimming != SpaceTrimming.NONE) { + builder.append('_'); + builder.append(spaceTrimming.toString()); + } + return builder.toString(); + } + static List listCollations() { CollationIdentifier UTF8_BINARY_COLLATION_IDENT = new CollationIdentifier(PROVIDER_SPARK, UTF8_BINARY_COLLATION_NAME, "1.0"); @@ -620,21 +791,33 @@ private enum AccentSensitivity { } } - private static final int UNICODE_COLLATION_ID = - new CollationSpecICU("UNICODE", CaseSensitivity.CS, AccentSensitivity.AS).collationId; - private static final int UNICODE_CI_COLLATION_ID = - new CollationSpecICU("UNICODE", CaseSensitivity.CI, AccentSensitivity.AS).collationId; + private static final int UNICODE_COLLATION_ID = new CollationSpecICU( + "UNICODE", + CaseSensitivity.CS, + AccentSensitivity.AS, + SpaceTrimming.NONE).collationId; + + private static final int UNICODE_CI_COLLATION_ID = new CollationSpecICU( + "UNICODE", + CaseSensitivity.CI, + AccentSensitivity.AS, + SpaceTrimming.NONE).collationId; private final CaseSensitivity caseSensitivity; private final AccentSensitivity accentSensitivity; + private final SpaceTrimming spaceTrimming; private final String locale; private final int collationId; - private CollationSpecICU(String locale, CaseSensitivity caseSensitivity, - AccentSensitivity accentSensitivity) { + private CollationSpecICU( + String locale, + CaseSensitivity caseSensitivity, + AccentSensitivity accentSensitivity, + SpaceTrimming spaceTrimming) { this.locale = locale; this.caseSensitivity = caseSensitivity; this.accentSensitivity = accentSensitivity; + this.spaceTrimming = spaceTrimming; // Construct collation ID from locale, case-sensitivity and accent-sensitivity specifiers. int collationId = ICULocaleToId.get(locale); // Mandatory ICU implementation provider. @@ -644,6 +827,8 @@ private CollationSpecICU(String locale, CaseSensitivity caseSensitivity, caseSensitivity); collationId = SpecifierUtils.setSpecValue(collationId, ACCENT_SENSITIVITY_OFFSET, accentSensitivity); + collationId = SpecifierUtils.setSpecValue(collationId, SPACE_TRIMMING_OFFSET, + spaceTrimming); this.collationId = collationId; } @@ -661,58 +846,86 @@ private static int collationNameToId( } if (lastPos == -1) { throw collationInvalidNameException(originalName); - } else { - String locale = collationName.substring(0, lastPos); - int collationId = ICULocaleToId.get(ICULocaleMapUppercase.get(locale)); - - // Try all combinations of AS/AI and CS/CI. - CaseSensitivity caseSensitivity; - AccentSensitivity accentSensitivity; - if (collationName.equals(locale) || - collationName.equals(locale + "_AS") || - collationName.equals(locale + "_CS") || - collationName.equals(locale + "_AS_CS") || - collationName.equals(locale + "_CS_AS") - ) { - caseSensitivity = CaseSensitivity.CS; - accentSensitivity = AccentSensitivity.AS; - } else if (collationName.equals(locale + "_CI") || - collationName.equals(locale + "_AS_CI") || - collationName.equals(locale + "_CI_AS")) { - caseSensitivity = CaseSensitivity.CI; - accentSensitivity = AccentSensitivity.AS; - } else if (collationName.equals(locale + "_AI") || - collationName.equals(locale + "_CS_AI") || - collationName.equals(locale + "_AI_CS")) { - caseSensitivity = CaseSensitivity.CS; - accentSensitivity = AccentSensitivity.AI; - } else if (collationName.equals(locale + "_AI_CI") || - collationName.equals(locale + "_CI_AI")) { - caseSensitivity = CaseSensitivity.CI; - accentSensitivity = AccentSensitivity.AI; - } else { - throw collationInvalidNameException(originalName); - } + } + String locale = collationName.substring(0, lastPos); + int collationId = ICULocaleToId.get(ICULocaleMapUppercase.get(locale)); + collationId = SpecifierUtils.setSpecValue(collationId, + IMPLEMENTATION_PROVIDER_OFFSET, ImplementationProvider.ICU); - // Build collation ID from computed specifiers. - collationId = SpecifierUtils.setSpecValue(collationId, - IMPLEMENTATION_PROVIDER_OFFSET, ImplementationProvider.ICU); - collationId = SpecifierUtils.setSpecValue(collationId, - CASE_SENSITIVITY_OFFSET, caseSensitivity); - collationId = SpecifierUtils.setSpecValue(collationId, - ACCENT_SENSITIVITY_OFFSET, accentSensitivity); + // No other specifiers present. + if(collationName.equals(locale)){ return collationId; } + if(collationName.charAt(locale.length()) != '_'){ + throw collationInvalidNameException(originalName); + } + // Extract remaining specifiers and trim "_" separator. + String remainingSpecifiers = collationName.substring(lastPos + 1); + + // Initialize default specifier flags. + // Case sensitive, accent sensitive, no space trimming. + boolean isCaseSpecifierSet = false; + boolean isAccentSpecifierSet = false; + boolean isSpaceTrimmingSpecifierSet = false; + CaseSensitivity caseSensitivity = CaseSensitivity.CS; + AccentSensitivity accentSensitivity = AccentSensitivity.AS; + SpaceTrimming spaceTrimming = SpaceTrimming.NONE; + + String[] specifiers = remainingSpecifiers.split("_"); + + // Iterate through specifiers and set corresponding flags + for (String specifier : specifiers) { + switch (specifier) { + case "CI": + case "CS": + if (isCaseSpecifierSet) { + throw collationInvalidNameException(originalName); + } + caseSensitivity = CaseSensitivity.valueOf(specifier); + isCaseSpecifierSet = true; + break; + case "AI": + case "AS": + if (isAccentSpecifierSet) { + throw collationInvalidNameException(originalName); + } + accentSensitivity = AccentSensitivity.valueOf(specifier); + isAccentSpecifierSet = true; + break; + case "RTRIM": + if (isSpaceTrimmingSpecifierSet) { + throw collationInvalidNameException(originalName); + } + spaceTrimming = SpaceTrimming.valueOf(specifier); + isSpaceTrimmingSpecifierSet = true; + break; + default: + throw collationInvalidNameException(originalName); + } + } + + // Build collation ID from computed specifiers. + collationId = SpecifierUtils.setSpecValue(collationId, + CASE_SENSITIVITY_OFFSET, caseSensitivity); + collationId = SpecifierUtils.setSpecValue(collationId, + ACCENT_SENSITIVITY_OFFSET, accentSensitivity); + collationId = SpecifierUtils.setSpecValue(collationId, + SPACE_TRIMMING_OFFSET, spaceTrimming); + return collationId; } private static CollationSpecICU fromCollationId(int collationId) { // Parse specifiers from collation ID. + int spaceTrimmingOrdinal = SpecifierUtils.getSpecValue(collationId, + SPACE_TRIMMING_OFFSET, SPACE_TRIMMING_MASK); int caseSensitivityOrdinal = SpecifierUtils.getSpecValue(collationId, CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK); int accentSensitivityOrdinal = SpecifierUtils.getSpecValue(collationId, ACCENT_SENSITIVITY_OFFSET, ACCENT_SENSITIVITY_MASK); collationId = SpecifierUtils.removeSpec(collationId, IMPLEMENTATION_PROVIDER_OFFSET, IMPLEMENTATION_PROVIDER_MASK); + collationId = SpecifierUtils.removeSpec(collationId, + SPACE_TRIMMING_OFFSET, SPACE_TRIMMING_MASK); collationId = SpecifierUtils.removeSpec(collationId, CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK); collationId = SpecifierUtils.removeSpec(collationId, @@ -723,8 +936,9 @@ private static CollationSpecICU fromCollationId(int collationId) { assert(localeId >= 0 && localeId < ICULocaleNames.length); CaseSensitivity caseSensitivity = CaseSensitivity.values()[caseSensitivityOrdinal]; AccentSensitivity accentSensitivity = AccentSensitivity.values()[accentSensitivityOrdinal]; + SpaceTrimming spaceTrimming = SpaceTrimming.values()[spaceTrimmingOrdinal]; String locale = ICULocaleNames[localeId]; - return new CollationSpecICU(locale, caseSensitivity, accentSensitivity); + return new CollationSpecICU(locale, caseSensitivity, accentSensitivity, spaceTrimming); } @Override @@ -751,16 +965,34 @@ protected Collation buildCollation() { Collator collator = Collator.getInstance(resultLocale); // Freeze ICU collator to ensure thread safety. collator.freeze(); + + Comparator comparator; + ToLongFunction hashFunction; + + if (spaceTrimming == SpaceTrimming.NONE) { + hashFunction = s -> (long) collator.getCollationKey( + s.toValidString()).hashCode(); + comparator = (s1, s2) -> + collator.compare(s1.toValidString(), s2.toValidString()); + } else { + comparator = (s1, s2) -> collator.compare( + applyTrimmingPolicy(s1, spaceTrimming).toValidString(), + applyTrimmingPolicy(s2, spaceTrimming).toValidString()); + hashFunction = s -> (long) collator.getCollationKey( + applyTrimmingPolicy(s, spaceTrimming).toValidString()).hashCode(); + } + return new Collation( - collationName(), + normalizedCollationName(), PROVIDER_ICU, collator, - (s1, s2) -> collator.compare(s1.toValidString(), s2.toValidString()), + comparator, ICU_COLLATOR_VERSION, - s -> (long) collator.getCollationKey(s.toValidString()).hashCode(), - /* supportsBinaryEquality = */ false, - /* supportsBinaryOrdering = */ false, - /* supportsLowercaseEquality = */ false); + hashFunction, + (s1, s2) -> comparator.compare(s1, s2) == 0, + /* isUtf8BinaryType = */ false, + /* isUtf8LcaseType = */ false, + spaceTrimming != SpaceTrimming.NONE); } @Override @@ -768,13 +1000,14 @@ protected CollationMeta buildCollationMeta() { return new CollationMeta( CATALOG, SCHEMA, - collationName(), + normalizedCollationName(), ICULocaleMap.get(locale).getDisplayLanguage(), ICULocaleMap.get(locale).getDisplayCountry(), VersionInfo.ICU_VERSION.toString(), COLLATION_PAD_ATTRIBUTE, + accentSensitivity == AccentSensitivity.AS, caseSensitivity == CaseSensitivity.CS, - accentSensitivity == AccentSensitivity.AS); + spaceTrimming.toString()); } /** @@ -782,9 +1015,11 @@ protected CollationMeta buildCollationMeta() { * - Locale name * - Optional case sensitivity when non-default preceded by underscore * - Optional accent sensitivity when non-default preceded by underscore - * Examples: en, en_USA_CI_AI, sr_Cyrl_SRB_AI. + * - Optional space trimming when non-default preceded by underscore + * Examples: en, en_USA_CI_LTRIM, en_USA_CI_AI, en_USA_CI_AI_TRIM, sr_Cyrl_SRB_AI. */ - private String collationName() { + @Override + protected String normalizedCollationName() { StringBuilder builder = new StringBuilder(); builder.append(locale); if (caseSensitivity != CaseSensitivity.CS) { @@ -795,20 +1030,21 @@ private String collationName() { builder.append('_'); builder.append(accentSensitivity.toString()); } + if(spaceTrimming != SpaceTrimming.NONE) { + builder.append('_'); + builder.append(spaceTrimming.toString()); + } return builder.toString(); } private static List allCollationNames() { List collationNames = new ArrayList<>(); - for (String locale: ICULocaleToId.keySet()) { - // CaseSensitivity.CS + AccentSensitivity.AS - collationNames.add(locale); - // CaseSensitivity.CS + AccentSensitivity.AI - collationNames.add(locale + "_AI"); - // CaseSensitivity.CI + AccentSensitivity.AS - collationNames.add(locale + "_CI"); - // CaseSensitivity.CI + AccentSensitivity.AI - collationNames.add(locale + "_CI_AI"); + List caseAccentSpecifiers = Arrays.asList("", "_AI", "_CI", "_CI_AI"); + for (String locale : ICULocaleToId.keySet()) { + for (String caseAccent : caseAccentSpecifiers) { + String collationName = locale + caseAccent; + collationNames.add(collationName); + } } return collationNames.stream().sorted().toList(); } @@ -921,6 +1157,18 @@ public static int collationNameToId(String collationName) throws SparkException return Collation.CollationSpec.collationNameToId(collationName); } + /** + * Returns whether the ICU collation is not Case Sensitive Accent Insensitive + * for the given collation id. + * This method is used in expressions which do not support CS_AI collations. + */ + public static boolean isCaseSensitiveAndAccentInsensitive(int collationId) { + return Collation.CollationSpecICU.fromCollationId(collationId).caseSensitivity == + Collation.CollationSpecICU.CaseSensitivity.CS && + Collation.CollationSpecICU.fromCollationId(collationId).accentSensitivity == + Collation.CollationSpecICU.AccentSensitivity.AI; + } + public static void assertValidProvider(String provider) throws SparkException { if (!SUPPORTED_PROVIDERS.contains(provider.toLowerCase())) { Map params = Map.of( @@ -947,24 +1195,32 @@ public static String[] getICULocaleNames() { public static UTF8String getCollationKey(UTF8String input, int collationId) { Collation collation = fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.supportsSpaceTrimming) { + input = Collation.CollationSpec.applyTrimmingPolicy(input, collationId); + } + if (collation.isUtf8BinaryType) { return input; - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return CollationAwareUTF8String.lowerCaseCodePoints(input); } else { - CollationKey collationKey = collation.collator.getCollationKey(input.toValidString()); + CollationKey collationKey = collation.collator.getCollationKey( + input.toValidString()); return UTF8String.fromBytes(collationKey.toByteArray()); } } public static byte[] getCollationKeyBytes(UTF8String input, int collationId) { Collation collation = fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.supportsSpaceTrimming) { + input = Collation.CollationSpec.applyTrimmingPolicy(input, collationId); + } + if (collation.isUtf8BinaryType) { return input.getBytes(); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return CollationAwareUTF8String.lowerCaseCodePoints(input).getBytes(); } else { - return collation.collator.getCollationKey(input.toValidString()).toByteArray(); + return collation.collator.getCollationKey( + input.toValidString()).toByteArray(); } } diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index f05d9e512568f..978b663cc25c9 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -37,9 +37,9 @@ public final class CollationSupport { public static class StringSplitSQL { public static UTF8String[] exec(final UTF8String s, final UTF8String d, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(s, d); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(s, d); } else { return execICU(s, d, collationId); @@ -48,9 +48,9 @@ public static UTF8String[] exec(final UTF8String s, final UTF8String d, final in public static String genCode(final String s, final String d, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringSplitSQL.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", s, d); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", s, d); } else { return String.format(expr + "ICU(%s, %s, %d)", s, d, collationId); @@ -71,9 +71,9 @@ public static UTF8String[] execICU(final UTF8String string, final UTF8String del public static class Contains { public static boolean exec(final UTF8String l, final UTF8String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(l, r); } else { return execICU(l, r, collationId); @@ -82,9 +82,9 @@ public static boolean exec(final UTF8String l, final UTF8String r, final int col public static String genCode(final String l, final String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.Contains.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", l, r); } else { return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId); @@ -109,9 +109,9 @@ public static class StartsWith { public static boolean exec(final UTF8String l, final UTF8String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(l, r); } else { return execICU(l, r, collationId); @@ -120,9 +120,9 @@ public static boolean exec(final UTF8String l, final UTF8String r, public static String genCode(final String l, final String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StartsWith.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", l, r); } else { return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId); @@ -146,9 +146,9 @@ public static boolean execICU(final UTF8String l, final UTF8String r, public static class EndsWith { public static boolean exec(final UTF8String l, final UTF8String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(l, r); } else { return execICU(l, r, collationId); @@ -157,9 +157,9 @@ public static boolean exec(final UTF8String l, final UTF8String r, final int col public static String genCode(final String l, final String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.EndsWith.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", l, r); } else { return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId); @@ -184,9 +184,9 @@ public static boolean execICU(final UTF8String l, final UTF8String r, public static class Upper { public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return useICU ? execBinaryICU(v) : execBinary(v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(v); } else { return execICU(v, collationId); @@ -195,10 +195,10 @@ public static UTF8String exec(final UTF8String v, final int collationId, boolean public static String genCode(final String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.Upper.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { String funcName = useICU ? "BinaryICU" : "Binary"; return String.format(expr + "%s(%s)", funcName, v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s)", v); } else { return String.format(expr + "ICU(%s, %d)", v, collationId); @@ -221,9 +221,9 @@ public static UTF8String execICU(final UTF8String v, final int collationId) { public static class Lower { public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return useICU ? execBinaryICU(v) : execBinary(v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(v); } else { return execICU(v, collationId); @@ -232,10 +232,10 @@ public static UTF8String exec(final UTF8String v, final int collationId, boolean public static String genCode(final String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.Lower.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { String funcName = useICU ? "BinaryICU" : "Binary"; return String.format(expr + "%s(%s)", funcName, v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s)", v); } else { return String.format(expr + "ICU(%s, %d)", v, collationId); @@ -258,9 +258,9 @@ public static UTF8String execICU(final UTF8String v, final int collationId) { public static class InitCap { public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return useICU ? execBinaryICU(v) : execBinary(v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(v); } else { return execICU(v, collationId); @@ -270,10 +270,10 @@ public static UTF8String exec(final UTF8String v, final int collationId, boolean public static String genCode(final String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.InitCap.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { String funcName = useICU ? "BinaryICU" : "Binary"; return String.format(expr + "%s(%s)", funcName, v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s)", v); } else { return String.format(expr + "ICU(%s, %d)", v, collationId); @@ -296,7 +296,7 @@ public static UTF8String execICU(final UTF8String v, final int collationId) { public static class FindInSet { public static int exec(final UTF8String word, final UTF8String set, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(word, set); } else { return execCollationAware(word, set, collationId); @@ -305,7 +305,7 @@ public static int exec(final UTF8String word, final UTF8String set, final int co public static String genCode(final String word, final String set, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.FindInSet.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", word, set); } else { return String.format(expr + "CollationAware(%s, %s, %d)", word, set, collationId); @@ -324,9 +324,9 @@ public static class StringInstr { public static int exec(final UTF8String string, final UTF8String substring, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(string, substring); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(string, substring); } else { return execICU(string, substring, collationId); @@ -336,9 +336,9 @@ public static String genCode(final String string, final String substring, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringInstr.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", string, substring); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", string, substring); } else { return String.format(expr + "ICU(%s, %s, %d)", string, substring, collationId); @@ -360,9 +360,9 @@ public static class StringReplace { public static UTF8String exec(final UTF8String src, final UTF8String search, final UTF8String replace, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(src, search, replace); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(src, search, replace); } else { return execICU(src, search, replace, collationId); @@ -372,9 +372,9 @@ public static String genCode(final String src, final String search, final String final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringReplace.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s, %s)", src, search, replace); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s, %s)", src, search, replace); } else { return String.format(expr + "ICU(%s, %s, %s, %d)", src, search, replace, collationId); @@ -398,9 +398,9 @@ public static class StringLocate { public static int exec(final UTF8String string, final UTF8String substring, final int start, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(string, substring, start); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(string, substring, start); } else { return execICU(string, substring, start, collationId); @@ -410,9 +410,9 @@ public static String genCode(final String string, final String substring, final final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringLocate.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s, %d)", string, substring, start); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s, %d)", string, substring, start); } else { return String.format(expr + "ICU(%s, %s, %d, %d)", string, substring, start, collationId); @@ -436,9 +436,9 @@ public static class SubstringIndex { public static UTF8String exec(final UTF8String string, final UTF8String delimiter, final int count, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(string, delimiter, count); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(string, delimiter, count); } else { return execICU(string, delimiter, count, collationId); @@ -448,9 +448,9 @@ public static String genCode(final String string, final String delimiter, final String count, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.SubstringIndex.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s, %s)", string, delimiter, count); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s, %s)", string, delimiter, count); } else { return String.format(expr + "ICU(%s, %s, %s, %d)", string, delimiter, count, collationId); @@ -474,9 +474,9 @@ public static class StringTranslate { public static UTF8String exec(final UTF8String source, Map dict, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(source, dict); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(source, dict); } else { return execICU(source, dict, collationId); @@ -503,9 +503,9 @@ public static UTF8String exec( final UTF8String trimString, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(srcString, trimString); } else { return execICU(srcString, trimString, collationId); @@ -520,9 +520,9 @@ public static String genCode( final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringTrim.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); } else { return String.format(expr + "ICU(%s, %s, %d)", srcString, trimString, collationId); @@ -559,9 +559,9 @@ public static UTF8String exec( final UTF8String trimString, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(srcString, trimString); } else { return execICU(srcString, trimString, collationId); @@ -576,9 +576,9 @@ public static String genCode( final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringTrimLeft.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); } else { return String.format(expr + "ICU(%s, %s, %d)", srcString, trimString, collationId); @@ -614,9 +614,9 @@ public static UTF8String exec( final UTF8String trimString, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(srcString, trimString); } else { return execICU(srcString, trimString, collationId); @@ -631,9 +631,9 @@ public static String genCode( final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringTrimRight.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); } else { return String.format(expr + "ICU(%s, %s, %d)", srcString, trimString, collationId); @@ -669,7 +669,7 @@ public static UTF8String execICU( public static boolean supportsLowercaseRegex(final int collationId) { // for regex, only Unicode case-insensitive matching is possible, // so UTF8_LCASE is treated as UNICODE_CI in this context - return CollationFactory.fetchCollation(collationId).supportsLowercaseEquality; + return CollationFactory.fetchCollation(collationId).isUtf8LcaseType; } static final int lowercaseRegexFlags = Pattern.UNICODE_CASE | Pattern.CASE_INSENSITIVE; diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index 5719303a0dce8..a445cde52ad57 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -629,6 +629,8 @@ public void testStartsWith() throws SparkException { assertStartsWith("İonic", "Io", "UTF8_LCASE", false); assertStartsWith("İonic", "i\u0307o", "UTF8_LCASE", true); assertStartsWith("İonic", "İo", "UTF8_LCASE", true); + assertStartsWith("oİ", "oİ", "UTF8_LCASE", true); + assertStartsWith("oİ", "oi̇", "UTF8_LCASE", true); // Conditional case mapping (e.g. Greek sigmas). assertStartsWith("σ", "σ", "UTF8_BINARY", true); assertStartsWith("σ", "ς", "UTF8_BINARY", false); @@ -880,6 +882,8 @@ public void testEndsWith() throws SparkException { assertEndsWith("the İo", "Io", "UTF8_LCASE", false); assertEndsWith("the İo", "i\u0307o", "UTF8_LCASE", true); assertEndsWith("the İo", "İo", "UTF8_LCASE", true); + assertEndsWith("İo", "İo", "UTF8_LCASE", true); + assertEndsWith("İo", "i̇o", "UTF8_LCASE", true); // Conditional case mapping (e.g. Greek sigmas). assertEndsWith("σ", "σ", "UTF8_BINARY", true); assertEndsWith("σ", "ς", "UTF8_BINARY", false); diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala index 321d1ccd700f2..df9af1579d4f1 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala @@ -38,22 +38,22 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig assert(UTF8_BINARY_COLLATION_ID == 0) val utf8Binary = fetchCollation(UTF8_BINARY_COLLATION_ID) assert(utf8Binary.collationName == "UTF8_BINARY") - assert(utf8Binary.supportsBinaryEquality) + assert(utf8Binary.isUtf8BinaryType) assert(UTF8_LCASE_COLLATION_ID == 1) val utf8Lcase = fetchCollation(UTF8_LCASE_COLLATION_ID) assert(utf8Lcase.collationName == "UTF8_LCASE") - assert(!utf8Lcase.supportsBinaryEquality) + assert(!utf8Lcase.isUtf8BinaryType) assert(UNICODE_COLLATION_ID == (1 << 29)) val unicode = fetchCollation(UNICODE_COLLATION_ID) assert(unicode.collationName == "UNICODE") - assert(!unicode.supportsBinaryEquality) + assert(!unicode.isUtf8BinaryType) assert(UNICODE_CI_COLLATION_ID == ((1 << 29) | (1 << 17))) val unicodeCi = fetchCollation(UNICODE_CI_COLLATION_ID) assert(unicodeCi.collationName == "UNICODE_CI") - assert(!unicodeCi.supportsBinaryEquality) + assert(!unicodeCi.isUtf8BinaryType) } test("UTF8_BINARY and ICU root locale collation names") { @@ -127,6 +127,11 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_BINARY", "aaa", "AAA", false), CollationTestCase("UTF8_BINARY", "aaa", "bbb", false), CollationTestCase("UTF8_BINARY", "å", "a\u030A", false), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "aaa", true), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "aaa ", true), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "aaa ", true), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa", " aaa ", false), + CollationTestCase("UTF8_BINARY_RTRIM", " ", " ", true), CollationTestCase("UTF8_LCASE", "aaa", "aaa", true), CollationTestCase("UTF8_LCASE", "aaa", "AAA", true), CollationTestCase("UTF8_LCASE", "aaa", "AaA", true), @@ -134,15 +139,30 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_LCASE", "aaa", "aa", false), CollationTestCase("UTF8_LCASE", "aaa", "bbb", false), CollationTestCase("UTF8_LCASE", "å", "a\u030A", false), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa", "AaA", true), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa", "AaA ", true), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AaA ", true), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa", " AaA ", false), + CollationTestCase("UTF8_LCASE_RTRIM", " ", " ", true), CollationTestCase("UNICODE", "aaa", "aaa", true), CollationTestCase("UNICODE", "aaa", "AAA", false), CollationTestCase("UNICODE", "aaa", "bbb", false), CollationTestCase("UNICODE", "å", "a\u030A", true), + CollationTestCase("UNICODE_RTRIM", "aaa", "aaa", true), + CollationTestCase("UNICODE_RTRIM", "aaa", "aaa ", true), + CollationTestCase("UNICODE_RTRIM", "aaa ", "aaa ", true), + CollationTestCase("UNICODE_RTRIM", "aaa", " aaa ", false), + CollationTestCase("UNICODE_RTRIM", " ", " ", true), CollationTestCase("UNICODE_CI", "aaa", "aaa", true), CollationTestCase("UNICODE_CI", "aaa", "AAA", true), CollationTestCase("UNICODE_CI", "aaa", "bbb", false), CollationTestCase("UNICODE_CI", "å", "a\u030A", true), - CollationTestCase("UNICODE_CI", "Å", "a\u030A", true) + CollationTestCase("UNICODE_CI", "Å", "a\u030A", true), + CollationTestCase("UNICODE_CI_RTRIM", "aaa", "AaA", true), + CollationTestCase("UNICODE_CI_RTRIM", "aaa", "AaA ", true), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AaA ", true), + CollationTestCase("UNICODE_CI_RTRIM", "aaa", " AaA ", false), + CollationTestCase("UNICODE_RTRIM", " ", " ", true) ) checks.foreach(testCase => { @@ -162,19 +182,48 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_BINARY", "aaa", "AAA", 1), CollationTestCase("UTF8_BINARY", "aaa", "bbb", -1), CollationTestCase("UTF8_BINARY", "aaa", "BBB", 1), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "aaa", 0), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "aaa ", 0), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "bbb", -1), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "bbb ", -1), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "BBB" , 1), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "BBB " , 1), + CollationTestCase("UTF8_BINARY_RTRIM", " ", " " , 0), CollationTestCase("UTF8_LCASE", "aaa", "aaa", 0), CollationTestCase("UTF8_LCASE", "aaa", "AAA", 0), CollationTestCase("UTF8_LCASE", "aaa", "AaA", 0), CollationTestCase("UTF8_LCASE", "aaa", "AaA", 0), CollationTestCase("UTF8_LCASE", "aaa", "aa", 1), CollationTestCase("UTF8_LCASE", "aaa", "bbb", -1), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AAA", 0), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AAA ", 0), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa", "bbb ", -1), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "bbb ", -1), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "aa", 1), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "aa ", 1), + CollationTestCase("UTF8_LCASE_RTRIM", " ", " ", 0), CollationTestCase("UNICODE", "aaa", "aaa", 0), CollationTestCase("UNICODE", "aaa", "AAA", -1), CollationTestCase("UNICODE", "aaa", "bbb", -1), CollationTestCase("UNICODE", "aaa", "BBB", -1), + CollationTestCase("UNICODE_RTRIM", "aaa ", "aaa", 0), + CollationTestCase("UNICODE_RTRIM", "aaa ", "aaa ", 0), + CollationTestCase("UNICODE_RTRIM", "aaa ", "bbb", -1), + CollationTestCase("UNICODE_RTRIM", "aaa ", "bbb ", -1), + CollationTestCase("UNICODE_RTRIM", "aaa", "BBB" , -1), + CollationTestCase("UNICODE_RTRIM", "aaa ", "BBB " , -1), + CollationTestCase("UNICODE_RTRIM", " ", " ", 0), CollationTestCase("UNICODE_CI", "aaa", "aaa", 0), CollationTestCase("UNICODE_CI", "aaa", "AAA", 0), - CollationTestCase("UNICODE_CI", "aaa", "bbb", -1)) + CollationTestCase("UNICODE_CI", "aaa", "bbb", -1), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AAA", 0), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AAA ", 0), + CollationTestCase("UNICODE_CI_RTRIM", "aaa", "bbb ", -1), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "bbb ", -1), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "aa", 1), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "aa ", 1), + CollationTestCase("UNICODE_CI_RTRIM", " ", " ", 0) + ) checks.foreach(testCase => { val collation = fetchCollation(testCase.collationName) @@ -369,9 +418,9 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig 1 << 15, // UTF8_BINARY mandatory zero bit 15 breach. 1 << 16, // UTF8_BINARY mandatory zero bit 16 breach. 1 << 17, // UTF8_BINARY mandatory zero bit 17 breach. - 1 << 18, // UTF8_BINARY mandatory zero bit 18 breach. 1 << 19, // UTF8_BINARY mandatory zero bit 19 breach. 1 << 20, // UTF8_BINARY mandatory zero bit 20 breach. + 1 << 21, // UTF8_BINARY mandatory zero bit 21 breach. 1 << 23, // UTF8_BINARY mandatory zero bit 23 breach. 1 << 24, // UTF8_BINARY mandatory zero bit 24 breach. 1 << 25, // UTF8_BINARY mandatory zero bit 25 breach. @@ -382,7 +431,6 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig (1 << 29) | (1 << 13), // ICU mandatory zero bit 13 breach. (1 << 29) | (1 << 14), // ICU mandatory zero bit 14 breach. (1 << 29) | (1 << 15), // ICU mandatory zero bit 15 breach. - (1 << 29) | (1 << 18), // ICU mandatory zero bit 18 breach. (1 << 29) | (1 << 19), // ICU mandatory zero bit 19 breach. (1 << 29) | (1 << 20), // ICU mandatory zero bit 20 breach. (1 << 29) | (1 << 21), // ICU mandatory zero bit 21 breach. @@ -457,7 +505,7 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig val e = intercept[SparkException] { fetchCollation(collationName) } - assert(e.getErrorClass === "COLLATION_INVALID_NAME") + assert(e.getCondition === "COLLATION_INVALID_NAME") assert(e.getMessageParameters.asScala === Map( "collationName" -> collationName, "proposals" -> proposals)) } diff --git a/common/utils/src/main/java/org/apache/spark/SparkThrowable.java b/common/utils/src/main/java/org/apache/spark/SparkThrowable.java index e1235b2982ba0..39808f58b08ae 100644 --- a/common/utils/src/main/java/org/apache/spark/SparkThrowable.java +++ b/common/utils/src/main/java/org/apache/spark/SparkThrowable.java @@ -35,19 +35,29 @@ */ @Evolving public interface SparkThrowable { - // Succinct, human-readable, unique, and consistent representation of the error category - // If null, error class is not set - String getErrorClass(); + /** + * Succinct, human-readable, unique, and consistent representation of the error condition. + * If null, error condition is not set. + */ + String getCondition(); + + /** + * Succinct, human-readable, unique, and consistent representation of the error category. + * If null, error class is not set. + * @deprecated Use {@link #getCondition()} instead. + */ + @Deprecated + default String getErrorClass() { return getCondition(); } // Portable error identifier across SQL engines // If null, error class or SQLSTATE is not set default String getSqlState() { - return SparkThrowableHelper.getSqlState(this.getErrorClass()); + return SparkThrowableHelper.getSqlState(this.getCondition()); } // True if this error is an internal error. default boolean isInternalError() { - return SparkThrowableHelper.isInternalError(this.getErrorClass()); + return SparkThrowableHelper.isInternalError(this.getCondition()); } default Map getMessageParameters() { diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 25dd676c4aff9..8b2a57d6da3dd 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1,4 +1,10 @@ { + "ADD_DEFAULT_UNSUPPORTED" : { + "message" : [ + "Failed to execute command because DEFAULT values are not supported when adding new columns to previously existing target data source with table provider: \"\"." + ], + "sqlState" : "42623" + }, "AGGREGATE_FUNCTION_WITH_NONDETERMINISTIC_EXPRESSION" : { "message" : [ "Non-deterministic expression should not appear in the arguments of an aggregate function." @@ -121,7 +127,7 @@ }, "BINARY_ARITHMETIC_OVERFLOW" : { "message" : [ - " caused overflow." + " caused overflow. Use to ignore overflow problem and return NULL." ], "sqlState" : "22003" }, @@ -612,6 +618,13 @@ ], "sqlState" : "42703" }, + "COLUMN_ORDINAL_OUT_OF_BOUNDS" : { + "message" : [ + "Column ordinal out of bounds. The number of columns in the table is , but the column ordinal is .", + "Attributes are the following: ." + ], + "sqlState" : "22003" + }, "COMPARATOR_RETURNS_NULL" : { "message" : [ "The comparator has returned a NULL for a comparison between and .", @@ -625,6 +638,11 @@ "Cannot process input data types for the expression: ." ], "subClass" : { + "BAD_INPUTS" : { + "message" : [ + "The input data types to must be valid, but found the input types ." + ] + }, "MISMATCHED_TYPES" : { "message" : [ "All input types must be the same except nullable, containsNull, valueContainsNull flags, but found the input types ." @@ -651,6 +669,16 @@ ], "sqlState" : "40000" }, + "CONFLICTING_DIRECTORY_STRUCTURES" : { + "message" : [ + "Conflicting directory structures detected.", + "Suspicious paths:", + "", + "If provided paths are partition directories, please set \"basePath\" in the options of the data source to specify the root directory of the table.", + "If there are multiple root directories, please load them separately and then union them." + ], + "sqlState" : "KD009" + }, "CONFLICTING_PARTITION_COLUMN_NAMES" : { "message" : [ "Conflicting partition column names detected:", @@ -1049,7 +1077,7 @@ "message" : [ "Encountered error when saving to external data source." ], - "sqlState" : "KD00F" + "sqlState" : "KD010" }, "DATA_SOURCE_NOT_EXIST" : { "message" : [ @@ -1096,6 +1124,12 @@ ], "sqlState" : "42608" }, + "DEFAULT_UNSUPPORTED" : { + "message" : [ + "Failed to execute command because DEFAULT values are not supported for target data source with table provider: \"\"." + ], + "sqlState" : "42623" + }, "DISTINCT_WINDOW_FUNCTION_UNSUPPORTED" : { "message" : [ "Distinct window functions are not supported: ." @@ -1178,6 +1212,12 @@ ], "sqlState" : "42604" }, + "EMPTY_SCHEMA_NOT_SUPPORTED_FOR_DATASOURCE" : { + "message" : [ + "The datasource does not support writing empty or nested empty schemas. Please make sure the data schema has at least one or more column(s)." + ], + "sqlState" : "0A000" + }, "ENCODER_NOT_FOUND" : { "message" : [ "Not found an encoder of the type to Spark SQL internal representation.", @@ -1444,6 +1484,12 @@ ], "sqlState" : "2203G" }, + "FAILED_TO_LOAD_ROUTINE" : { + "message" : [ + "Failed to load routine ." + ], + "sqlState" : "38000" + }, "FAILED_TO_PARSE_TOO_COMPLEX" : { "message" : [ "The statement, including potential SQL functions and referenced views, was too complex to parse.", @@ -2497,6 +2543,11 @@ "Interval string does not match second-nano format of ss.nnnnnnnnn." ] }, + "TIMEZONE_INTERVAL_OUT_OF_RANGE" : { + "message" : [ + "The interval value must be in the range of [-18, +18] hours with second precision." + ] + }, "UNKNOWN_PARSING_ERROR" : { "message" : [ "Unknown error when parsing ." @@ -2548,6 +2599,13 @@ }, "sqlState" : "42K0K" }, + "INVALID_JAVA_IDENTIFIER_AS_FIELD_NAME" : { + "message" : [ + " is not a valid identifier of Java and cannot be used as field name", + "." + ], + "sqlState" : "46121" + }, "INVALID_JOIN_TYPE_FOR_JOINWITH" : { "message" : [ "Invalid join type in joinWith: ." @@ -3005,7 +3063,7 @@ }, "MULTI_PART_NAME" : { "message" : [ - " with multiple part function name() is not allowed." + " with multiple part name() is not allowed." ] }, "OPTION_IS_INVALID" : { @@ -3083,6 +3141,12 @@ ], "sqlState" : "42K0F" }, + "INVALID_TIMEZONE" : { + "message" : [ + "The timezone: is invalid. The timezone must be either a region-based zone ID or a zone offset. Region IDs must have the form 'area/city', such as 'America/Los_Angeles'. Zone offsets must be in the format '(+|-)HH', '(+|-)HH:mm’ or '(+|-)HH:mm:ss', e.g '-08' , '+01:00' or '-13:33:33', and must be in the range from -18:00 to +18:00. 'Z' and 'UTC' are accepted as synonyms for '+00:00'." + ], + "sqlState" : "22009" + }, "INVALID_TIME_TRAVEL_SPEC" : { "message" : [ "Cannot specify both version and timestamp when time travelling the table." @@ -3772,6 +3836,18 @@ ], "sqlState" : "428FT" }, + "PARTITION_COLUMN_NOT_FOUND_IN_SCHEMA" : { + "message" : [ + "Partition column not found in schema . Please provide the existing column for partitioning." + ], + "sqlState" : "42000" + }, + "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY" : { + "message" : [ + "The expression must be inside 'partitionedBy'." + ], + "sqlState" : "42S23" + }, "PATH_ALREADY_EXISTS" : { "message" : [ "Path already exists. Set mode as \"overwrite\" to overwrite the existing path." @@ -3944,6 +4020,18 @@ ], "sqlState" : "22023" }, + "SCALAR_FUNCTION_NOT_COMPATIBLE" : { + "message" : [ + "ScalarFunction not overrides method 'produceResult(InternalRow)' with custom implementation." + ], + "sqlState" : "42K0O" + }, + "SCALAR_FUNCTION_NOT_FULLY_IMPLEMENTED" : { + "message" : [ + "ScalarFunction not implements or overrides method 'produceResult(InternalRow)'." + ], + "sqlState" : "42K0P" + }, "SCALAR_SUBQUERY_IS_IN_GROUP_BY_OR_AGGREGATE_FUNCTION" : { "message" : [ "The correlated scalar subquery '' is neither present in GROUP BY, nor in an aggregate function.", @@ -4445,6 +4533,12 @@ ], "sqlState" : "428EK" }, + "TRAILING_COMMA_IN_SELECT" : { + "message" : [ + "Trailing comma detected in SELECT clause. Remove the trailing comma before the FROM clause." + ], + "sqlState" : "42601" + }, "TRANSPOSE_EXCEED_ROW_LIMIT" : { "message" : [ "Number of rows exceeds the allowed limit of for TRANSPOSE. If this was intended, set to at least the current row count." @@ -4868,11 +4962,6 @@ "Catalog does not support ." ] }, - "COLLATION" : { - "message" : [ - "Collation is not yet supported." - ] - }, "COMBINATION_QUERY_RESULT_CLAUSES" : { "message" : [ "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY." @@ -5099,6 +5188,16 @@ "message" : [ "TRANSFORM with SERDE is only supported in hive mode." ] + }, + "TRIM_COLLATION" : { + "message" : [ + "TRIM specifier in the collation." + ] + }, + "UPDATE_COLUMN_NULLABILITY" : { + "message" : [ + "Update column nullability for MySQL and MS SQL Server." + ] } }, "sqlState" : "0A000" @@ -5627,11 +5726,6 @@ "Expected format is 'RESET' or 'RESET key'. If you want to include special characters in key, please use quotes, e.g., RESET `key`." ] }, - "_LEGACY_ERROR_TEMP_0044" : { - "message" : [ - "The interval value must be in the range of [-18, +18] hours with second precision." - ] - }, "_LEGACY_ERROR_TEMP_0045" : { "message" : [ "Invalid time zone displacement value." @@ -6049,11 +6143,6 @@ " is not a valid Spark SQL Data Source." ] }, - "_LEGACY_ERROR_TEMP_1136" : { - "message" : [ - "Cannot save interval data type into external storage." - ] - }, "_LEGACY_ERROR_TEMP_1137" : { "message" : [ "Unable to resolve given []." @@ -6079,11 +6168,6 @@ "Multiple sources found for (), please specify the fully qualified class name." ] }, - "_LEGACY_ERROR_TEMP_1142" : { - "message" : [ - "Datasource does not support writing empty or nested empty schemas. Please make sure the data schema has at least one or more column(s)." - ] - }, "_LEGACY_ERROR_TEMP_1143" : { "message" : [ "The data to be inserted needs to have the same number of columns as the target table: target table has column(s) but the inserted data has column(s), which contain partition column(s) having assigned constant values." @@ -6638,11 +6722,6 @@ "The pivot column has more than distinct values, this could indicate an error. If this was intended, set to at least the number of distinct values of the pivot column." ] }, - "_LEGACY_ERROR_TEMP_1325" : { - "message" : [ - "Cannot modify the value of a static config: ." - ] - }, "_LEGACY_ERROR_TEMP_1327" : { "message" : [ "Command execution is not supported in runner ." @@ -6673,21 +6752,6 @@ "Sinks cannot request distribution and ordering in continuous execution mode." ] }, - "_LEGACY_ERROR_TEMP_1344" : { - "message" : [ - "Invalid DEFAULT value for column : fails to parse as a valid literal value." - ] - }, - "_LEGACY_ERROR_TEMP_1345" : { - "message" : [ - "Failed to execute command because DEFAULT values are not supported for target data source with table provider: \"\"." - ] - }, - "_LEGACY_ERROR_TEMP_1346" : { - "message" : [ - "Failed to execute command because DEFAULT values are not supported when adding new columns to previously existing target data source with table provider: \"\"." - ] - }, "_LEGACY_ERROR_TEMP_2000" : { "message" : [ ". If necessary set to false to bypass this error." @@ -7007,7 +7071,7 @@ }, "_LEGACY_ERROR_TEMP_2097" : { "message" : [ - "Could not execute broadcast in secs. You can increase the timeout for broadcasts via or disable broadcast join by setting to -1." + "Could not execute broadcast in secs. You can increase the timeout for broadcasts via or disable broadcast join by setting to -1 or remove the broadcast hint if it exists in your code." ] }, "_LEGACY_ERROR_TEMP_2098" : { @@ -7145,12 +7209,6 @@ "cannot have circular references in class, but got the circular reference of class ." ] }, - "_LEGACY_ERROR_TEMP_2140" : { - "message" : [ - "`` is not a valid identifier of Java and cannot be used as field name", - "." - ] - }, "_LEGACY_ERROR_TEMP_2144" : { "message" : [ "Unable to find constructor for . This could happen if is an interface, or a trait without companion object constructor." @@ -7639,11 +7697,6 @@ "comment on table is not supported." ] }, - "_LEGACY_ERROR_TEMP_2271" : { - "message" : [ - "UpdateColumnNullability is not supported." - ] - }, "_LEGACY_ERROR_TEMP_2272" : { "message" : [ "Rename column is only supported for MySQL version 8.0 and above." @@ -7908,11 +7961,6 @@ " is not currently supported" ] }, - "_LEGACY_ERROR_TEMP_3055" : { - "message" : [ - "ScalarFunction neither implement magic method nor override 'produceResult'" - ] - }, "_LEGACY_ERROR_TEMP_3056" : { "message" : [ "Unexpected row-level read relations (allow multiple = ): " @@ -8271,11 +8319,6 @@ "Partitions truncate is not supported" ] }, - "_LEGACY_ERROR_TEMP_3146" : { - "message" : [ - "Cannot find a compatible ScalarFunction#produceResult" - ] - }, "_LEGACY_ERROR_TEMP_3147" : { "message" : [ ": Batch scan are not supported" diff --git a/common/utils/src/main/resources/error/error-states.json b/common/utils/src/main/resources/error/error-states.json index edba6e1d43216..fb899e4eb207e 100644 --- a/common/utils/src/main/resources/error/error-states.json +++ b/common/utils/src/main/resources/error/error-states.json @@ -4631,6 +4631,18 @@ "standard": "N", "usedBy": ["Spark"] }, + "42K0O": { + "description": "ScalarFunction not overrides method 'produceResult(InternalRow)' with custom implementation.", + "origin": "Spark", + "standard": "N", + "usedBy": ["Spark"] + }, + "42K0P": { + "description": "ScalarFunction not implements or overrides method 'produceResult(InternalRow)'.", + "origin": "Spark", + "standard": "N", + "usedBy": ["Spark"] + }, "42KD0": { "description": "Ambiguous name reference.", "origin": "Databricks", @@ -4901,6 +4913,12 @@ "standard": "N", "usedBy": ["SQL Server"] }, + "42S23": { + "description": "Partition transform expression not in 'partitionedBy'", + "origin": "Spark", + "standard": "N", + "usedBy": ["Spark"] + }, "44000": { "description": "with check option violation", "origin": "SQL/Foundation", @@ -7417,7 +7435,7 @@ "standard": "N", "usedBy": ["Databricks"] }, - "KD00F": { + "KD010": { "description": "external data source failure", "origin": "Databricks", "standard": "N", diff --git a/common/utils/src/main/scala/org/apache/spark/SparkException.scala b/common/utils/src/main/scala/org/apache/spark/SparkException.scala index 398cb1fad6726..0c0a1902ee2a1 100644 --- a/common/utils/src/main/scala/org/apache/spark/SparkException.scala +++ b/common/utils/src/main/scala/org/apache/spark/SparkException.scala @@ -69,7 +69,7 @@ class SparkException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } @@ -179,7 +179,7 @@ private[spark] class SparkUpgradeException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull } /** @@ -212,7 +212,7 @@ private[spark] class SparkArithmeticException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } @@ -250,7 +250,7 @@ private[spark] class SparkUnsupportedOperationException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull } private[spark] object SparkUnsupportedOperationException { @@ -280,7 +280,7 @@ private[spark] class SparkClassNotFoundException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass } /** @@ -296,7 +296,7 @@ private[spark] class SparkConcurrentModificationException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass } /** @@ -306,8 +306,9 @@ private[spark] class SparkDateTimeException private( message: String, errorClass: Option[String], messageParameters: Map[String, String], - context: Array[QueryContext]) - extends DateTimeException(message) with SparkThrowable { + context: Array[QueryContext], + cause: Option[Throwable]) + extends DateTimeException(message, cause.orNull) with SparkThrowable { def this( errorClass: String, @@ -318,7 +319,23 @@ private[spark] class SparkDateTimeException private( SparkThrowableHelper.getMessage(errorClass, messageParameters, summary), Option(errorClass), messageParameters, - context + context, + cause = None + ) + } + + def this( + errorClass: String, + messageParameters: Map[String, String], + context: Array[QueryContext], + summary: String, + cause: Option[Throwable]) = { + this( + SparkThrowableHelper.getMessage(errorClass, messageParameters, summary), + Option(errorClass), + messageParameters, + context, + cause.orElse(None) ) } @@ -329,7 +346,7 @@ private[spark] class SparkDateTimeException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } @@ -345,7 +362,7 @@ private[spark] class SparkFileNotFoundException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass } /** @@ -379,7 +396,7 @@ private[spark] class SparkNumberFormatException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } @@ -431,7 +448,7 @@ private[spark] class SparkIllegalArgumentException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } @@ -460,7 +477,7 @@ private[spark] class SparkRuntimeException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } @@ -489,7 +506,7 @@ private[spark] class SparkPythonException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } @@ -507,7 +524,7 @@ private[spark] class SparkNoSuchElementException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass override def getQueryContext: Array[QueryContext] = context } @@ -524,7 +541,7 @@ private[spark] class SparkSecurityException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass } /** @@ -558,7 +575,7 @@ private[spark] class SparkArrayIndexOutOfBoundsException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } @@ -574,7 +591,7 @@ private[spark] class SparkSQLException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass } /** @@ -589,5 +606,5 @@ private[spark] class SparkSQLFeatureNotSupportedException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass } diff --git a/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala b/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala index 428c9d2a49351..b6c2b176de62b 100644 --- a/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala +++ b/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala @@ -81,7 +81,7 @@ private[spark] object SparkThrowableHelper { import ErrorMessageFormat._ format match { case PRETTY => e.getMessage - case MINIMAL | STANDARD if e.getErrorClass == null => + case MINIMAL | STANDARD if e.getCondition == null => toJsonString { generator => val g = generator.useDefaultPrettyPrinter() g.writeStartObject() @@ -92,7 +92,7 @@ private[spark] object SparkThrowableHelper { g.writeEndObject() } case MINIMAL | STANDARD => - val errorClass = e.getErrorClass + val errorClass = e.getCondition toJsonString { generator => val g = generator.useDefaultPrettyPrinter() g.writeStartObject() diff --git a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala index a7e4f186000b5..12d456a371d07 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala @@ -266,6 +266,7 @@ private[spark] object LogKeys { case object FEATURE_NAME extends LogKey case object FETCH_SIZE extends LogKey case object FIELD_NAME extends LogKey + case object FIELD_TYPE extends LogKey case object FILES extends LogKey case object FILE_ABSOLUTE_PATH extends LogKey case object FILE_END_OFFSET extends LogKey @@ -652,6 +653,7 @@ private[spark] object LogKeys { case object RECEIVER_IDS extends LogKey case object RECORDS extends LogKey case object RECOVERY_STATE extends LogKey + case object RECURSIVE_DEPTH extends LogKey case object REDACTED_STATEMENT extends LogKey case object REDUCE_ID extends LogKey case object REGEX extends LogKey diff --git a/common/utils/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala b/common/utils/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala index 259f4330224c9..1972ef05d8759 100644 --- a/common/utils/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala +++ b/common/utils/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala @@ -84,7 +84,7 @@ class StreamingQueryException private[sql]( s"""${classOf[StreamingQueryException].getName}: ${cause.getMessage} |$queryDebugString""".stripMargin - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava } diff --git a/common/utils/src/main/scala/org/apache/spark/util/JsonUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/JsonUtils.scala index 4d729adfbb7eb..f88f267727c11 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/JsonUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/JsonUtils.scala @@ -24,6 +24,7 @@ import com.fasterxml.jackson.core.{JsonEncoding, JsonGenerator} import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper} import com.fasterxml.jackson.module.scala.DefaultScalaModule +import org.apache.spark.util.SparkErrorUtils.tryWithResource private[spark] trait JsonUtils { @@ -31,12 +32,12 @@ private[spark] trait JsonUtils { .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) def toJsonString(block: JsonGenerator => Unit): String = { - val baos = new ByteArrayOutputStream() - val generator = mapper.createGenerator(baos, JsonEncoding.UTF8) - block(generator) - generator.close() - baos.close() - new String(baos.toByteArray, StandardCharsets.UTF_8) + tryWithResource(new ByteArrayOutputStream()) { baos => + tryWithResource(mapper.createGenerator(baos, JsonEncoding.UTF8)) { generator => + block(generator) + } + new String(baos.toByteArray, StandardCharsets.UTF_8) + } } } diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala index 7d80998d96eb1..0b85b208242cb 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala @@ -42,7 +42,8 @@ private[sql] case class AvroDataToCatalyst( val dt = SchemaConverters.toSqlType( expectedSchema, avroOptions.useStableIdForUnionType, - avroOptions.stableIdPrefixForUnionType).dataType + avroOptions.stableIdPrefixForUnionType, + avroOptions.recursiveFieldMaxDepth).dataType parseMode match { // With PermissiveMode, the output Catalyst row might contain columns of null values for // corrupt records, even if some of the columns are not nullable in the user-provided schema. @@ -69,7 +70,8 @@ private[sql] case class AvroDataToCatalyst( dataType, avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType, - avroOptions.stableIdPrefixForUnionType) + avroOptions.stableIdPrefixForUnionType, + avroOptions.recursiveFieldMaxDepth) @transient private var decoder: BinaryDecoder = _ diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index 877c3f89e88c0..ac20614553ca2 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -51,14 +51,16 @@ private[sql] class AvroDeserializer( datetimeRebaseSpec: RebaseSpec, filters: StructFilters, useStableIdForUnionType: Boolean, - stableIdPrefixForUnionType: String) { + stableIdPrefixForUnionType: String, + recursiveFieldMaxDepth: Int) { def this( rootAvroType: Schema, rootCatalystType: DataType, datetimeRebaseMode: String, useStableIdForUnionType: Boolean, - stableIdPrefixForUnionType: String) = { + stableIdPrefixForUnionType: String, + recursiveFieldMaxDepth: Int) = { this( rootAvroType, rootCatalystType, @@ -66,7 +68,8 @@ private[sql] class AvroDeserializer( RebaseSpec(LegacyBehaviorPolicy.withName(datetimeRebaseMode)), new NoopFilters, useStableIdForUnionType, - stableIdPrefixForUnionType) + stableIdPrefixForUnionType, + recursiveFieldMaxDepth) } private lazy val decimalConversions = new DecimalConversion() @@ -128,7 +131,8 @@ private[sql] class AvroDeserializer( s"schema is incompatible (avroType = $avroType, sqlType = ${catalystType.sql})" val realDataType = SchemaConverters.toSqlType( - avroType, useStableIdForUnionType, stableIdPrefixForUnionType).dataType + avroType, useStableIdForUnionType, stableIdPrefixForUnionType, + recursiveFieldMaxDepth).dataType (avroType.getType, catalystType) match { case (NULL, NullType) => (updater, ordinal, _) => diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index 372f24b54f5c4..264c3a1f48abe 100755 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -145,7 +145,8 @@ private[sql] class AvroFileFormat extends FileFormat datetimeRebaseMode, avroFilters, parsedOptions.useStableIdForUnionType, - parsedOptions.stableIdPrefixForUnionType) + parsedOptions.stableIdPrefixForUnionType, + parsedOptions.recursiveFieldMaxDepth) override val stopPosition = file.start + file.length override def hasNext: Boolean = hasNextRow diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala index 4332904339f19..e0c6ad3ee69d3 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions} import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, FailFastMode, ParseMode} +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf /** @@ -136,6 +137,15 @@ private[sql] class AvroOptions( val stableIdPrefixForUnionType: String = parameters .getOrElse(STABLE_ID_PREFIX_FOR_UNION_TYPE, "member_") + + val recursiveFieldMaxDepth: Int = + parameters.get(RECURSIVE_FIELD_MAX_DEPTH).map(_.toInt).getOrElse(-1) + + if (recursiveFieldMaxDepth > RECURSIVE_FIELD_MAX_DEPTH_LIMIT) { + throw QueryCompilationErrors.avroOptionsException( + RECURSIVE_FIELD_MAX_DEPTH, + s"Should not be greater than $RECURSIVE_FIELD_MAX_DEPTH_LIMIT.") + } } private[sql] object AvroOptions extends DataSourceOptions { @@ -170,4 +180,25 @@ private[sql] object AvroOptions extends DataSourceOptions { // When STABLE_ID_FOR_UNION_TYPE is enabled, the option allows to configure the prefix for fields // of Avro Union type. val STABLE_ID_PREFIX_FOR_UNION_TYPE = newOption("stableIdentifierPrefixForUnionType") + + /** + * Adds support for recursive fields. If this option is not specified or is set to 0, recursive + * fields are not permitted. Setting it to 1 drops all recursive fields, 2 allows recursive + * fields to be recursed once, and 3 allows it to be recursed twice and so on, up to 15. + * Values larger than 15 are not allowed in order to avoid inadvertently creating very large + * schemas. If an avro message has depth beyond this limit, the Spark struct returned is + * truncated after the recursion limit. + * + * Examples: Consider an Avro schema with a recursive field: + * {"type" : "record", "name" : "Node", "fields" : [{"name": "Id", "type": "int"}, + * {"name": "Next", "type": ["null", "Node"]}]} + * The following lists the parsed schema with different values for this setting. + * 1: `struct` + * 2: `struct>` + * 3: `struct>>` + * and so on. + */ + val RECURSIVE_FIELD_MAX_DEPTH = newOption("recursiveFieldMaxDepth") + + val RECURSIVE_FIELD_MAX_DEPTH_LIMIT: Int = 15 } diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala index 7cbc30f1fb3dc..594ebb4716c41 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala @@ -65,7 +65,8 @@ private[sql] object AvroUtils extends Logging { SchemaConverters.toSqlType( avroSchema, parsedOptions.useStableIdForUnionType, - parsedOptions.stableIdPrefixForUnionType).dataType match { + parsedOptions.stableIdPrefixForUnionType, + parsedOptions.recursiveFieldMaxDepth).dataType match { case t: StructType => Some(t) case _ => throw new RuntimeException( s"""Avro schema cannot be converted to a Spark SQL StructType: diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala index b2285aa966ddb..1168a887abd8e 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -27,6 +27,10 @@ import org.apache.avro.LogicalTypes.{Date, Decimal, LocalTimestampMicros, LocalT import org.apache.avro.Schema.Type._ import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.Logging +import org.apache.spark.internal.LogKeys.{FIELD_NAME, FIELD_TYPE, RECURSIVE_DEPTH} +import org.apache.spark.internal.MDC +import org.apache.spark.sql.avro.AvroOptions.RECURSIVE_FIELD_MAX_DEPTH_LIMIT import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.types._ import org.apache.spark.sql.types.Decimal.minBytesForPrecision @@ -36,7 +40,7 @@ import org.apache.spark.sql.types.Decimal.minBytesForPrecision * versa. */ @DeveloperApi -object SchemaConverters { +object SchemaConverters extends Logging { private lazy val nullSchema = Schema.create(Schema.Type.NULL) /** @@ -48,14 +52,27 @@ object SchemaConverters { /** * Converts an Avro schema to a corresponding Spark SQL schema. - * + * + * @param avroSchema The Avro schema to convert. + * @param useStableIdForUnionType If true, Avro schema is deserialized into Spark SQL schema, + * and the Avro Union type is transformed into a structure where + * the field names remain consistent with their respective types. + * @param stableIdPrefixForUnionType The prefix to use to configure the prefix for fields of + * Avro Union type + * @param recursiveFieldMaxDepth The maximum depth to recursively process fields in Avro schema. + * -1 means not supported. * @since 4.0.0 */ def toSqlType( avroSchema: Schema, useStableIdForUnionType: Boolean, - stableIdPrefixForUnionType: String): SchemaType = { - toSqlTypeHelper(avroSchema, Set.empty, useStableIdForUnionType, stableIdPrefixForUnionType) + stableIdPrefixForUnionType: String, + recursiveFieldMaxDepth: Int = -1): SchemaType = { + val schema = toSqlTypeHelper(avroSchema, Map.empty, useStableIdForUnionType, + stableIdPrefixForUnionType, recursiveFieldMaxDepth) + // the top level record should never return null + assert(schema != null) + schema } /** * Converts an Avro schema to a corresponding Spark SQL schema. @@ -63,17 +80,17 @@ object SchemaConverters { * @since 2.4.0 */ def toSqlType(avroSchema: Schema): SchemaType = { - toSqlType(avroSchema, false, "") + toSqlType(avroSchema, false, "", -1) } @deprecated("using toSqlType(..., useStableIdForUnionType: Boolean) instead", "4.0.0") def toSqlType(avroSchema: Schema, options: Map[String, String]): SchemaType = { val avroOptions = AvroOptions(options) - toSqlTypeHelper( + toSqlType( avroSchema, - Set.empty, avroOptions.useStableIdForUnionType, - avroOptions.stableIdPrefixForUnionType) + avroOptions.stableIdPrefixForUnionType, + avroOptions.recursiveFieldMaxDepth) } // The property specifies Catalyst type of the given field @@ -81,9 +98,10 @@ object SchemaConverters { private def toSqlTypeHelper( avroSchema: Schema, - existingRecordNames: Set[String], + existingRecordNames: Map[String, Int], useStableIdForUnionType: Boolean, - stableIdPrefixForUnionType: String): SchemaType = { + stableIdPrefixForUnionType: String, + recursiveFieldMaxDepth: Int): SchemaType = { avroSchema.getType match { case INT => avroSchema.getLogicalType match { case _: Date => SchemaType(DateType, nullable = false) @@ -128,62 +146,110 @@ object SchemaConverters { case NULL => SchemaType(NullType, nullable = true) case RECORD => - if (existingRecordNames.contains(avroSchema.getFullName)) { + val recursiveDepth: Int = existingRecordNames.getOrElse(avroSchema.getFullName, 0) + if (recursiveDepth > 0 && recursiveFieldMaxDepth <= 0) { throw new IncompatibleSchemaException(s""" - |Found recursive reference in Avro schema, which can not be processed by Spark: - |${avroSchema.toString(true)} + |Found recursive reference in Avro schema, which can not be processed by Spark by + | default: ${avroSchema.toString(true)}. Try setting the option `recursiveFieldMaxDepth` + | to 1 - $RECURSIVE_FIELD_MAX_DEPTH_LIMIT. """.stripMargin) - } - val newRecordNames = existingRecordNames + avroSchema.getFullName - val fields = avroSchema.getFields.asScala.map { f => - val schemaType = toSqlTypeHelper( - f.schema(), - newRecordNames, - useStableIdForUnionType, - stableIdPrefixForUnionType) - StructField(f.name, schemaType.dataType, schemaType.nullable) - } + } else if (recursiveDepth > 0 && recursiveDepth >= recursiveFieldMaxDepth) { + logInfo( + log"The field ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " + + log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} is dropped at recursive depth " + + log"${MDC(RECURSIVE_DEPTH, recursiveDepth)}." + ) + null + } else { + val newRecordNames = + existingRecordNames + (avroSchema.getFullName -> (recursiveDepth + 1)) + val fields = avroSchema.getFields.asScala.map { f => + val schemaType = toSqlTypeHelper( + f.schema(), + newRecordNames, + useStableIdForUnionType, + stableIdPrefixForUnionType, + recursiveFieldMaxDepth) + if (schemaType == null) { + null + } + else { + StructField(f.name, schemaType.dataType, schemaType.nullable) + } + }.filter(_ != null).toSeq - SchemaType(StructType(fields.toArray), nullable = false) + SchemaType(StructType(fields), nullable = false) + } case ARRAY => val schemaType = toSqlTypeHelper( avroSchema.getElementType, existingRecordNames, useStableIdForUnionType, - stableIdPrefixForUnionType) - SchemaType( - ArrayType(schemaType.dataType, containsNull = schemaType.nullable), - nullable = false) + stableIdPrefixForUnionType, + recursiveFieldMaxDepth) + if (schemaType == null) { + logInfo( + log"Dropping ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " + + log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} as it does not have any " + + log"fields left likely due to recursive depth limit." + ) + null + } else { + SchemaType( + ArrayType(schemaType.dataType, containsNull = schemaType.nullable), + nullable = false) + } case MAP => val schemaType = toSqlTypeHelper(avroSchema.getValueType, - existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType) - SchemaType( - MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable), - nullable = false) + existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType, + recursiveFieldMaxDepth) + if (schemaType == null) { + logInfo( + log"Dropping ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " + + log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} as it does not have any " + + log"fields left likely due to recursive depth limit." + ) + null + } else { + SchemaType( + MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable), + nullable = false) + } case UNION => if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) { // In case of a union with null, eliminate it and make a recursive call val remainingUnionTypes = AvroUtils.nonNullUnionBranches(avroSchema) - if (remainingUnionTypes.size == 1) { - toSqlTypeHelper( - remainingUnionTypes.head, - existingRecordNames, - useStableIdForUnionType, - stableIdPrefixForUnionType).copy(nullable = true) + val remainingSchema = + if (remainingUnionTypes.size == 1) { + remainingUnionTypes.head + } else { + Schema.createUnion(remainingUnionTypes.asJava) + } + val schemaType = toSqlTypeHelper( + remainingSchema, + existingRecordNames, + useStableIdForUnionType, + stableIdPrefixForUnionType, + recursiveFieldMaxDepth) + + if (schemaType == null) { + logInfo( + log"Dropping ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " + + log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} as it does not have any " + + log"fields left likely due to recursive depth limit." + ) + null } else { - toSqlTypeHelper( - Schema.createUnion(remainingUnionTypes.asJava), - existingRecordNames, - useStableIdForUnionType, - stableIdPrefixForUnionType).copy(nullable = true) + schemaType.copy(nullable = true) } } else avroSchema.getTypes.asScala.map(_.getType).toSeq match { case Seq(t1) => toSqlTypeHelper(avroSchema.getTypes.get(0), - existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType) + existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType, + recursiveFieldMaxDepth) case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) => SchemaType(LongType, nullable = false) case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) => @@ -201,29 +267,33 @@ object SchemaConverters { s, existingRecordNames, useStableIdForUnionType, - stableIdPrefixForUnionType) - - val fieldName = if (useStableIdForUnionType) { - // Avro's field name may be case sensitive, so field names for two named type - // could be "a" and "A" and we need to distinguish them. In this case, we throw - // an exception. - // Stable id prefix can be empty so the name of the field can be just the type. - val tempFieldName = s"${stableIdPrefixForUnionType}${s.getName}" - if (!fieldNameSet.add(tempFieldName.toLowerCase(Locale.ROOT))) { - throw new IncompatibleSchemaException( - "Cannot generate stable identifier for Avro union type due to name " + - s"conflict of type name ${s.getName}") - } - tempFieldName + stableIdPrefixForUnionType, + recursiveFieldMaxDepth) + if (schemaType == null) { + null } else { - s"member$i" - } + val fieldName = if (useStableIdForUnionType) { + // Avro's field name may be case sensitive, so field names for two named type + // could be "a" and "A" and we need to distinguish them. In this case, we throw + // an exception. + // Stable id prefix can be empty so the name of the field can be just the type. + val tempFieldName = s"${stableIdPrefixForUnionType}${s.getName}" + if (!fieldNameSet.add(tempFieldName.toLowerCase(Locale.ROOT))) { + throw new IncompatibleSchemaException( + "Cannot generate stable identifier for Avro union type due to name " + + s"conflict of type name ${s.getName}") + } + tempFieldName + } else { + s"member$i" + } - // All fields are nullable because only one of them is set at a time - StructField(fieldName, schemaType.dataType, nullable = true) - } + // All fields are nullable because only one of them is set at a time + StructField(fieldName, schemaType.dataType, nullable = true) + } + }.filter(_ != null).toSeq - SchemaType(StructType(fields.toArray), nullable = false) + SchemaType(StructType(fields), nullable = false) } case other => throw new IncompatibleSchemaException(s"Unsupported type $other") diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala deleted file mode 100755 index 828a609a10e9c..0000000000000 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.avro - -import scala.jdk.CollectionConverters._ - -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.Column -import org.apache.spark.sql.internal.ExpressionUtils.{column, expression} - - -// scalastyle:off: object.name -object functions { -// scalastyle:on: object.name - - /** - * Converts a binary column of avro format into its corresponding catalyst value. The specified - * schema must match the read data, otherwise the behavior is undefined: it may fail or return - * arbitrary result. - * - * @param data the binary column. - * @param jsonFormatSchema the avro schema in JSON string format. - * - * @since 3.0.0 - */ - @Experimental - def from_avro( - data: Column, - jsonFormatSchema: String): Column = { - AvroDataToCatalyst(data, jsonFormatSchema, Map.empty) - } - - /** - * Converts a binary column of Avro format into its corresponding catalyst value. - * The specified schema must match actual schema of the read data, otherwise the behavior - * is undefined: it may fail or return arbitrary result. - * To deserialize the data with a compatible and evolved schema, the expected Avro schema can be - * set via the option avroSchema. - * - * @param data the binary column. - * @param jsonFormatSchema the avro schema in JSON string format. - * @param options options to control how the Avro record is parsed. - * - * @since 3.0.0 - */ - @Experimental - def from_avro( - data: Column, - jsonFormatSchema: String, - options: java.util.Map[String, String]): Column = { - AvroDataToCatalyst(data, jsonFormatSchema, options.asScala.toMap) - } - - /** - * Converts a column into binary of avro format. - * - * @param data the data column. - * - * @since 3.0.0 - */ - @Experimental - def to_avro(data: Column): Column = { - CatalystDataToAvro(data, None) - } - - /** - * Converts a column into binary of avro format. - * - * @param data the data column. - * @param jsonFormatSchema user-specified output avro schema in JSON string format. - * - * @since 3.0.0 - */ - @Experimental - def to_avro(data: Column, jsonFormatSchema: String): Column = { - CatalystDataToAvro(data, Some(jsonFormatSchema)) - } -} diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala index 1083c99160724..a13faf3b51560 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala @@ -105,7 +105,8 @@ case class AvroPartitionReaderFactory( datetimeRebaseMode, avroFilters, options.useStableIdForUnionType, - options.stableIdPrefixForUnionType) + options.stableIdPrefixForUnionType, + options.recursiveFieldMaxDepth) override val stopPosition = partitionedFile.start + partitionedFile.length override def next(): Boolean = hasNextRow diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala index 388347537a4d6..311eda3a1b6ae 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala @@ -291,7 +291,8 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite RebaseSpec(LegacyBehaviorPolicy.CORRECTED), filters, false, - "") + "", + -1) val deserialized = deserializer.deserialize(data) expected match { case None => assert(deserialized == None) diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala index 47faaf7662a50..a7f7abadcf485 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.LocalTableScanExec import org.apache.spark.sql.functions.{col, lit, struct} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{BinaryType, StructType} +import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType} class AvroFunctionsSuite extends QueryTest with SharedSparkSession { import testImplicits._ @@ -374,6 +374,37 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { } } + + test("roundtrip in to_avro and from_avro - recursive schema") { + val catalystSchema = + StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", StructType(Seq( + StructField("Id", IntegerType))))))))) + + val avroSchema = s""" + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [ + | {"name": "Id", "type": "int"}, + | {"name": "Name", "type": ["null", "test_schema"]} + | ] + |} + """.stripMargin + + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(2, Row(3, Row(4))), Row(1, null))), + catalystSchema).select(struct("Id", "Name").as("struct")) + + val avroStructDF = df.select(functions.to_avro($"struct", avroSchema).as("avro")) + checkAnswer(avroStructDF.select( + functions.from_avro($"avro", avroSchema, Map( + "recursiveFieldMaxDepth" -> "3").asJava)), df) + } + private def serialize(record: GenericRecord, avroSchema: String): Array[Byte] = { val schema = new Schema.Parser().parse(avroSchema) val datumWriter = new GenericDatumWriter[GenericRecord](schema) diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala index 751ac275e048a..bb0858decdf8f 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala @@ -436,7 +436,7 @@ abstract class AvroLogicalTypeSuite extends QueryTest with SharedSparkSession { val ex = intercept[SparkException] { spark.read.format("avro").load(s"$dir.avro").collect() } - assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(ex.getCondition.startsWith("FAILED_READ_FILE")) checkError( exception = ex.getCause.asInstanceOf[SparkArithmeticException], condition = "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION", diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala index 9b3bb929a700d..c1ab96a63eb26 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala @@ -77,7 +77,8 @@ class AvroRowReaderSuite RebaseSpec(CORRECTED), new NoopFilters, false, - "") + "", + -1) override val stopPosition = fileSize override def hasNext: Boolean = hasNextRow diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala index cbcbc2e7e76a6..3643a95abe19c 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala @@ -228,7 +228,8 @@ object AvroSerdeSuite { RebaseSpec(CORRECTED), new NoopFilters, false, - "") + "", + -1) } /** diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 14ed6c43e4c0f..0df6a7c4bc90e 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -891,7 +891,7 @@ abstract class AvroSuite val ex = intercept[SparkException] { spark.read.schema("a DECIMAL(4, 3)").format("avro").load(path.toString).collect() } - assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(ex.getCondition.startsWith("FAILED_READ_FILE")) checkError( exception = ex.getCause.asInstanceOf[AnalysisException], condition = "AVRO_INCOMPATIBLE_READ_TYPE", @@ -969,7 +969,7 @@ abstract class AvroSuite val ex = intercept[SparkException] { spark.read.schema(s"a $sqlType").format("avro").load(path.toString).collect() } - assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(ex.getCondition.startsWith("FAILED_READ_FILE")) checkError( exception = ex.getCause.asInstanceOf[AnalysisException], condition = "AVRO_INCOMPATIBLE_READ_TYPE", @@ -1006,7 +1006,7 @@ abstract class AvroSuite val ex = intercept[SparkException] { spark.read.schema(s"a $sqlType").format("avro").load(path.toString).collect() } - assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(ex.getCondition.startsWith("FAILED_READ_FILE")) checkError( exception = ex.getCause.asInstanceOf[AnalysisException], condition = "AVRO_INCOMPATIBLE_READ_TYPE", @@ -1515,7 +1515,7 @@ abstract class AvroSuite .write.format("avro").option("avroSchema", avroSchema) .save(s"$tempDir/${UUID.randomUUID()}") } - assert(ex.getErrorClass == "TASK_WRITE_FAILED") + assert(ex.getCondition == "TASK_WRITE_FAILED") assert(ex.getCause.isInstanceOf[java.lang.NullPointerException]) assert(ex.getCause.getMessage.contains( "null value for (non-nullable) string at test_schema.Name")) @@ -1673,8 +1673,12 @@ abstract class AvroSuite exception = intercept[AnalysisException] { sql("select interval 1 days").write.format("avro").mode("overwrite").save(tempDir) }, - condition = "_LEGACY_ERROR_TEMP_1136", - parameters = Map.empty + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + parameters = Map( + "format" -> "Avro", + "columnName" -> "`INTERVAL '1 days'`", + "columnType" -> "\"INTERVAL\"" + ) ) checkError( exception = intercept[AnalysisException] { @@ -2220,7 +2224,8 @@ abstract class AvroSuite } } - private def checkSchemaWithRecursiveLoop(avroSchema: String): Unit = { + private def checkSchemaWithRecursiveLoop(avroSchema: String, recursiveFieldMaxDepth: Int): + Unit = { val message = intercept[IncompatibleSchemaException] { SchemaConverters.toSqlType(new Schema.Parser().parse(avroSchema), false, "") }.getMessage @@ -2229,7 +2234,79 @@ abstract class AvroSuite } test("Detect recursive loop") { - checkSchemaWithRecursiveLoop(""" + for (recursiveFieldMaxDepth <- Seq(-1, 0)) { + checkSchemaWithRecursiveLoop( + """ + |{ + | "type": "record", + | "name": "LongList", + | "fields" : [ + | {"name": "value", "type": "long"}, // each element has a long + | {"name": "next", "type": ["null", "LongList"]} // optional next element + | ] + |} + """.stripMargin, recursiveFieldMaxDepth) + + checkSchemaWithRecursiveLoop( + """ + |{ + | "type": "record", + | "name": "LongList", + | "fields": [ + | { + | "name": "value", + | "type": { + | "type": "record", + | "name": "foo", + | "fields": [ + | { + | "name": "parent", + | "type": "LongList" + | } + | ] + | } + | } + | ] + |} + """.stripMargin, recursiveFieldMaxDepth) + + checkSchemaWithRecursiveLoop( + """ + |{ + | "type": "record", + | "name": "LongList", + | "fields" : [ + | {"name": "value", "type": "long"}, + | {"name": "array", "type": {"type": "array", "items": "LongList"}} + | ] + |} + """.stripMargin, recursiveFieldMaxDepth) + + checkSchemaWithRecursiveLoop( + """ + |{ + | "type": "record", + | "name": "LongList", + | "fields" : [ + | {"name": "value", "type": "long"}, + | {"name": "map", "type": {"type": "map", "values": "LongList"}} + | ] + |} + """.stripMargin, recursiveFieldMaxDepth) + } + } + + private def checkSparkSchemaEquals( + avroSchema: String, expectedSchema: StructType, recursiveFieldMaxDepth: Int): Unit = { + val sparkSchema = + SchemaConverters.toSqlType( + new Schema.Parser().parse(avroSchema), false, "", recursiveFieldMaxDepth).dataType + + assert(sparkSchema === expectedSchema) + } + + test("Translate recursive schema - union") { + val avroSchema = """ |{ | "type": "record", | "name": "LongList", @@ -2238,9 +2315,57 @@ abstract class AvroSuite | {"name": "next", "type": ["null", "LongList"]} // optional next element | ] |} - """.stripMargin) + """.stripMargin + val nonRecursiveFields = new StructType().add("value", LongType, nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = nonRecursiveFields.add("next", expectedSchema) + } + } + + test("Translate recursive schema - union - 2 non-null fields") { + val avroSchema = """ + |{ + | "type": "record", + | "name": "TreeNode", + | "fields": [ + | { + | "name": "name", + | "type": "string" + | }, + | { + | "name": "value", + | "type": [ + | "long" + | ] + | }, + | { + | "name": "children", + | "type": [ + | "null", + | { + | "type": "array", + | "items": "TreeNode" + | } + | ], + | "default": null + | } + | ] + |} + """.stripMargin + val nonRecursiveFields = new StructType().add("name", StringType, nullable = false) + .add("value", LongType, nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = nonRecursiveFields.add("children", + new ArrayType(expectedSchema, false), nullable = true) + } + } - checkSchemaWithRecursiveLoop(""" + test("Translate recursive schema - record") { + val avroSchema = """ |{ | "type": "record", | "name": "LongList", @@ -2260,9 +2385,18 @@ abstract class AvroSuite | } | ] |} - """.stripMargin) + """.stripMargin + val nonRecursiveFields = new StructType().add("value", StructType(Seq()), nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = new StructType().add("value", + new StructType().add("parent", expectedSchema, nullable = false), nullable = false) + } + } - checkSchemaWithRecursiveLoop(""" + test("Translate recursive schema - array") { + val avroSchema = """ |{ | "type": "record", | "name": "LongList", @@ -2271,9 +2405,18 @@ abstract class AvroSuite | {"name": "array", "type": {"type": "array", "items": "LongList"}} | ] |} - """.stripMargin) + """.stripMargin + val nonRecursiveFields = new StructType().add("value", LongType, nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = + nonRecursiveFields.add("array", new ArrayType(expectedSchema, false), nullable = false) + } + } - checkSchemaWithRecursiveLoop(""" + test("Translate recursive schema - map") { + val avroSchema = """ |{ | "type": "record", | "name": "LongList", @@ -2282,7 +2425,70 @@ abstract class AvroSuite | {"name": "map", "type": {"type": "map", "values": "LongList"}} | ] |} - """.stripMargin) + """.stripMargin + val nonRecursiveFields = new StructType().add("value", LongType, nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = + nonRecursiveFields.add("map", + new MapType(StringType, expectedSchema, false), nullable = false) + } + } + + test("recursive schema integration test") { + val catalystSchema = + StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", NullType))))))))) + + val avroSchema = s""" + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [ + | {"name": "Id", "type": "int"}, + | {"name": "Name", "type": ["null", "test_schema"]} + | ] + |} + """.stripMargin + + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(2, Row(3, Row(4, null))), Row(1, null))), + catalystSchema) + + withTempPath { tempDir => + df.write.format("avro").save(tempDir.getPath) + + val exc = intercept[AnalysisException] { + spark.read + .format("avro") + .option("avroSchema", avroSchema) + .option("recursiveFieldMaxDepth", 16) + .load(tempDir.getPath) + } + assert(exc.getMessage.contains("Should not be greater than 15.")) + + checkAnswer( + spark.read + .format("avro") + .option("avroSchema", avroSchema) + .option("recursiveFieldMaxDepth", 10) + .load(tempDir.getPath), + df) + + checkAnswer( + spark.read + .format("avro") + .option("avroSchema", avroSchema) + .option("recursiveFieldMaxDepth", 1) + .load(tempDir.getPath), + df.select("Id")) + } } test("log a warning of ignoreExtension deprecation") { @@ -2427,7 +2633,7 @@ abstract class AvroSuite val e = intercept[SparkException] { df.write.format("avro").option("avroSchema", avroSchema).save(path3_x) } - assert(e.getErrorClass == "TASK_WRITE_FAILED") + assert(e.getCondition == "TASK_WRITE_FAILED") assert(e.getCause.isInstanceOf[SparkUpgradeException]) } checkDefaultLegacyRead(oldPath) @@ -2682,7 +2888,7 @@ abstract class AvroSuite val e = intercept[SparkException] { df.write.format("avro").option("avroSchema", avroSchema).save(dir.getCanonicalPath) } - assert(e.getErrorClass == "TASK_WRITE_FAILED") + assert(e.getCondition == "TASK_WRITE_FAILED") val errMsg = e.getCause.asInstanceOf[SparkUpgradeException].getMessage assert(errMsg.contains("You may get a different result due to the upgrading")) } @@ -2693,7 +2899,7 @@ abstract class AvroSuite val e = intercept[SparkException] { df.write.format("avro").save(dir.getCanonicalPath) } - assert(e.getErrorClass == "TASK_WRITE_FAILED") + assert(e.getCondition == "TASK_WRITE_FAILED") val errMsg = e.getCause.asInstanceOf[SparkUpgradeException].getMessage assert(errMsg.contains("You may get a different result due to the upgrading")) } @@ -2777,7 +2983,7 @@ abstract class AvroSuite } test("SPARK-40667: validate Avro Options") { - assert(AvroOptions.getAllOptions.size == 11) + assert(AvroOptions.getAllOptions.size == 12) // Please add validation on any new Avro options here assert(AvroOptions.isValidOption("ignoreExtension")) assert(AvroOptions.isValidOption("mode")) @@ -2790,6 +2996,7 @@ abstract class AvroSuite assert(AvroOptions.isValidOption("datetimeRebaseMode")) assert(AvroOptions.isValidOption("enableStableIdentifiersForUnionType")) assert(AvroOptions.isValidOption("stableIdentifierPrefixForUnionType")) + assert(AvroOptions.isValidOption("recursiveFieldMaxDepth")) } test("SPARK-46633: read file with empty blocks") { diff --git a/connector/connect/client/jvm/pom.xml b/connector/connect/client/jvm/pom.xml index be358f317481e..2fdb2d4bafe01 100644 --- a/connector/connect/client/jvm/pom.xml +++ b/connector/connect/client/jvm/pom.xml @@ -45,6 +45,11 @@ spark-sql-api_${scala.binary.version} ${project.version} + + org.apache.spark + spark-connect-shims_${scala.binary.version} + ${project.version} + org.apache.spark spark-sketch_${scala.binary.version} @@ -88,6 +93,13 @@ scalacheck_${scala.binary.version} test + + org.apache.spark + spark-sql-api_${scala.binary.version} + ${project.version} + tests + test + org.apache.spark spark-common-utils_${scala.binary.version} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index c06cbbc0cdb42..3777f82594aae 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -22,6 +22,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.connect.proto.{NAReplace, Relation} import org.apache.spark.connect.proto.Expression.{Literal => GLiteral} import org.apache.spark.connect.proto.NAReplace.Replacement +import org.apache.spark.sql.connect.ConnectConversions._ /** * Functionality for working with missing data in `DataFrame`s. @@ -29,7 +30,7 @@ import org.apache.spark.connect.proto.NAReplace.Replacement * @since 3.4.0 */ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root: Relation) - extends api.DataFrameNaFunctions[Dataset] { + extends api.DataFrameNaFunctions { import sparkSession.RichColumn override protected def drop(minNonNulls: Option[Int]): Dataset[Row] = diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index c3ee7030424eb..051d382c49773 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -22,7 +22,10 @@ import java.util.Properties import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Stable +import org.apache.spark.api.java.JavaRDD import org.apache.spark.connect.proto.Parse.ParseFormat +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.types.StructType @@ -33,8 +36,8 @@ import org.apache.spark.sql.types.StructType * @since 3.4.0 */ @Stable -class DataFrameReader private[sql] (sparkSession: SparkSession) - extends api.DataFrameReader[Dataset] { +class DataFrameReader private[sql] (sparkSession: SparkSession) extends api.DataFrameReader { + type DS[U] = Dataset[U] /** @inheritdoc */ override def format(source: String): this.type = super.format(source) @@ -139,6 +142,14 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) def json(jsonDataset: Dataset[String]): DataFrame = parse(jsonDataset, ParseFormat.PARSE_FORMAT_JSON) + /** @inheritdoc */ + override def json(jsonRDD: JavaRDD[String]): Dataset[Row] = + throwRddNotSupportedException() + + /** @inheritdoc */ + override def json(jsonRDD: RDD[String]): Dataset[Row] = + throwRddNotSupportedException() + /** @inheritdoc */ override def csv(path: String): DataFrame = super.csv(path) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 9f5ada0d7ec35..bb7cfa75a9ab9 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -22,6 +22,7 @@ import java.{lang => jl, util => ju} import org.apache.spark.connect.proto.{Relation, StatSampleBy} import org.apache.spark.sql.DataFrameStatFunctions.approxQuantileResultEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, PrimitiveDoubleEncoder} +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.functions.lit /** @@ -30,7 +31,7 @@ import org.apache.spark.sql.functions.lit * @since 3.4.0 */ final class DataFrameStatFunctions private[sql] (protected val df: DataFrame) - extends api.DataFrameStatFunctions[Dataset] { + extends api.DataFrameStatFunctions { private def root: Relation = df.plan.getRoot private val sparkSession: SparkSession = df.sparkSession diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 519193ebd9c74..adbfda9691508 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -26,12 +26,15 @@ import scala.util.control.NonFatal import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function._ import org.apache.spark.connect.proto +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.expressions.OrderUtils +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.client.SparkResult import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevelProtoConverter} import org.apache.spark.sql.errors.DataTypeErrors.toSQLId @@ -134,15 +137,15 @@ class Dataset[T] private[sql] ( val sparkSession: SparkSession, @DeveloperApi val plan: proto.Plan, val encoder: Encoder[T]) - extends api.Dataset[T, Dataset] { - type RGD = RelationalGroupedDataset + extends api.Dataset[T] { + type DS[U] = Dataset[U] import sparkSession.RichColumn // Make sure we don't forget to set plan id. assert(plan.getRoot.getCommon.hasPlanId) - private[sql] val agnosticEncoder: AgnosticEncoder[T] = encoderFor(encoder) + private[sql] val agnosticEncoder: AgnosticEncoder[T] = agnosticEncoderFor(encoder) override def toString: String = { try { @@ -273,9 +276,7 @@ class Dataset[T] private[sql] ( df.withResult { result => assert(result.length == 1) assert(result.schema.size == 1) - // scalastyle:off println - println(result.toArray.head) - // scalastyle:on println + print(result.toArray.head) } } @@ -436,7 +437,7 @@ class Dataset[T] private[sql] ( /** @inheritdoc */ def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { - val encoder = encoderFor(c1.encoder) + val encoder = agnosticEncoderFor(c1.encoder) val col = if (encoder.schema == encoder.dataType) { functions.inline(functions.array(c1)) } else { @@ -451,7 +452,7 @@ class Dataset[T] private[sql] ( /** @inheritdoc */ protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { - val encoder = ProductEncoder.tuple(columns.map(c => encoderFor(c.encoder))) + val encoder = ProductEncoder.tuple(columns.map(c => agnosticEncoderFor(c.encoder))) selectUntyped(encoder, columns) } @@ -523,27 +524,11 @@ class Dataset[T] private[sql] ( result(0) } - /** - * (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given - * key `func`. - * - * @group typedrel - * @since 3.5.0 - */ + /** @inheritdoc */ def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { - KeyValueGroupedDatasetImpl[K, T](this, encoderFor[K], func) + KeyValueGroupedDatasetImpl[K, T](this, agnosticEncoderFor[K], func) } - /** - * (Java-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given - * key `func`. - * - * @group typedrel - * @since 3.5.0 - */ - def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = - groupByKey(ToScalaUDF(func))(encoder) - /** @inheritdoc */ @scala.annotation.varargs def rollup(cols: Column*): RelationalGroupedDataset = { @@ -896,7 +881,7 @@ class Dataset[T] private[sql] ( /** @inheritdoc */ def mapPartitions[U: Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { - val outputEncoder = encoderFor[U] + val outputEncoder = agnosticEncoderFor[U] val udf = SparkUserDefinedFunction( function = func, inputEncoders = agnosticEncoder :: Nil, @@ -1050,12 +1035,7 @@ class Dataset[T] private[sql] ( new MergeIntoWriterImpl[T](table, this, condition) } - /** - * Interface for saving the content of the streaming Dataset out into external storage. - * - * @group basic - * @since 3.5.0 - */ + /** @inheritdoc */ def writeStream: DataStreamWriter[T] = { new DataStreamWriter[T](this) } @@ -1135,13 +1115,20 @@ class Dataset[T] private[sql] ( } /** @inheritdoc */ - protected def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] = { + protected def checkpoint( + eager: Boolean, + reliableCheckpoint: Boolean, + storageLevel: Option[StorageLevel]): Dataset[T] = { sparkSession.newDataset(agnosticEncoder) { builder => val command = sparkSession.newCommand { builder => - builder.getCheckpointCommandBuilder + val checkpointBuilder = builder.getCheckpointCommandBuilder .setLocal(!reliableCheckpoint) .setEager(eager) .setRelation(this.plan.getRoot) + storageLevel.foreach { storageLevel => + checkpointBuilder.setStorageLevel( + StorageLevelProtoConverter.toConnectProtoType(storageLevel)) + } } val responseIter = sparkSession.execute(command) try { @@ -1324,6 +1311,10 @@ class Dataset[T] private[sql] ( /** @inheritdoc */ override def localCheckpoint(eager: Boolean): Dataset[T] = super.localCheckpoint(eager) + /** @inheritdoc */ + override def localCheckpoint(eager: Boolean, storageLevel: StorageLevel): Dataset[T] = + super.localCheckpoint(eager, storageLevel) + /** @inheritdoc */ override def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = super.joinWith(other, condition) @@ -1479,4 +1470,16 @@ class Dataset[T] private[sql] ( /** @inheritdoc */ @scala.annotation.varargs override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*) + + /** @inheritdoc */ + override def groupByKey[K]( + func: MapFunction[T, K], + encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = + super.groupByKey(func, encoder).asInstanceOf[KeyValueGroupedDataset[K, T]] + + /** @inheritdoc */ + override def rdd: RDD[T] = throwRddNotSupportedException() + + /** @inheritdoc */ + override def toJavaRDD: JavaRDD[T] = throwRddNotSupportedException() } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index aef7efb08a254..63b5f27c4745e 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -25,7 +25,8 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.api.java.function._ import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, ProductEncoder} +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.common.UdfUtils import org.apache.spark.sql.expressions.SparkUserDefinedFunction import org.apache.spark.sql.functions.col @@ -40,8 +41,7 @@ import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode * * @since 3.5.0 */ -class KeyValueGroupedDataset[K, V] private[sql] () - extends api.KeyValueGroupedDataset[K, V, Dataset] { +class KeyValueGroupedDataset[K, V] private[sql] () extends api.KeyValueGroupedDataset[K, V] { type KVDS[KY, VL] = KeyValueGroupedDataset[KY, VL] private def unsupported(): Nothing = throw new UnsupportedOperationException() @@ -398,7 +398,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( new KeyValueGroupedDatasetImpl[L, V, IK, IV]( sparkSession, plan, - encoderFor[L], + agnosticEncoderFor[L], ivEncoder, vEncoder, groupingExprs, @@ -412,7 +412,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( plan, kEncoder, ivEncoder, - encoderFor[W], + agnosticEncoderFor[W], groupingExprs, valueMapFunc .map(_.andThen(valueFunc)) @@ -430,7 +430,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = { // Apply mapValues changes to the udf val nf = UDFAdaptors.flatMapGroupsWithMappedValues(f, valueMapFunc) - val outputEncoder = encoderFor[U] + val outputEncoder = agnosticEncoderFor[U] sparkSession.newDataset[U](outputEncoder) { builder => builder.getGroupMapBuilder .setInput(plan.getRoot) @@ -446,7 +446,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( val otherImpl = other.asInstanceOf[KeyValueGroupedDatasetImpl[K, U, _, Any]] // Apply mapValues changes to the udf val nf = UDFAdaptors.coGroupWithMappedValues(f, valueMapFunc, otherImpl.valueMapFunc) - val outputEncoder = encoderFor[R] + val outputEncoder = agnosticEncoderFor[R] sparkSession.newDataset[R](outputEncoder) { builder => builder.getCoGroupMapBuilder .setInput(plan.getRoot) @@ -461,7 +461,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( override protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { // TODO(SPARK-43415): For each column, apply the valueMap func first... - val rEnc = ProductEncoder.tuple(kEncoder +: columns.map(c => encoderFor(c.encoder))) + val rEnc = ProductEncoder.tuple(kEncoder +: columns.map(c => agnosticEncoderFor(c.encoder))) sparkSession.newDataset(rEnc) { builder => builder.getAggregateBuilder .setInput(plan.getRoot) @@ -501,7 +501,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( null } - val outputEncoder = encoderFor[U] + val outputEncoder = agnosticEncoderFor[U] val nf = UDFAdaptors.flatMapGroupsWithStateWithMappedValues(func, valueMapFunc) sparkSession.newDataset[U](outputEncoder) { builder => diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index ea13635fc2eaa..5bded40b0d132 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql import scala.jdk.CollectionConverters._ import org.apache.spark.connect.proto +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.agnosticEncoderFor +import org.apache.spark.sql.connect.ConnectConversions._ /** * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]], @@ -39,8 +41,7 @@ class RelationalGroupedDataset private[sql] ( groupType: proto.Aggregate.GroupType, pivot: Option[proto.Aggregate.Pivot] = None, groupingSets: Option[Seq[proto.Aggregate.GroupingSets]] = None) - extends api.RelationalGroupedDataset[Dataset] { - type RGD = RelationalGroupedDataset + extends api.RelationalGroupedDataset { import df.sparkSession.RichColumn protected def toDF(aggExprs: Seq[Column]): DataFrame = { @@ -82,7 +83,11 @@ class RelationalGroupedDataset private[sql] ( /** @inheritdoc */ def as[K: Encoder, T: Encoder]: KeyValueGroupedDataset[K, T] = { - KeyValueGroupedDatasetImpl[K, T](df, encoderFor[K], encoderFor[T], groupingExprs) + KeyValueGroupedDatasetImpl[K, T]( + df, + agnosticEncoderFor[K], + agnosticEncoderFor[T], + groupingExprs) } /** @inheritdoc */ diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 7799d395d5c6a..4690253da808b 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -16,283 +16,8 @@ */ package org.apache.spark.sql -import scala.collection.Map -import scala.language.implicitConversions -import scala.reflect.classTag -import scala.reflect.runtime.universe.TypeTag - -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ - -/** - * A collection of implicit methods for converting names and Symbols into [[Column]]s, and for - * converting common Scala objects into [[Dataset]]s. - * - * @since 3.4.0 - */ -abstract class SQLImplicits private[sql] (session: SparkSession) extends LowPrioritySQLImplicits { - - /** - * Converts $"col name" into a [[Column]]. - * - * @since 3.4.0 - */ - implicit class StringToColumn(val sc: StringContext) { - def $(args: Any*): ColumnName = { - new ColumnName(sc.s(args: _*)) - } - } - - /** - * An implicit conversion that turns a Scala `Symbol` into a [[Column]]. - * @since 3.4.0 - */ - implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) - - /** @since 3.4.0 */ - implicit val newIntEncoder: Encoder[Int] = PrimitiveIntEncoder - - /** @since 3.4.0 */ - implicit val newLongEncoder: Encoder[Long] = PrimitiveLongEncoder - - /** @since 3.4.0 */ - implicit val newDoubleEncoder: Encoder[Double] = PrimitiveDoubleEncoder - - /** @since 3.4.0 */ - implicit val newFloatEncoder: Encoder[Float] = PrimitiveFloatEncoder - - /** @since 3.4.0 */ - implicit val newByteEncoder: Encoder[Byte] = PrimitiveByteEncoder - - /** @since 3.4.0 */ - implicit val newShortEncoder: Encoder[Short] = PrimitiveShortEncoder - - /** @since 3.4.0 */ - implicit val newBooleanEncoder: Encoder[Boolean] = PrimitiveBooleanEncoder - - /** @since 3.4.0 */ - implicit val newStringEncoder: Encoder[String] = StringEncoder - - /** @since 3.4.0 */ - implicit val newJavaDecimalEncoder: Encoder[java.math.BigDecimal] = - AgnosticEncoders.DEFAULT_JAVA_DECIMAL_ENCODER - - /** @since 3.4.0 */ - implicit val newScalaDecimalEncoder: Encoder[scala.math.BigDecimal] = - AgnosticEncoders.DEFAULT_SCALA_DECIMAL_ENCODER - - /** @since 3.4.0 */ - implicit val newDateEncoder: Encoder[java.sql.Date] = AgnosticEncoders.STRICT_DATE_ENCODER - - /** @since 3.4.0 */ - implicit val newLocalDateEncoder: Encoder[java.time.LocalDate] = - AgnosticEncoders.STRICT_LOCAL_DATE_ENCODER - - /** @since 3.4.0 */ - implicit val newLocalDateTimeEncoder: Encoder[java.time.LocalDateTime] = - AgnosticEncoders.LocalDateTimeEncoder - - /** @since 3.4.0 */ - implicit val newTimeStampEncoder: Encoder[java.sql.Timestamp] = - AgnosticEncoders.STRICT_TIMESTAMP_ENCODER - - /** @since 3.4.0 */ - implicit val newInstantEncoder: Encoder[java.time.Instant] = - AgnosticEncoders.STRICT_INSTANT_ENCODER - - /** @since 3.4.0 */ - implicit val newDurationEncoder: Encoder[java.time.Duration] = DayTimeIntervalEncoder - - /** @since 3.4.0 */ - implicit val newPeriodEncoder: Encoder[java.time.Period] = YearMonthIntervalEncoder - - /** @since 3.4.0 */ - implicit def newJavaEnumEncoder[A <: java.lang.Enum[_]: TypeTag]: Encoder[A] = { - ScalaReflection.encoderFor[A] - } - - // Boxed primitives - - /** @since 3.4.0 */ - implicit val newBoxedIntEncoder: Encoder[java.lang.Integer] = BoxedIntEncoder - - /** @since 3.4.0 */ - implicit val newBoxedLongEncoder: Encoder[java.lang.Long] = BoxedLongEncoder - - /** @since 3.4.0 */ - implicit val newBoxedDoubleEncoder: Encoder[java.lang.Double] = BoxedDoubleEncoder - - /** @since 3.4.0 */ - implicit val newBoxedFloatEncoder: Encoder[java.lang.Float] = BoxedFloatEncoder - - /** @since 3.4.0 */ - implicit val newBoxedByteEncoder: Encoder[java.lang.Byte] = BoxedByteEncoder - - /** @since 3.4.0 */ - implicit val newBoxedShortEncoder: Encoder[java.lang.Short] = BoxedShortEncoder - - /** @since 3.4.0 */ - implicit val newBoxedBooleanEncoder: Encoder[java.lang.Boolean] = BoxedBooleanEncoder - - // Seqs - private def newSeqEncoder[E](elementEncoder: AgnosticEncoder[E]): AgnosticEncoder[Seq[E]] = { - IterableEncoder( - classTag[Seq[E]], - elementEncoder, - elementEncoder.nullable, - elementEncoder.lenientSerialization) - } - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - val newIntSeqEncoder: Encoder[Seq[Int]] = newSeqEncoder(PrimitiveIntEncoder) - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - val newLongSeqEncoder: Encoder[Seq[Long]] = newSeqEncoder(PrimitiveLongEncoder) - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - val newDoubleSeqEncoder: Encoder[Seq[Double]] = newSeqEncoder(PrimitiveDoubleEncoder) - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - val newFloatSeqEncoder: Encoder[Seq[Float]] = newSeqEncoder(PrimitiveFloatEncoder) - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - val newByteSeqEncoder: Encoder[Seq[Byte]] = newSeqEncoder(PrimitiveByteEncoder) - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - val newShortSeqEncoder: Encoder[Seq[Short]] = newSeqEncoder(PrimitiveShortEncoder) - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - val newBooleanSeqEncoder: Encoder[Seq[Boolean]] = newSeqEncoder(PrimitiveBooleanEncoder) - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - val newStringSeqEncoder: Encoder[Seq[String]] = newSeqEncoder(StringEncoder) - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - def newProductSeqEncoder[A <: Product: TypeTag]: Encoder[Seq[A]] = - newSeqEncoder(ScalaReflection.encoderFor[A]) - - /** @since 3.4.0 */ - implicit def newSequenceEncoder[T <: Seq[_]: TypeTag]: Encoder[T] = - ScalaReflection.encoderFor[T] - - // Maps - /** @since 3.4.0 */ - implicit def newMapEncoder[T <: Map[_, _]: TypeTag]: Encoder[T] = ScalaReflection.encoderFor[T] - - /** - * Notice that we serialize `Set` to Catalyst array. The set property is only kept when - * manipulating the domain objects. The serialization format doesn't keep the set property. When - * we have a Catalyst array which contains duplicated elements and convert it to - * `Dataset[Set[T]]` by using the encoder, the elements will be de-duplicated. - * - * @since 3.4.0 - */ - implicit def newSetEncoder[T <: Set[_]: TypeTag]: Encoder[T] = ScalaReflection.encoderFor[T] - - // Arrays - private def newArrayEncoder[E]( - elementEncoder: AgnosticEncoder[E]): AgnosticEncoder[Array[E]] = { - ArrayEncoder(elementEncoder, elementEncoder.nullable) - } - - /** @since 3.4.0 */ - implicit val newIntArrayEncoder: Encoder[Array[Int]] = newArrayEncoder(PrimitiveIntEncoder) - - /** @since 3.4.0 */ - implicit val newLongArrayEncoder: Encoder[Array[Long]] = newArrayEncoder(PrimitiveLongEncoder) - - /** @since 3.4.0 */ - implicit val newDoubleArrayEncoder: Encoder[Array[Double]] = - newArrayEncoder(PrimitiveDoubleEncoder) - - /** @since 3.4.0 */ - implicit val newFloatArrayEncoder: Encoder[Array[Float]] = newArrayEncoder( - PrimitiveFloatEncoder) - - /** @since 3.4.0 */ - implicit val newByteArrayEncoder: Encoder[Array[Byte]] = BinaryEncoder - - /** @since 3.4.0 */ - implicit val newShortArrayEncoder: Encoder[Array[Short]] = newArrayEncoder( - PrimitiveShortEncoder) - - /** @since 3.4.0 */ - implicit val newBooleanArrayEncoder: Encoder[Array[Boolean]] = - newArrayEncoder(PrimitiveBooleanEncoder) - - /** @since 3.4.0 */ - implicit val newStringArrayEncoder: Encoder[Array[String]] = newArrayEncoder(StringEncoder) - - /** @since 3.4.0 */ - implicit def newProductArrayEncoder[A <: Product: TypeTag]: Encoder[Array[A]] = { - newArrayEncoder(ScalaReflection.encoderFor[A]) - } - - /** - * Creates a [[Dataset]] from a local Seq. - * @since 3.4.0 - */ - implicit def localSeqToDatasetHolder[T: Encoder](s: Seq[T]): DatasetHolder[T] = { - DatasetHolder(session.createDataset(s)) - } -} - -/** - * Lower priority implicit methods for converting Scala objects into [[Dataset]]s. Conflicting - * implicits are placed here to disambiguate resolution. - * - * Reasons for including specific implicits: newProductEncoder - to disambiguate for `List`s which - * are both `Seq` and `Product` - */ -trait LowPrioritySQLImplicits { - - /** @since 3.4.0 */ - implicit def newProductEncoder[T <: Product: TypeTag]: Encoder[T] = - ScalaReflection.encoderFor[T] +/** @inheritdoc */ +abstract class SQLImplicits private[sql] (override val session: SparkSession) + extends api.SQLImplicits { + type DS[U] = Dataset[U] } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index aa6258a14b811..c0590fbd1728f 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import java.net.URI import java.nio.file.{Files, Paths} import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.atomic.{AtomicLong, AtomicReference} +import java.util.concurrent.atomic.AtomicLong import scala.jdk.CollectionConverters._ import scala.reflect.runtime.universe.TypeTag @@ -29,14 +29,17 @@ import com.google.common.cache.{CacheBuilder, CacheLoader} import io.grpc.ClientInterceptor import org.apache.arrow.memory.RootAllocator +import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} +import org.apache.spark.api.java.JavaRDD import org.apache.spark.connect.proto import org.apache.spark.connect.proto.ExecutePlanResponse import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalog.Catalog import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedLongEncoder, UnboundRowEncoder} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, BoxedLongEncoder, UnboundRowEncoder} import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult} import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration import org.apache.spark.sql.connect.client.arrow.ArrowSerializer @@ -69,7 +72,7 @@ import org.apache.spark.util.ArrayImplicits._ class SparkSession private[sql] ( private[sql] val client: SparkConnectClient, private val planIdGenerator: AtomicLong) - extends api.SparkSession[Dataset] + extends api.SparkSession with Logging { private[this] val allocator = new RootAllocator() @@ -84,10 +87,14 @@ class SparkSession private[sql] ( private[sql] val observationRegistry = new ConcurrentHashMap[Long, Observation]() - private[sql] def hijackServerSideSessionIdForTesting(suffix: String) = { + private[sql] def hijackServerSideSessionIdForTesting(suffix: String): Unit = { client.hijackServerSideSessionIdForTesting(suffix) } + /** @inheritdoc */ + override def sparkContext: SparkContext = + throw new UnsupportedOperationException("sparkContext is not supported in Spark Connect.") + /** @inheritdoc */ val conf: RuntimeConfig = new ConnectRuntimeConfig(client) @@ -136,7 +143,7 @@ class SparkSession private[sql] ( /** @inheritdoc */ def createDataset[T: Encoder](data: Seq[T]): Dataset[T] = { - createDataset(encoderFor[T], data.iterator) + createDataset(agnosticEncoderFor[T], data.iterator) } /** @inheritdoc */ @@ -144,6 +151,30 @@ class SparkSession private[sql] ( createDataset(data.asScala.toSeq) } + /** @inheritdoc */ + override def createDataFrame[A <: Product: TypeTag](rdd: RDD[A]): DataFrame = + throwRddNotSupportedException() + + /** @inheritdoc */ + override def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = + throwRddNotSupportedException() + + /** @inheritdoc */ + override def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = + throwRddNotSupportedException() + + /** @inheritdoc */ + override def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = + throwRddNotSupportedException() + + /** @inheritdoc */ + override def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = + throwRddNotSupportedException() + + /** @inheritdoc */ + override def createDataset[T: Encoder](data: RDD[T]): Dataset[T] = + throwRddNotSupportedException() + /** @inheritdoc */ @Experimental def sql(sqlText: String, args: Array[_]): DataFrame = newDataFrame { builder => @@ -209,17 +240,10 @@ class SparkSession private[sql] ( /** @inheritdoc */ def read: DataFrameReader = new DataFrameReader(this) - /** - * Returns a `DataStreamReader` that can be used to read streaming data in as a `DataFrame`. - * {{{ - * sparkSession.readStream.parquet("/path/to/directory/of/parquet/files") - * sparkSession.readStream.schema(schema).json("/path/to/directory/of/json/files") - * }}} - * - * @since 3.5.0 - */ + /** @inheritdoc */ def readStream: DataStreamReader = new DataStreamReader(this) + /** @inheritdoc */ lazy val streams: StreamingQueryManager = new StreamingQueryManager(this) /** @inheritdoc */ @@ -252,19 +276,8 @@ class SparkSession private[sql] ( lazy val udf: UDFRegistration = new UDFRegistration(this) // scalastyle:off - // Disable style checker so "implicits" object can start with lowercase i - /** - * (Scala-specific) Implicit methods available in Scala for converting common names and Symbols - * into [[Column]]s, and for converting common Scala objects into DataFrame`s. - * - * {{{ - * val sparkSession = SparkSession.builder.getOrCreate() - * import sparkSession.implicits._ - * }}} - * - * @since 3.4.0 - */ - object implicits extends SQLImplicits(this) with Serializable + /** @inheritdoc */ + object implicits extends SQLImplicits(this) // scalastyle:on /** @inheritdoc */ @@ -512,6 +525,8 @@ class SparkSession private[sql] ( } } + override private[sql] def isUsable: Boolean = client.isSessionValid + implicit class RichColumn(c: Column) { def expr: proto.Expression = toExpr(c) def typedExpr[T](e: Encoder[T]): proto.Expression = toTypedExpr(c, e) @@ -520,7 +535,9 @@ class SparkSession private[sql] ( // The minimal builder needed to create a spark session. // TODO: implements all methods mentioned in the scaladoc of [[SparkSession]] -object SparkSession extends Logging { +object SparkSession extends api.BaseSparkSessionCompanion with Logging { + override private[sql] type Session = SparkSession + private val MAX_CACHED_SESSIONS = 100 private val planIdGenerator = new AtomicLong private var server: Option[Process] = None @@ -536,29 +553,6 @@ object SparkSession extends Logging { override def load(c: Configuration): SparkSession = create(c) }) - /** The active SparkSession for the current thread. */ - private val activeThreadSession = new InheritableThreadLocal[SparkSession] - - /** Reference to the root SparkSession. */ - private val defaultSession = new AtomicReference[SparkSession] - - /** - * Set the (global) default [[SparkSession]], and (thread-local) active [[SparkSession]] when - * they are not set yet or the associated [[SparkConnectClient]] is unusable. - */ - private def setDefaultAndActiveSession(session: SparkSession): Unit = { - val currentDefault = defaultSession.getAcquire - if (currentDefault == null || !currentDefault.client.isSessionValid) { - // Update `defaultSession` if it is null or the contained session is not valid. There is a - // chance that the following `compareAndSet` fails if a new default session has just been set, - // but that does not matter since that event has happened after this method was invoked. - defaultSession.compareAndSet(currentDefault, session) - } - if (getActiveSession.isEmpty) { - setActiveSession(session) - } - } - /** * Create a new Spark Connect server to connect locally. */ @@ -611,17 +605,6 @@ object SparkSession extends Logging { new SparkSession(configuration.toSparkConnectClient, planIdGenerator) } - /** - * Hook called when a session is closed. - */ - private[sql] def onSessionClose(session: SparkSession): Unit = { - sessions.invalidate(session.client.configuration) - defaultSession.compareAndSet(session, null) - if (getActiveSession.contains(session)) { - clearActiveSession() - } - } - /** * Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]]. * @@ -629,15 +612,15 @@ object SparkSession extends Logging { */ def builder(): Builder = new Builder() - class Builder() extends Logging { + class Builder() extends api.SparkSessionBuilder { // Initialize the connection string of the Spark Connect client builder from SPARK_REMOTE // by default, if it exists. The connection string can be overridden using // the remote() function, as it takes precedence over the SPARK_REMOTE environment variable. private val builder = SparkConnectClient.builder().loadFromEnvironment() private var client: SparkConnectClient = _ - private[this] val options = new scala.collection.mutable.HashMap[String, String] - def remote(connectionString: String): Builder = { + /** @inheritdoc */ + def remote(connectionString: String): this.type = { builder.connectionString(connectionString) this } @@ -649,93 +632,45 @@ object SparkSession extends Logging { * * @since 3.5.0 */ - def interceptor(interceptor: ClientInterceptor): Builder = { + def interceptor(interceptor: ClientInterceptor): this.type = { builder.interceptor(interceptor) this } - private[sql] def client(client: SparkConnectClient): Builder = { + private[sql] def client(client: SparkConnectClient): this.type = { this.client = client this } - /** - * Sets a config option. Options set using this method are automatically propagated to the - * Spark Connect session. Only runtime options are supported. - * - * @since 3.5.0 - */ - def config(key: String, value: String): Builder = synchronized { - options += key -> value - this - } + /** @inheritdoc */ + override def config(key: String, value: String): this.type = super.config(key, value) - /** - * Sets a config option. Options set using this method are automatically propagated to the - * Spark Connect session. Only runtime options are supported. - * - * @since 3.5.0 - */ - def config(key: String, value: Long): Builder = synchronized { - options += key -> value.toString - this - } + /** @inheritdoc */ + override def config(key: String, value: Long): this.type = super.config(key, value) - /** - * Sets a config option. Options set using this method are automatically propagated to the - * Spark Connect session. Only runtime options are supported. - * - * @since 3.5.0 - */ - def config(key: String, value: Double): Builder = synchronized { - options += key -> value.toString - this - } + /** @inheritdoc */ + override def config(key: String, value: Double): this.type = super.config(key, value) - /** - * Sets a config option. Options set using this method are automatically propagated to the - * Spark Connect session. Only runtime options are supported. - * - * @since 3.5.0 - */ - def config(key: String, value: Boolean): Builder = synchronized { - options += key -> value.toString - this - } + /** @inheritdoc */ + override def config(key: String, value: Boolean): this.type = super.config(key, value) - /** - * Sets a config a map of options. Options set using this method are automatically propagated - * to the Spark Connect session. Only runtime options are supported. - * - * @since 3.5.0 - */ - def config(map: Map[String, Any]): Builder = synchronized { - map.foreach { kv: (String, Any) => - { - options += kv._1 -> kv._2.toString - } - } - this - } + /** @inheritdoc */ + override def config(map: Map[String, Any]): this.type = super.config(map) - /** - * Sets a config option. Options set using this method are automatically propagated to both - * `SparkConf` and SparkSession's own configuration. - * - * @since 3.5.0 - */ - def config(map: java.util.Map[String, Any]): Builder = synchronized { - config(map.asScala.toMap) - } + /** @inheritdoc */ + override def config(map: java.util.Map[String, Any]): this.type = super.config(map) + /** @inheritdoc */ @deprecated("enableHiveSupport does not work in Spark Connect") - def enableHiveSupport(): Builder = this + override def enableHiveSupport(): this.type = this + /** @inheritdoc */ @deprecated("master does not work in Spark Connect, please use remote instead") - def master(master: String): Builder = this + override def master(master: String): this.type = this + /** @inheritdoc */ @deprecated("appName does not work in Spark Connect") - def appName(name: String): Builder = this + override def appName(name: String): this.type = this private def tryCreateSessionFromClient(): Option[SparkSession] = { if (client != null && client.isSessionValid) { @@ -816,71 +751,12 @@ object SparkSession extends Logging { } } - /** - * Returns the default SparkSession. If the previously set default SparkSession becomes - * unusable, returns None. - * - * @since 3.5.0 - */ - def getDefaultSession: Option[SparkSession] = - Option(defaultSession.get()).filter(_.client.isSessionValid) - - /** - * Sets the default SparkSession. - * - * @since 3.5.0 - */ - def setDefaultSession(session: SparkSession): Unit = { - defaultSession.set(session) - } - - /** - * Clears the default SparkSession. - * - * @since 3.5.0 - */ - def clearDefaultSession(): Unit = { - defaultSession.set(null) - } - - /** - * Returns the active SparkSession for the current thread. If the previously set active - * SparkSession becomes unusable, returns None. - * - * @since 3.5.0 - */ - def getActiveSession: Option[SparkSession] = - Option(activeThreadSession.get()).filter(_.client.isSessionValid) - - /** - * Changes the SparkSession that will be returned in this thread and its children when - * SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives - * an isolated SparkSession. - * - * @since 3.5.0 - */ - def setActiveSession(session: SparkSession): Unit = { - activeThreadSession.set(session) - } + /** @inheritdoc */ + override def getActiveSession: Option[SparkSession] = super.getActiveSession - /** - * Clears the active SparkSession for current thread. - * - * @since 3.5.0 - */ - def clearActiveSession(): Unit = { - activeThreadSession.remove() - } + /** @inheritdoc */ + override def getDefaultSession: Option[SparkSession] = super.getDefaultSession - /** - * Returns the currently active SparkSession, otherwise the default one. If there is no default - * SparkSession, throws an exception. - * - * @since 3.5.0 - */ - def active: SparkSession = { - getActiveSession - .orElse(getDefaultSession) - .getOrElse(throw new IllegalStateException("No active or default Spark session found")) - } + /** @inheritdoc */ + override def active: SparkSession = super.active } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala index 63fa2821a6c6a..bff6db25a21f2 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala @@ -50,8 +50,9 @@ object ConnectRepl { /_/ Type in expressions to have them evaluated. +Spark connect server version %s. Spark session available as 'spark'. - """.format(spark_version) + """ def main(args: Array[String]): Unit = doMain(args) @@ -102,7 +103,7 @@ Spark session available as 'spark'. // Please note that we make ammonite generate classes instead of objects. // Classes tend to have superior serialization behavior when using UDFs. val main = new ammonite.Main( - welcomeBanner = Option(splash), + welcomeBanner = Option(splash.format(spark_version, spark.version)), predefCode = predefCode, replCodeWrapper = ExtendedCodeClassWrapper, scriptCodeWrapper = ExtendedCodeClassWrapper, diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 11a4a044d20e5..86b1dbe4754e6 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.catalog import java.util import org.apache.spark.sql.{api, DataFrame, Dataset} +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.types.StructType /** @inheritdoc */ -abstract class Catalog extends api.Catalog[Dataset] { +abstract class Catalog extends api.Catalog { /** @inheritdoc */ override def listDatabases(): Dataset[Database] diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala new file mode 100644 index 0000000000000..0344152be86e6 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect + +import scala.language.implicitConversions + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.connect.proto +import org.apache.spark.sql._ +import org.apache.spark.sql.internal.ProtoColumnNode + +/** + * Conversions from sql interfaces to the Connect specific implementation. + * + * This class is mainly used by the implementation. It is also meant to be used by extension + * developers. + * + * We provide both a trait and an object. The trait is useful in situations where an extension + * developer needs to use these conversions in a project covering multiple Spark versions. They + * can create a shim for these conversions, the Spark 4+ version of the shim implements this + * trait, and shims for older versions do not. + */ +@DeveloperApi +trait ConnectConversions { + implicit def castToImpl(session: api.SparkSession): SparkSession = + session.asInstanceOf[SparkSession] + + implicit def castToImpl[T](ds: api.Dataset[T]): Dataset[T] = + ds.asInstanceOf[Dataset[T]] + + implicit def castToImpl(rgds: api.RelationalGroupedDataset): RelationalGroupedDataset = + rgds.asInstanceOf[RelationalGroupedDataset] + + implicit def castToImpl[K, V]( + kvds: api.KeyValueGroupedDataset[K, V]): KeyValueGroupedDataset[K, V] = + kvds.asInstanceOf[KeyValueGroupedDataset[K, V]] + + /** + * Create a [[Column]] from a [[proto.Expression]] + * + * This method is meant to be used by Connect plugins. We do not guarantee any compatibility + * between (minor) versions. + */ + @DeveloperApi + def column(expr: proto.Expression): Column = { + Column(ProtoColumnNode(expr)) + } + + /** + * Create a [[Column]] using a function that manipulates an [[proto.Expression.Builder]]. + * + * This method is meant to be used by Connect plugins. We do not guarantee any compatibility + * between (minor) versions. + */ + @DeveloperApi + def column(f: proto.Expression.Builder => Unit): Column = { + val builder = proto.Expression.newBuilder() + f(builder) + column(builder.build()) + } + + /** + * Implicit helper that makes it easy to construct a Column from an Expression or an Expression + * builder. This allows developers to create a Column in the same way as in earlier versions of + * Spark (before 4.0). + */ + @DeveloperApi + implicit class ColumnConstructorExt(val c: Column.type) { + def apply(e: proto.Expression): Column = column(e) + } +} + +object ConnectConversions extends ConnectConversions diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala index 7578e2424fb42..be1a13cb2fed2 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala @@ -38,7 +38,7 @@ class ConnectRuntimeConfig private[sql] (client: SparkConnectClient) } /** @inheritdoc */ - @throws[NoSuchElementException]("if the key is not set") + @throws[NoSuchElementException]("if the key is not set and there is no default value") def get(key: String): String = getOption(key).getOrElse { throw new NoSuchElementException(key) } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/UdfToProtoUtils.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/UdfToProtoUtils.scala index 85ce2cb820437..409c43f480b8e 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/UdfToProtoUtils.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/UdfToProtoUtils.scala @@ -25,9 +25,9 @@ import com.google.protobuf.ByteString import org.apache.spark.SparkException import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.agnosticEncoderFor import org.apache.spark.sql.connect.common.DataTypeProtoConverter.toConnectProtoType import org.apache.spark.sql.connect.common.UdfPacket -import org.apache.spark.sql.encoderFor import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction} import org.apache.spark.util.{ClosureCleaner, SparkClassUtils, SparkSerDeUtils} @@ -79,12 +79,12 @@ private[sql] object UdfToProtoUtils { udf match { case f: SparkUserDefinedFunction => val outputEncoder = f.outputEncoder - .map(e => encoderFor(e)) + .map(e => agnosticEncoderFor(e)) .getOrElse(RowEncoder.encoderForDataType(f.dataType, lenient = false)) val inputEncoders = if (f.inputEncoders.forall(_.isEmpty)) { Nil // Java UDFs have no bindings for their inputs. } else { - f.inputEncoders.map(e => encoderFor(e.get)) // TODO support Any and UnboundRow. + f.inputEncoders.map(e => agnosticEncoderFor(e.get)) // TODO support Any and UnboundRow. } inputEncoders.foreach(e => protoUdf.addInputTypes(toConnectProtoType(e.dataType))) protoUdf @@ -93,8 +93,8 @@ private[sql] object UdfToProtoUtils { .setAggregate(false) f.givenName.foreach(invokeUdf.setFunctionName) case f: UserDefinedAggregator[_, _, _] => - val outputEncoder = encoderFor(f.aggregator.outputEncoder) - val inputEncoder = encoderFor(f.inputEncoder) + val outputEncoder = agnosticEncoderFor(f.aggregator.outputEncoder) + val inputEncoder = agnosticEncoderFor(f.inputEncoder) protoUdf .setPayload(toUdfPacketBytes(f.aggregator, inputEncoder :: Nil, outputEncoder)) .addInputTypes(toConnectProtoType(inputEncoder.dataType)) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala index 45fa449b58ed7..34a8a91a0ddf8 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala @@ -52,9 +52,10 @@ object ColumnNodeToProtoConverter extends (ColumnNode => proto.Expression) { case Literal(value, Some(dataType), _) => builder.setLiteral(toLiteralProtoBuilder(value, dataType)) - case UnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn, _) => + case u @ UnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn, _) => + val escapedName = u.sql val b = builder.getUnresolvedAttributeBuilder - .setUnparsedIdentifier(unparsedIdentifier) + .setUnparsedIdentifier(escapedName) if (isMetadataColumn) { // We only set this field when it is needed. If we would always set it, // too many of the verbatims we use for testing would have to be regenerated. diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala index 154f2b0405fcd..5c61b9371f37c 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala @@ -17,39 +17,9 @@ package org.apache.spark -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.connect.proto -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder -import org.apache.spark.sql.internal.ProtoColumnNode - package object sql { type DataFrame = Dataset[Row] - private[sql] def encoderFor[E: Encoder]: AgnosticEncoder[E] = { - implicitly[Encoder[E]].asInstanceOf[AgnosticEncoder[E]] - } - - /** - * Create a [[Column]] from a [[proto.Expression]] - * - * This method is meant to be used by Connect plugins. We do not guarantee any compatility - * between (minor) versions. - */ - @DeveloperApi - def column(expr: proto.Expression): Column = { - Column(ProtoColumnNode(expr)) - } - - /** - * Creat a [[Column]] using a function that manipulates an [[proto.Expression.Builder]]. - * - * This method is meant to be used by Connect plugins. We do not guarantee any compatility - * between (minor) versions. - */ - @DeveloperApi - def column(f: proto.Expression.Builder => Unit): Column = { - val builder = proto.Expression.newBuilder() - f(builder) - column(builder.build()) - } + private[sql] def throwRddNotSupportedException(): Nothing = + throw new UnsupportedOperationException("RDDs are not supported in Spark Connect.") } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 789425c9daea1..2ff34a6343644 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -21,11 +21,9 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Evolving import org.apache.spark.connect.proto.Read.DataSource -import org.apache.spark.internal.Logging -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Dataset -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder +import org.apache.spark.sql.{api, DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.connect.ConnectConversions._ +import org.apache.spark.sql.errors.DataTypeErrors import org.apache.spark.sql.types.StructType /** @@ -35,101 +33,49 @@ import org.apache.spark.sql.types.StructType * @since 3.5.0 */ @Evolving -final class DataStreamReader private[sql] (sparkSession: SparkSession) extends Logging { +final class DataStreamReader private[sql] (sparkSession: SparkSession) + extends api.DataStreamReader { - /** - * Specifies the input data source format. - * - * @since 3.5.0 - */ - def format(source: String): DataStreamReader = { + private val sourceBuilder = DataSource.newBuilder() + + /** @inheritdoc */ + def format(source: String): this.type = { sourceBuilder.setFormat(source) this } - /** - * Specifies the input schema. Some data sources (e.g. JSON) can infer the input schema - * automatically from data. By specifying the schema here, the underlying data source can skip - * the schema inference step, and thus speed up data loading. - * - * @since 3.5.0 - */ - def schema(schema: StructType): DataStreamReader = { + /** @inheritdoc */ + def schema(schema: StructType): this.type = { if (schema != null) { sourceBuilder.setSchema(schema.json) // Use json. DDL does not retail all the attributes. } this } - /** - * Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON) - * can infer the input schema automatically from data. By specifying the schema here, the - * underlying data source can skip the schema inference step, and thus speed up data loading. - * - * @since 3.5.0 - */ - def schema(schemaString: String): DataStreamReader = { + /** @inheritdoc */ + override def schema(schemaString: String): this.type = { sourceBuilder.setSchema(schemaString) this } - /** - * Adds an input option for the underlying data source. - * - * @since 3.5.0 - */ - def option(key: String, value: String): DataStreamReader = { + /** @inheritdoc */ + def option(key: String, value: String): this.type = { sourceBuilder.putOptions(key, value) this } - /** - * Adds an input option for the underlying data source. - * - * @since 3.5.0 - */ - def option(key: String, value: Boolean): DataStreamReader = option(key, value.toString) - - /** - * Adds an input option for the underlying data source. - * - * @since 3.5.0 - */ - def option(key: String, value: Long): DataStreamReader = option(key, value.toString) - - /** - * Adds an input option for the underlying data source. - * - * @since 3.5.0 - */ - def option(key: String, value: Double): DataStreamReader = option(key, value.toString) - - /** - * (Scala-specific) Adds input options for the underlying data source. - * - * @since 3.5.0 - */ - def options(options: scala.collection.Map[String, String]): DataStreamReader = { + /** @inheritdoc */ + def options(options: scala.collection.Map[String, String]): this.type = { this.options(options.asJava) - this } - /** - * (Java-specific) Adds input options for the underlying data source. - * - * @since 3.5.0 - */ - def options(options: java.util.Map[String, String]): DataStreamReader = { + /** @inheritdoc */ + override def options(options: java.util.Map[String, String]): this.type = { sourceBuilder.putAllOptions(options) this } - /** - * Loads input data stream in as a `DataFrame`, for data streams that don't require a path (e.g. - * external key-value stores). - * - * @since 3.5.0 - */ + /** @inheritdoc */ def load(): DataFrame = { sparkSession.newDataFrame { relationBuilder => relationBuilder.getReadBuilder @@ -138,120 +84,14 @@ final class DataStreamReader private[sql] (sparkSession: SparkSession) extends L } } - /** - * Loads input in as a `DataFrame`, for data streams that read from some path. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def load(path: String): DataFrame = { sourceBuilder.clearPaths() sourceBuilder.addPaths(path) load() } - /** - * Loads a JSON file stream and returns the results as a `DataFrame`. - * - * JSON Lines (newline-delimited JSON) is supported by - * default. For JSON (one record per file), set the `multiLine` option to true. - * - * This function goes through the input once to determine the input schema. If you know the - * schema in advance, use the version that specifies the schema to avoid the extra scan. - * - * You can set the following option(s):
  • `maxFilesPerTrigger` (default: no max limit): - * sets the maximum number of new files to be considered in every trigger.
  • - *
  • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files to - * be considered in every trigger.
- * - * You can find the JSON-specific options for reading JSON file stream in - * Data Source Option in the version you use. - * - * @since 3.5.0 - */ - def json(path: String): DataFrame = { - format("json").load(path) - } - - /** - * Loads a CSV file stream and returns the result as a `DataFrame`. - * - * This function will go through the input once to determine the input schema if `inferSchema` - * is enabled. To avoid going through the entire data once, disable `inferSchema` option or - * specify the schema explicitly using `schema`. - * - * You can set the following option(s):
  • `maxFilesPerTrigger` (default: no max limit): - * sets the maximum number of new files to be considered in every trigger.
  • - *
  • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files to - * be considered in every trigger.
- * - * You can find the CSV-specific options for reading CSV file stream in - * Data Source Option in the version you use. - * - * @since 3.5.0 - */ - def csv(path: String): DataFrame = format("csv").load(path) - - /** - * Loads a XML file stream and returns the result as a `DataFrame`. - * - * This function will go through the input once to determine the input schema if `inferSchema` - * is enabled. To avoid going through the entire data once, disable `inferSchema` option or - * specify the schema explicitly using `schema`. - * - * You can set the following option(s):
  • `maxFilesPerTrigger` (default: no max limit): - * sets the maximum number of new files to be considered in every trigger.
  • - *
  • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files to - * be considered in every trigger.
- * - * You can find the XML-specific options for reading XML file stream in - * Data Source Option in the version you use. - * - * @since 4.0.0 - */ - def xml(path: String): DataFrame = format("xml").load(path) - - /** - * Loads a ORC file stream, returning the result as a `DataFrame`. - * - * You can set the following option(s):
  • `maxFilesPerTrigger` (default: no max limit): - * sets the maximum number of new files to be considered in every trigger.
  • - *
  • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files to - * be considered in every trigger.
- * - * ORC-specific option(s) for reading ORC file stream can be found in Data - * Source Option in the version you use. - * - * @since 3.5.0 - */ - def orc(path: String): DataFrame = format("orc").load(path) - - /** - * Loads a Parquet file stream, returning the result as a `DataFrame`. - * - * You can set the following option(s):
  • `maxFilesPerTrigger` (default: no max limit): - * sets the maximum number of new files to be considered in every trigger.
  • - *
  • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files to - * be considered in every trigger.
- * - * Parquet-specific option(s) for reading Parquet file stream can be found in Data - * Source Option in the version you use. - * - * @since 3.5.0 - */ - def parquet(path: String): DataFrame = format("parquet").load(path) - - /** - * Define a Streaming DataFrame on a Table. The DataSource corresponding to the table should - * support streaming mode. - * @param tableName - * The name of the table - * @since 3.5.0 - */ + /** @inheritdoc */ def table(tableName: String): DataFrame = { require(tableName != null, "The table name can't be null") sparkSession.newDataFrame { builder => @@ -263,59 +103,44 @@ final class DataStreamReader private[sql] (sparkSession: SparkSession) extends L } } - /** - * Loads text files and returns a `DataFrame` whose schema starts with a string column named - * "value", and followed by partitioned columns if there are any. The text files must be encoded - * as UTF-8. - * - * By default, each line in the text files is a new row in the resulting DataFrame. For example: - * {{{ - * // Scala: - * spark.readStream.text("/path/to/directory/") - * - * // Java: - * spark.readStream().text("/path/to/directory/") - * }}} - * - * You can set the following option(s):
  • `maxFilesPerTrigger` (default: no max limit): - * sets the maximum number of new files to be considered in every trigger.
  • - *
  • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files to - * be considered in every trigger.
- * - * You can find the text-specific options for reading text files in - * Data Source Option in the version you use. - * - * @since 3.5.0 - */ - def text(path: String): DataFrame = format("text").load(path) - - /** - * Loads text file(s) and returns a `Dataset` of String. The underlying schema of the Dataset - * contains a single string column named "value". The text files must be encoded as UTF-8. - * - * If the directory structure of the text files contains partitioning information, those are - * ignored in the resulting Dataset. To include partitioning information as columns, use `text`. - * - * By default, each line in the text file is a new element in the resulting Dataset. For - * example: - * {{{ - * // Scala: - * spark.readStream.textFile("/path/to/spark/README.md") - * - * // Java: - * spark.readStream().textFile("/path/to/spark/README.md") - * }}} - * - * You can set the text-specific options as specified in `DataStreamReader.text`. - * - * @param path - * input path - * @since 3.5.0 - */ - def textFile(path: String): Dataset[String] = { - text(path).select("value").as[String](StringEncoder) + override protected def assertNoSpecifiedSchema(operation: String): Unit = { + if (sourceBuilder.hasSchema) { + throw DataTypeErrors.userSpecifiedSchemaUnsupportedError(operation) + } } - private val sourceBuilder = DataSource.newBuilder() + /////////////////////////////////////////////////////////////////////////////////////// + // Covariant overrides. + /////////////////////////////////////////////////////////////////////////////////////// + + /** @inheritdoc */ + override def option(key: String, value: Boolean): this.type = super.option(key, value) + + /** @inheritdoc */ + override def option(key: String, value: Long): this.type = super.option(key, value) + + /** @inheritdoc */ + override def option(key: String, value: Double): this.type = super.option(key, value) + + /** @inheritdoc */ + override def json(path: String): DataFrame = super.json(path) + + /** @inheritdoc */ + override def csv(path: String): DataFrame = super.csv(path) + + /** @inheritdoc */ + override def xml(path: String): DataFrame = super.xml(path) + + /** @inheritdoc */ + override def orc(path: String): DataFrame = super.orc(path) + + /** @inheritdoc */ + override def parquet(path: String): DataFrame = super.parquet(path) + + /** @inheritdoc */ + override def text(path: String): DataFrame = super.text(path) + + /** @inheritdoc */ + override def textFile(path: String): Dataset[String] = super.textFile(path) + } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index c8c714047788b..9fcc31e562682 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,9 +29,8 @@ import org.apache.spark.api.java.function.VoidFunction2 import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Command import org.apache.spark.connect.proto.WriteStreamOperationStart -import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Dataset, ForeachWriter} -import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, UdfUtils} +import org.apache.spark.sql.{api, Dataset, ForeachWriter} +import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket} import org.apache.spark.sql.execution.streaming.AvailableNowTrigger import org.apache.spark.sql.execution.streaming.ContinuousTrigger import org.apache.spark.sql.execution.streaming.OneTimeTrigger @@ -47,63 +46,23 @@ import org.apache.spark.util.SparkSerDeUtils * @since 3.5.0 */ @Evolving -final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging { +final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends api.DataStreamWriter[T] { + override type DS[U] = Dataset[U] - /** - * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink.
  • - * `OutputMode.Append()`: only the new rows in the streaming DataFrame/Dataset will be written - * to the sink.
  • `OutputMode.Complete()`: all the rows in the streaming - * DataFrame/Dataset will be written to the sink every time there are some updates.
  • - * `OutputMode.Update()`: only the rows that were updated in the streaming DataFrame/Dataset - * will be written to the sink every time there are some updates. If the query doesn't contain - * aggregations, it will be equivalent to `OutputMode.Append()` mode.
- * - * @since 3.5.0 - */ - def outputMode(outputMode: OutputMode): DataStreamWriter[T] = { + /** @inheritdoc */ + def outputMode(outputMode: OutputMode): this.type = { sinkBuilder.setOutputMode(outputMode.toString.toLowerCase(Locale.ROOT)) this } - /** - * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink.
  • - * `append`: only the new rows in the streaming DataFrame/Dataset will be written to the - * sink.
  • `complete`: all the rows in the streaming DataFrame/Dataset will be written - * to the sink every time there are some updates.
  • `update`: only the rows that were - * updated in the streaming DataFrame/Dataset will be written to the sink every time there are - * some updates. If the query doesn't contain aggregations, it will be equivalent to `append` - * mode.
- * - * @since 3.5.0 - */ - def outputMode(outputMode: String): DataStreamWriter[T] = { + /** @inheritdoc */ + def outputMode(outputMode: String): this.type = { sinkBuilder.setOutputMode(outputMode) this } - /** - * Set the trigger for the stream query. The default value is `ProcessingTime(0)` and it will - * run the query as fast as possible. - * - * Scala Example: - * {{{ - * df.writeStream.trigger(ProcessingTime("10 seconds")) - * - * import scala.concurrent.duration._ - * df.writeStream.trigger(ProcessingTime(10.seconds)) - * }}} - * - * Java Example: - * {{{ - * df.writeStream().trigger(ProcessingTime.create("10 seconds")) - * - * import java.util.concurrent.TimeUnit - * df.writeStream().trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) - * }}} - * - * @since 3.5.0 - */ - def trigger(trigger: Trigger): DataStreamWriter[T] = { + /** @inheritdoc */ + def trigger(trigger: Trigger): this.type = { trigger match { case ProcessingTimeTrigger(intervalMs) => sinkBuilder.setProcessingTimeInterval(s"$intervalMs milliseconds") @@ -117,123 +76,54 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging { this } - /** - * Specifies the name of the [[StreamingQuery]] that can be started with `start()`. This name - * must be unique among all the currently active queries in the associated SQLContext. - * - * @since 3.5.0 - */ - def queryName(queryName: String): DataStreamWriter[T] = { + /** @inheritdoc */ + def queryName(queryName: String): this.type = { sinkBuilder.setQueryName(queryName) this } - /** - * Specifies the underlying output data source. - * - * @since 3.5.0 - */ - def format(source: String): DataStreamWriter[T] = { + /** @inheritdoc */ + def format(source: String): this.type = { sinkBuilder.setFormat(source) this } - /** - * Partitions the output by the given columns on the file system. If specified, the output is - * laid out on the file system similar to Hive's partitioning scheme. As an example, when we - * partition a dataset by year and then month, the directory layout would look like: - * - *
  • year=2016/month=01/
  • year=2016/month=02/
- * - * Partitioning is one of the most widely used techniques to optimize physical data layout. It - * provides a coarse-grained index for skipping unnecessary data reads when queries have - * predicates on the partitioned columns. In order for partitioning to work well, the number of - * distinct values in each column should typically be less than tens of thousands. - * - * @since 3.5.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def partitionBy(colNames: String*): DataStreamWriter[T] = { + def partitionBy(colNames: String*): this.type = { sinkBuilder.clearPartitioningColumnNames() sinkBuilder.addAllPartitioningColumnNames(colNames.asJava) this } - /** - * Clusters the output by the given columns. If specified, the output is laid out such that - * records with similar values on the clustering column are grouped together in the same file. - * - * Clustering improves query efficiency by allowing queries with predicates on the clustering - * columns to skip unnecessary data. Unlike partitioning, clustering can be used on very high - * cardinality columns. - * - * @since 4.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def clusterBy(colNames: String*): DataStreamWriter[T] = { + def clusterBy(colNames: String*): this.type = { sinkBuilder.clearClusteringColumnNames() sinkBuilder.addAllClusteringColumnNames(colNames.asJava) this } - /** - * Adds an output option for the underlying data source. - * - * @since 3.5.0 - */ - def option(key: String, value: String): DataStreamWriter[T] = { + /** @inheritdoc */ + def option(key: String, value: String): this.type = { sinkBuilder.putOptions(key, value) this } - /** - * Adds an output option for the underlying data source. - * - * @since 3.5.0 - */ - def option(key: String, value: Boolean): DataStreamWriter[T] = option(key, value.toString) - - /** - * Adds an output option for the underlying data source. - * - * @since 3.5.0 - */ - def option(key: String, value: Long): DataStreamWriter[T] = option(key, value.toString) - - /** - * Adds an output option for the underlying data source. - * - * @since 3.5.0 - */ - def option(key: String, value: Double): DataStreamWriter[T] = option(key, value.toString) - - /** - * (Scala-specific) Adds output options for the underlying data source. - * - * @since 3.5.0 - */ - def options(options: scala.collection.Map[String, String]): DataStreamWriter[T] = { + /** @inheritdoc */ + def options(options: scala.collection.Map[String, String]): this.type = { this.options(options.asJava) this } - /** - * Adds output options for the underlying data source. - * - * @since 3.5.0 - */ - def options(options: java.util.Map[String, String]): DataStreamWriter[T] = { + /** @inheritdoc */ + def options(options: java.util.Map[String, String]): this.type = { sinkBuilder.putAllOptions(options) this } - /** - * Sets the output of the streaming query to be processed using the provided writer object. - * object. See [[org.apache.spark.sql.ForeachWriter]] for more details on the lifecycle and - * semantics. - * @since 3.5.0 - */ - def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = { + /** @inheritdoc */ + def foreach(writer: ForeachWriter[T]): this.type = { val serialized = SparkSerDeUtils.serialize(ForeachWriterPacket(writer, ds.agnosticEncoder)) val scalaWriterBuilder = proto.ScalarScalaUDF .newBuilder() @@ -242,21 +132,9 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging { this } - /** - * :: Experimental :: - * - * (Scala-specific) Sets the output of the streaming query to be processed using the provided - * function. This is supported only in the micro-batch execution modes (that is, when the - * trigger is not continuous). In every micro-batch, the provided function will be called in - * every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. The - * batchId can be used to deduplicate and transactionally write the output (that is, the - * provided Dataset) to external systems. The output Dataset is guaranteed to be exactly the - * same for the same batchId (assuming all operations are deterministic in the query). - * - * @since 3.5.0 - */ + /** @inheritdoc */ @Evolving - def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] = { + def foreachBatch(function: (Dataset[T], Long) => Unit): this.type = { val serializedFn = SparkSerDeUtils.serialize(function) sinkBuilder.getForeachBatchBuilder.getScalaFunctionBuilder .setPayload(ByteString.copyFrom(serializedFn)) @@ -265,48 +143,13 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging { this } - /** - * :: Experimental :: - * - * (Java-specific) Sets the output of the streaming query to be processed using the provided - * function. This is supported only in the micro-batch execution modes (that is, when the - * trigger is not continuous). In every micro-batch, the provided function will be called in - * every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. The - * batchId can be used to deduplicate and transactionally write the output (that is, the - * provided Dataset) to external systems. The output Dataset is guaranteed to be exactly the - * same for the same batchId (assuming all operations are deterministic in the query). - * - * @since 3.5.0 - */ - @Evolving - def foreachBatch(function: VoidFunction2[Dataset[T], java.lang.Long]): DataStreamWriter[T] = { - foreachBatch(UdfUtils.foreachBatchFuncToScalaFunc(function)) - } - - /** - * Starts the execution of the streaming query, which will continually output results to the - * given path as new data arrives. The returned [[StreamingQuery]] object can be used to - * interact with the stream. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def start(path: String): StreamingQuery = { sinkBuilder.setPath(path) start() } - /** - * Starts the execution of the streaming query, which will continually output results to the - * given path as new data arrives. The returned [[StreamingQuery]] object can be used to - * interact with the stream. Throws a `TimeoutException` if the following conditions are met: - * - Another run of the same streaming query, that is a streaming query sharing the same - * checkpoint location, is already active on the same Spark Driver - * - The SQL configuration `spark.sql.streaming.stopActiveRunOnRestart` is enabled - * - The active run cannot be stopped within the timeout controlled by the SQL configuration - * `spark.sql.streaming.stopTimeout` - * - * @since 3.5.0 - */ + /** @inheritdoc */ @throws[TimeoutException] def start(): StreamingQuery = { val startCmd = Command @@ -323,22 +166,7 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging { RemoteStreamingQuery.fromStartCommandResponse(ds.sparkSession, resp) } - /** - * Starts the execution of the streaming query, which will continually output results to the - * given table as new data arrives. The returned [[StreamingQuery]] object can be used to - * interact with the stream. - * - * For v1 table, partitioning columns provided by `partitionBy` will be respected no matter the - * table exists or not. A new table will be created if the table not exists. - * - * For v2 table, `partitionBy` will be ignored if the table already exists. `partitionBy` will - * be respected only if the v2 table does not exist. Besides, the v2 table created by this API - * lacks some functionalities (e.g., customized properties, options, and serde info). If you - * need them, please create the v2 table manually before the execution to avoid creating a table - * with incomplete information. - * - * @since 3.5.0 - */ + /** @inheritdoc */ @Evolving @throws[TimeoutException] def toTable(tableName: String): StreamingQuery = { @@ -346,6 +174,24 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging { start() } + /////////////////////////////////////////////////////////////////////////////////////// + // Covariant Overrides + /////////////////////////////////////////////////////////////////////////////////////// + + /** @inheritdoc */ + override def option(key: String, value: Boolean): this.type = super.option(key, value) + + /** @inheritdoc */ + override def option(key: String, value: Long): this.type = super.option(key, value) + + /** @inheritdoc */ + override def option(key: String, value: Double): this.type = super.option(key, value) + + /** @inheritdoc */ + @Evolving + override def foreachBatch(function: VoidFunction2[Dataset[T], java.lang.Long]): this.type = + super.foreachBatch(function) + private val sinkBuilder = WriteStreamOperationStart .newBuilder() .setInput(ds.plan.getRoot) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala index 3b47269875f4a..29fbcc443deb9 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala @@ -26,10 +26,10 @@ import org.apache.spark.connect.proto.ExecutePlanResponse import org.apache.spark.connect.proto.StreamingQueryCommand import org.apache.spark.connect.proto.StreamingQueryCommandResult import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance -import org.apache.spark.sql.{api, Dataset, SparkSession} +import org.apache.spark.sql.{api, SparkSession} /** @inheritdoc */ -trait StreamingQuery extends api.StreamingQuery[Dataset] { +trait StreamingQuery extends api.StreamingQuery { /** @inheritdoc */ override def sparkSession: SparkSession diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 7efced227d6d1..647d29c714dbb 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -27,7 +27,7 @@ import org.apache.spark.connect.proto.Command import org.apache.spark.connect.proto.StreamingQueryManagerCommand import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{api, SparkSession} import org.apache.spark.sql.connect.common.InvalidPlanInput /** @@ -36,7 +36,9 @@ import org.apache.spark.sql.connect.common.InvalidPlanInput * @since 3.5.0 */ @Evolving -class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Logging { +class StreamingQueryManager private[sql] (sparkSession: SparkSession) + extends api.StreamingQueryManager + with Logging { // Mapping from id to StreamingQueryListener. There's another mapping from id to // StreamingQueryListener on server side. This is used by removeListener() to find the id @@ -53,29 +55,17 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo streamingQueryListenerBus.close() } - /** - * Returns a list of active queries associated with this SQLContext - * - * @since 3.5.0 - */ + /** @inheritdoc */ def active: Array[StreamingQuery] = { executeManagerCmd(_.setActive(true)).getActive.getActiveQueriesList.asScala.map { q => RemoteStreamingQuery.fromStreamingQueryInstanceResponse(sparkSession, q) }.toArray } - /** - * Returns the query if there is an active query with the given id, or null. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def get(id: UUID): StreamingQuery = get(id.toString) - /** - * Returns the query if there is an active query with the given id, or null. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def get(id: String): StreamingQuery = { val response = executeManagerCmd(_.setGetQuery(id)) if (response.hasQuery) { @@ -85,52 +75,13 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } } - /** - * Wait until any of the queries on the associated SQLContext has terminated since the creation - * of the context, or since `resetTerminated()` was called. If any query was terminated with an - * exception, then the exception will be thrown. - * - * If a query has terminated, then subsequent calls to `awaitAnyTermination()` will either - * return immediately (if the query was terminated by `query.stop()`), or throw the exception - * immediately (if the query was terminated with exception). Use `resetTerminated()` to clear - * past terminations and wait for new terminations. - * - * In the case where multiple queries have terminated since `resetTermination()` was called, if - * any query has terminated with exception, then `awaitAnyTermination()` will throw any of the - * exception. For correctly documenting exceptions across multiple queries, users need to stop - * all of them after any of them terminates with exception, and then check the - * `query.exception()` for each query. - * - * @throws StreamingQueryException - * if any query has terminated with an exception - * @since 3.5.0 - */ + /** @inheritdoc */ @throws[StreamingQueryException] def awaitAnyTermination(): Unit = { executeManagerCmd(_.getAwaitAnyTerminationBuilder.build()) } - /** - * Wait until any of the queries on the associated SQLContext has terminated since the creation - * of the context, or since `resetTerminated()` was called. Returns whether any query has - * terminated or not (multiple may have terminated). If any query has terminated with an - * exception, then the exception will be thrown. - * - * If a query has terminated, then subsequent calls to `awaitAnyTermination()` will either - * return `true` immediately (if the query was terminated by `query.stop()`), or throw the - * exception immediately (if the query was terminated with exception). Use `resetTerminated()` - * to clear past terminations and wait for new terminations. - * - * In the case where multiple queries have terminated since `resetTermination()` was called, if - * any query has terminated with exception, then `awaitAnyTermination()` will throw any of the - * exception. For correctly documenting exceptions across multiple queries, users need to stop - * all of them after any of them terminates with exception, and then check the - * `query.exception()` for each query. - * - * @throws StreamingQueryException - * if any query has terminated with an exception - * @since 3.5.0 - */ + /** @inheritdoc */ @throws[StreamingQueryException] def awaitAnyTermination(timeoutMs: Long): Boolean = { require(timeoutMs > 0, "Timeout has to be positive") @@ -139,40 +90,22 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo timeoutMs)).getAwaitAnyTermination.getTerminated } - /** - * Forget about past terminated queries so that `awaitAnyTermination()` can be used again to - * wait for new terminations. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def resetTerminated(): Unit = { executeManagerCmd(_.setResetTerminated(true)) } - /** - * Register a [[StreamingQueryListener]] to receive up-calls for life cycle events of - * [[StreamingQuery]]. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def addListener(listener: StreamingQueryListener): Unit = { streamingQueryListenerBus.append(listener) } - /** - * Deregister a [[StreamingQueryListener]]. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def removeListener(listener: StreamingQueryListener): Unit = { streamingQueryListenerBus.remove(listener) } - /** - * List all [[StreamingQueryListener]]s attached to this [[StreamingQueryManager]]. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def listListeners(): Array[StreamingQueryListener] = { streamingQueryListenerBus.list() } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala index 0e3a683d2701d..ce552bdd4f0f0 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala @@ -69,7 +69,7 @@ class CatalogSuite extends ConnectFunSuite with RemoteSparkSession with SQLHelpe val exception = intercept[SparkException] { spark.catalog.setCurrentCatalog("notExists") } - assert(exception.getErrorClass == "CATALOG_NOT_FOUND") + assert(exception.getCondition == "CATALOG_NOT_FOUND") spark.catalog.setCurrentCatalog("testcat") assert(spark.catalog.currentCatalog().equals("testcat")) val catalogsAfterChange = spark.catalog.listCatalogs().collect() diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala index e57b051890f56..0d9685d9c710f 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala @@ -27,6 +27,7 @@ import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.apache.spark.SparkException import org.apache.spark.connect.proto import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession, SQLHelper} +import org.apache.spark.storage.StorageLevel class CheckpointSuite extends ConnectFunSuite with RemoteSparkSession with SQLHelper { @@ -50,12 +51,20 @@ class CheckpointSuite extends ConnectFunSuite with RemoteSparkSession with SQLHe checkFragments(captureStdOut(block), fragmentsToCheck) } - test("checkpoint") { + test("localCheckpoint") { val df = spark.range(100).localCheckpoint() testCapturedStdOut(df.explain(), "ExistingRDD") } - test("checkpoint gc") { + test("localCheckpoint with StorageLevel") { + // We don't have a way to reach into the server and assert the storage level server side, but + // this test should cover for unexpected errors in the API. + val df = + spark.range(100).localCheckpoint(eager = true, storageLevel = StorageLevel.DISK_ONLY) + df.collect() + } + + test("localCheckpoint gc") { val df = spark.range(100).localCheckpoint(eager = true) val encoder = df.agnosticEncoder val dfId = df.plan.getRoot.getCachedRemoteRelation.getRelationId @@ -77,7 +86,7 @@ class CheckpointSuite extends ConnectFunSuite with RemoteSparkSession with SQLHe // This test is flaky because cannot guarantee GC // You can locally run this to verify the behavior. - ignore("checkpoint gc derived DataFrame") { + ignore("localCheckpoint gc derived DataFrame") { var df1 = spark.range(100).localCheckpoint(eager = true) var derived = df1.repartition(10) val encoder = df1.agnosticEncoder diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala index 88281352f2479..84ed624a95214 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala @@ -251,16 +251,16 @@ class ClientDataFrameStatSuite extends ConnectFunSuite with RemoteSparkSession { val error1 = intercept[AnalysisException] { df.stat.bloomFilter("id", -1000, 100) } - assert(error1.getErrorClass === "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE") + assert(error1.getCondition === "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE") val error2 = intercept[AnalysisException] { df.stat.bloomFilter("id", 1000, -100) } - assert(error2.getErrorClass === "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE") + assert(error2.getCondition === "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE") val error3 = intercept[AnalysisException] { df.stat.bloomFilter("id", 1000, -1.0) } - assert(error3.getErrorClass === "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE") + assert(error3.getCondition === "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE") } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 52cdbd47357f3..0371981b728d1 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -23,7 +23,7 @@ import java.util.Properties import scala.collection.mutable import scala.concurrent.{ExecutionContext, Future} -import scala.concurrent.duration.DurationInt +import scala.concurrent.duration.{DurationInt, FiniteDuration} import scala.jdk.CollectionConverters._ import org.apache.commons.io.FileUtils @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkResult} +import org.apache.spark.sql.connect.client.{RetryPolicy, SparkConnectClient, SparkResult} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.test.{ConnectFunSuite, IntegrationTestUtils, RemoteSparkSession, SQLHelper} @@ -95,7 +95,7 @@ class ClientE2ETestSuite .collect() } assert( - ex.getErrorClass === + ex.getCondition === "INCONSISTENT_BEHAVIOR_CROSS_VERSION.PARSE_DATETIME_BY_NEW_PARSER") assert( ex.getMessageParameters.asScala == Map( @@ -122,12 +122,12 @@ class ClientE2ETestSuite Seq("1").toDS().withColumn("udf_val", throwException($"value")).collect() } - assert(ex.getErrorClass != null) + assert(ex.getCondition != null) assert(!ex.getMessageParameters.isEmpty) assert(ex.getCause.isInstanceOf[SparkException]) val cause = ex.getCause.asInstanceOf[SparkException] - assert(cause.getErrorClass == null) + assert(cause.getCondition == null) assert(cause.getMessageParameters.isEmpty) assert(cause.getMessage.contains("test" * 10000)) } @@ -141,7 +141,7 @@ class ClientE2ETestSuite val ex = intercept[AnalysisException] { spark.sql("select x").collect() } - assert(ex.getErrorClass != null) + assert(ex.getCondition != null) assert(!ex.messageParameters.isEmpty) assert(ex.getSqlState != null) assert(!ex.isInternalError) @@ -169,14 +169,14 @@ class ClientE2ETestSuite val ex = intercept[NoSuchNamespaceException] { spark.sql("use database123") } - assert(ex.getErrorClass != null) + assert(ex.getCondition != null) } test("table not found for spark.catalog.getTable") { val ex = intercept[AnalysisException] { spark.catalog.getTable("test_table") } - assert(ex.getErrorClass != null) + assert(ex.getCondition != null) } test("throw NamespaceAlreadyExistsException") { @@ -185,7 +185,7 @@ class ClientE2ETestSuite val ex = intercept[NamespaceAlreadyExistsException] { spark.sql("create database test_db") } - assert(ex.getErrorClass != null) + assert(ex.getCondition != null) } finally { spark.sql("drop database test_db") } @@ -197,7 +197,7 @@ class ClientE2ETestSuite val ex = intercept[TempTableAlreadyExistsException] { spark.sql("create temporary view test_view as select 1") } - assert(ex.getErrorClass != null) + assert(ex.getCondition != null) } finally { spark.sql("drop view test_view") } @@ -209,7 +209,7 @@ class ClientE2ETestSuite val ex = intercept[TableAlreadyExistsException] { spark.sql(s"create table testcat.test_table (id int)") } - assert(ex.getErrorClass != null) + assert(ex.getCondition != null) } } @@ -217,7 +217,7 @@ class ClientE2ETestSuite val ex = intercept[ParseException] { spark.sql("selet 1").collect() } - assert(ex.getErrorClass != null) + assert(ex.getCondition != null) assert(!ex.messageParameters.isEmpty) assert(ex.getSqlState != null) assert(!ex.isInternalError) @@ -1566,6 +1566,25 @@ class ClientE2ETestSuite val result = df.select(trim(col("col"), " ").as("trimmed_col")).collect() assert(result sameElements Array(Row("a"), Row("b"), Row("c"))) } + + test("SPARK-49673: new batch size, multiple batches") { + val maxBatchSize = spark.conf.get("spark.connect.grpc.arrow.maxBatchSize").dropRight(1).toInt + // Adjust client grpcMaxMessageSize to maxBatchSize (10MiB; set in RemoteSparkSession config) + val sparkWithLowerMaxMessageSize = SparkSession + .builder() + .client( + SparkConnectClient + .builder() + .userId("test") + .port(port) + .grpcMaxMessageSize(maxBatchSize) + .retryPolicy(RetryPolicy + .defaultPolicy() + .copy(maxRetries = Some(10), maxBackoff = Some(FiniteDuration(30, "s")))) + .build()) + .create() + assert(sparkWithLowerMaxMessageSize.range(maxBatchSize).collect().length == maxBatchSize) + } } private[sql] case class ClassData(a: String, b: Int) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala index c37100b729029..86c7a20136851 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala @@ -173,8 +173,8 @@ class ColumnTestSuite extends ConnectFunSuite { assert(explain1 != explain2) assert(explain1.strip() == "+(a, b)") assert(explain2.contains("UnresolvedFunction(+")) - assert(explain2.contains("UnresolvedAttribute(a")) - assert(explain2.contains("UnresolvedAttribute(b")) + assert(explain2.contains("UnresolvedAttribute(List(a")) + assert(explain2.contains("UnresolvedAttribute(List(b")) } private def testColName(dataType: DataType, f: ColumnName => StructField): Unit = { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 315f80e13eff7..c557b54732797 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.avro.{functions => avroFn} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.client.SparkConnectClient import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala index 57342e12fcb51..b3b8020b1e4c7 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala @@ -26,6 +26,7 @@ import org.apache.arrow.memory.RootAllocator import org.apache.commons.lang3.SystemUtils import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.agnosticEncoderFor import org.apache.spark.sql.connect.client.SparkConnectClient import org.apache.spark.sql.connect.client.arrow.{ArrowDeserializers, ArrowSerializer} import org.apache.spark.sql.test.ConnectFunSuite @@ -55,7 +56,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { import org.apache.spark.util.ArrayImplicits._ import spark.implicits._ def testImplicit[T: Encoder](expected: T): Unit = { - val encoder = encoderFor[T] + val encoder = agnosticEncoderFor[T] val allocator = new RootAllocator() try { val batch = ArrowSerializer.serialize( diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala similarity index 52% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DatasetHolder.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala index 66f591bf1fb99..ed930882ac2fd 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DatasetHolder.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala @@ -16,26 +16,18 @@ */ package org.apache.spark.sql +import org.apache.spark.sql.api.SparkSessionBuilder +import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession} + /** - * A container for a [[Dataset]], used for implicit conversions in Scala. - * - * To use this, import implicit conversions in SQL: - * {{{ - * val spark: SparkSession = ... - * import spark.implicits._ - * }}} - * - * @since 3.4.0 + * Make sure the api.SparkSessionBuilder binds to Connect implementation. */ -case class DatasetHolder[T] private[sql] (private val ds: Dataset[T]) { - - // This is declared with parentheses to prevent the Scala compiler from treating - // `rdd.toDS("1")` as invoking this toDS and then apply on the returned Dataset. - def toDS(): Dataset[T] = ds - - // This is declared with parentheses to prevent the Scala compiler from treating - // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. - def toDF(): DataFrame = ds.toDF() - - def toDF(colNames: String*): DataFrame = ds.toDF(colNames: _*) +class SparkSessionBuilderImplementationBindingSuite + extends ConnectFunSuite + with api.SparkSessionBuilderImplementationBindingSuite + with RemoteSparkSession { + override protected def configure(builder: SparkSessionBuilder): builder.type = { + // We need to set this configuration because the port used by the server is random. + builder.remote(s"sc://localhost:$serverPort") + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala index 8abc41639fdd2..dec56554d143e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala @@ -22,6 +22,7 @@ import scala.util.control.NonFatal import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, MethodDescriptor} +import org.apache.spark.SparkException import org.apache.spark.sql.test.ConnectFunSuite import org.apache.spark.util.SparkSerDeUtils @@ -113,7 +114,7 @@ class SparkSessionSuite extends ConnectFunSuite { SparkSession.clearActiveSession() assert(SparkSession.getDefaultSession.isEmpty) assert(SparkSession.getActiveSession.isEmpty) - intercept[IllegalStateException](SparkSession.active) + intercept[SparkException](SparkSession.active) // Create a session val session1 = SparkSession.builder().remote(connectionString1).getOrCreate() diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index abf03cfbc6722..693c807ec71ea 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -158,6 +158,7 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.catalyst.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.columnar.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.connector.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.classic.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.execution.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.internal.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.jdbc.*"), @@ -226,6 +227,8 @@ object CheckConnectJvmClientCompatibility { "org.apache.spark.sql.SparkSession.baseRelationToDataFrame"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.createDataset"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.executeCommand"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.SparkSession.canUseSession"), // SparkSession#implicits ProblemFilters.exclude[DirectMissingMethodProblem]( @@ -303,7 +306,13 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.DataFrameReader.validateJsonSchema"), ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.DataFrameReader.validateXmlSchema")) + "org.apache.spark.sql.DataFrameReader.validateXmlSchema"), + + // Protected DataStreamReader methods... + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.streaming.DataStreamReader.validateJsonSchema"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.streaming.DataStreamReader.validateXmlSchema")) checkMiMaCompatibility(clientJar, sqlJar, includedRules, excludeRules) } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index 46aeaeff43d2f..ac56600392aa3 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -224,7 +224,7 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { val error = constructor(testParams).asInstanceOf[Throwable with SparkThrowable] assert(error.getMessage.contains(testParams.message)) assert(error.getCause == null) - assert(error.getErrorClass == testParams.errorClass.get) + assert(error.getCondition == testParams.errorClass.get) assert(error.getMessageParameters.asScala == testParams.messageParameters) assert(error.getQueryContext.isEmpty) } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index 5397dae9dcc5f..10e4c11c406fe 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -30,11 +30,11 @@ import org.apache.arrow.memory.{BufferAllocator, RootAllocator} import org.apache.arrow.vector.VarBinaryVector import org.scalatest.BeforeAndAfterAll -import org.apache.spark.{sql, SparkRuntimeException, SparkUnsupportedOperationException} +import org.apache.spark.{SparkRuntimeException, SparkUnsupportedOperationException} import org.apache.spark.sql.{AnalysisException, Encoders, Row} import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, Codec, OuterScopes} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, RowEncoder, ScalaDecimalEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, YearMonthIntervalEncoder} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, RowEncoder, ScalaDecimalEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.catalyst.encoders.RowEncoder.{encoderFor => toRowEncoder} import org.apache.spark.sql.catalyst.util.{DateFormatter, SparkStringUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_SECOND @@ -770,7 +770,7 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { } test("java serialization") { - val encoder = sql.encoderFor(Encoders.javaSerialization[(Int, String)]) + val encoder = agnosticEncoderFor(Encoders.javaSerialization[(Int, String)]) roundTripAndCheckIdentical(encoder) { () => Iterator.tabulate(10)(i => (i, "itr_" + i)) } @@ -778,12 +778,12 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { test("kryo serialization") { val e = intercept[SparkRuntimeException] { - val encoder = sql.encoderFor(Encoders.kryo[(Int, String)]) + val encoder = agnosticEncoderFor(Encoders.kryo[(Int, String)]) roundTripAndCheckIdentical(encoder) { () => Iterator.tabulate(10)(i => (i, "itr_" + i)) } } - assert(e.getErrorClass == "CANNOT_USE_KRYO") + assert(e.getCondition == "CANNOT_USE_KRYO") } test("transforming encoder") { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala index 27b1ee014a719..b1a7d81916e92 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala @@ -331,7 +331,7 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L query.awaitTermination() } - assert(exception.getErrorClass != null) + assert(exception.getCondition != null) assert(exception.getMessageParameters().get("id") == query.id.toString) assert(exception.getMessageParameters().get("runId") == query.runId.toString) assert(exception.getCause.isInstanceOf[SparkException]) @@ -369,7 +369,7 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L spark.streams.awaitAnyTermination() } - assert(exception.getErrorClass != null) + assert(exception.getCondition != null) assert(exception.getMessageParameters().get("id") == query.id.toString) assert(exception.getMessageParameters().get("runId") == query.runId.toString) assert(exception.getCause.isInstanceOf[SparkException]) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala index e0de73e496d95..36aaa2cc7fbf6 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala @@ -124,6 +124,8 @@ object SparkConnectServerUtils { // to make the tests exercise reattach. "spark.connect.execute.reattachable.senderMaxStreamDuration=1s", "spark.connect.execute.reattachable.senderMaxStreamSize=123", + // Testing SPARK-49673, setting maxBatchSize to 10MiB + s"spark.connect.grpc.arrow.maxBatchSize=${10 * 1024 * 1024}", // Disable UI "spark.ui.enabled=false") Seq("--jars", catalystTestJar) ++ confs.flatMap(v => "--conf" :: v :: Nil) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala index 8d17e0b4e36e6..1df01bd3bfb62 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala @@ -115,7 +115,7 @@ abstract class DockerJDBCIntegrationSuite protected val startContainerTimeout: Long = timeStringAsSeconds(sys.props.getOrElse("spark.test.docker.startContainerTimeout", "5min")) protected val connectionTimeout: PatienceConfiguration.Timeout = { - val timeoutStr = sys.props.getOrElse("spark.test.docker.connectionTimeout", "5min") + val timeoutStr = sys.props.getOrElse("spark.test.docker.connectionTimeout", "10min") timeout(timeStringAsSeconds(timeoutStr).seconds) } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 3076b599ef4ef..071b976f044c3 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -32,9 +32,9 @@ import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:16.4-alpine): + * To run this test suite for a specific version (e.g., postgres:17.0-alpine): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.4-alpine + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:17.0-alpine * ./build/sbt -Pdocker-integration-tests * "docker-integration-tests/testOnly org.apache.spark.sql.jdbc.PostgresIntegrationSuite" * }}} @@ -42,7 +42,7 @@ import org.apache.spark.tags.DockerTest @DockerTest class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.4-alpine") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:17.0-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala index 5acb6423bbd9b..62f9c6e0256f3 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.execution.datasources.jdbc.connection.SecureConnecti import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:16.4-alpine): + * To run this test suite for a specific version (e.g., postgres:17.0-alpine): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.4-alpine + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:17.0-alpine * ./build/sbt -Pdocker-integration-tests * "docker-integration-tests/testOnly *PostgresKrbIntegrationSuite" * }}} @@ -38,7 +38,7 @@ class PostgresKrbIntegrationSuite extends DockerKrbJDBCIntegrationSuite { override protected val keytabFileName = "postgres.keytab" override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.4-alpine") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:17.0-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/GeneratedSubquerySuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/GeneratedSubquerySuite.scala index 8d367f476403f..a79bbf39a71b8 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/GeneratedSubquerySuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/GeneratedSubquerySuite.scala @@ -28,9 +28,9 @@ import org.apache.spark.tags.DockerTest /** * This suite is used to generate subqueries, and test Spark against Postgres. - * To run this test suite for a specific version (e.g., postgres:16.4-alpine): + * To run this test suite for a specific version (e.g., postgres:17.0-alpine): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.4-alpine + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:17.0-alpine * ./build/sbt -Pdocker-integration-tests * "docker-integration-tests/testOnly org.apache.spark.sql.jdbc.GeneratedSubquerySuite" * }}} @@ -39,7 +39,7 @@ import org.apache.spark.tags.DockerTest class GeneratedSubquerySuite extends DockerJDBCIntegrationSuite with QueryGeneratorHelper { override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.4-alpine") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:17.0-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/PostgreSQLQueryTestSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/PostgreSQLQueryTestSuite.scala index f3a08541365c1..80ba35df6c893 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/PostgreSQLQueryTestSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/PostgreSQLQueryTestSuite.scala @@ -30,9 +30,9 @@ import org.apache.spark.tags.DockerTest * confidence, and you won't have to manually verify the golden files generated with your test. * 2. Add this line to your .sql file: --ONLY_IF spark * - * Note: To run this test suite for a specific version (e.g., postgres:16.4-alpine): + * Note: To run this test suite for a specific version (e.g., postgres:17.0-alpine): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.4-alpine + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:17.0-alpine * ./build/sbt -Pdocker-integration-tests * "testOnly org.apache.spark.sql.jdbc.PostgreSQLQueryTestSuite" * }}} @@ -45,7 +45,7 @@ class PostgreSQLQueryTestSuite extends CrossDbmsQueryTestSuite { protected val customInputFilePath: String = new File(inputFilePath, "subquery").getAbsolutePath override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.4-alpine") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:17.0-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index e5fd453cb057c..aaaaa28558342 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -115,7 +115,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD exception = intercept[SparkSQLFeatureNotSupportedException] { sql(s"ALTER TABLE $tbl ALTER COLUMN ID DROP NOT NULL") }, - condition = "_LEGACY_ERROR_TEMP_2271") + condition = "UNSUPPORTED_FEATURE.UPDATE_COLUMN_NULLABILITY") } test("SPARK-47440: SQLServer does not support boolean expression in binary comparison") { diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala index 700c05b54a256..a895739254373 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala @@ -142,7 +142,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest exception = intercept[SparkSQLFeatureNotSupportedException] { sql(s"ALTER TABLE $tbl ALTER COLUMN ID DROP NOT NULL") }, - condition = "_LEGACY_ERROR_TEMP_2271") + condition = "UNSUPPORTED_FEATURE.UPDATE_COLUMN_NULLABILITY") } override def testCreateTableWithProperty(tbl: String): Unit = { diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index 850391e8dc33c..05f02a402353b 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -28,9 +28,9 @@ import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:16.4-alpine) + * To run this test suite for a specific version (e.g., postgres:17.0-alpine) * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.4-alpine + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:17.0-alpine * ./build/sbt -Pdocker-integration-tests "testOnly *v2.PostgresIntegrationSuite" * }}} */ @@ -38,7 +38,7 @@ import org.apache.spark.tags.DockerTest class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "postgresql" override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.4-alpine") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:17.0-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) @@ -65,6 +65,17 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT |) """.stripMargin ).executeUpdate() + connection.prepareStatement( + "CREATE TABLE datetime (name VARCHAR(32), date1 DATE, time1 TIMESTAMP)") + .executeUpdate() + } + + override def dataPreparation(connection: Connection): Unit = { + super.dataPreparation(connection) + connection.prepareStatement("INSERT INTO datetime VALUES " + + "('amy', '2022-05-19', '2022-05-19 00:00:00')").executeUpdate() + connection.prepareStatement("INSERT INTO datetime VALUES " + + "('alex', '2022-05-18', '2022-05-18 00:00:00')").executeUpdate() } override def testUpdateColumnType(tbl: String): Unit = { @@ -123,4 +134,77 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT ) } } + + override def testDatetime(tbl: String): Unit = { + val df1 = sql(s"SELECT name FROM $tbl WHERE " + + "dayofyear(date1) > 100 AND dayofmonth(date1) > 10 ") + checkFilterPushed(df1) + val rows1 = df1.collect() + assert(rows1.length === 2) + assert(rows1(0).getString(0) === "amy") + assert(rows1(1).getString(0) === "alex") + + val df2 = sql(s"SELECT name FROM $tbl WHERE year(date1) = 2022 AND quarter(date1) = 2") + checkFilterPushed(df2) + val rows2 = df2.collect() + assert(rows2.length === 2) + assert(rows2(0).getString(0) === "amy") + assert(rows2(1).getString(0) === "alex") + + val df3 = sql(s"SELECT name FROM $tbl WHERE second(time1) = 0 AND month(date1) = 5") + checkFilterPushed(df3) + val rows3 = df3.collect() + assert(rows3.length === 2) + assert(rows3(0).getString(0) === "amy") + assert(rows3(1).getString(0) === "alex") + + val df4 = sql(s"SELECT name FROM $tbl WHERE hour(time1) = 0 AND minute(time1) = 0") + checkFilterPushed(df4) + val rows4 = df4.collect() + assert(rows4.length === 2) + assert(rows4(0).getString(0) === "amy") + assert(rows4(1).getString(0) === "alex") + + val df5 = sql(s"SELECT name FROM $tbl WHERE " + + "extract(WEEk from date1) > 10 AND extract(YEAROFWEEK from date1) = 2022") + checkFilterPushed(df5) + val rows5 = df5.collect() + assert(rows5.length === 2) + assert(rows5(0).getString(0) === "amy") + assert(rows5(1).getString(0) === "alex") + + val df6 = sql(s"SELECT name FROM $tbl WHERE date_add(date1, 1) = date'2022-05-20' " + + "AND datediff(date1, '2022-05-10') > 0") + checkFilterPushed(df6, false) + val rows6 = df6.collect() + assert(rows6.length === 1) + assert(rows6(0).getString(0) === "amy") + + val df7 = sql(s"SELECT name FROM $tbl WHERE weekday(date1) = 2") + checkFilterPushed(df7) + val rows7 = df7.collect() + assert(rows7.length === 1) + assert(rows7(0).getString(0) === "alex") + + val df8 = sql(s"SELECT name FROM $tbl WHERE dayofweek(date1) = 4") + checkFilterPushed(df8) + val rows8 = df8.collect() + assert(rows8.length === 1) + assert(rows8(0).getString(0) === "alex") + + val df9 = sql(s"SELECT name FROM $tbl WHERE " + + "dayofyear(date1) > 100 order by dayofyear(date1) limit 1") + checkFilterPushed(df9) + val rows9 = df9.collect() + assert(rows9.length === 1) + assert(rows9(0).getString(0) === "alex") + + // Postgres does not support + val df10 = sql(s"SELECT name FROM $tbl WHERE trunc(date1, 'week') = date'2022-05-16'") + checkFilterPushed(df10, false) + val rows10 = df10.collect() + assert(rows10.length === 2) + assert(rows10(0).getString(0) === "amy") + assert(rows10(1).getString(0) === "alex") + } } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala index 665746f1d5770..6d4f1cc2fd3fc 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala @@ -26,16 +26,16 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:16.4-alpine): + * To run this test suite for a specific version (e.g., postgres:17.0-alpine): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.4-alpine + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:17.0-alpine * ./build/sbt -Pdocker-integration-tests "testOnly *v2.PostgresNamespaceSuite" * }}} */ @DockerTest class PostgresNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespaceTest { override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.4-alpine") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:17.0-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) diff --git a/connector/kafka-0-10-sql/pom.xml b/connector/kafka-0-10-sql/pom.xml index 35f58134f1a85..66e1c24e821c8 100644 --- a/connector/kafka-0-10-sql/pom.xml +++ b/connector/kafka-0-10-sql/pom.xml @@ -148,6 +148,16 @@ mockito-core test
+ + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.scalacheck scalacheck_${scala.binary.version} diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaExceptions.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaExceptions.scala index 13a68e72269f0..c4adb6b3f26e1 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaExceptions.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaExceptions.scala @@ -184,5 +184,5 @@ private[kafka010] class KafkaIllegalStateException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass } diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index 7852bc814ccd4..c3f02eebab23a 100644 --- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -176,7 +176,7 @@ class KafkaTestUtils( } kdc.getKrb5conf.delete() - Files.write(krb5confStr, kdc.getKrb5conf, StandardCharsets.UTF_8) + Files.asCharSink(kdc.getKrb5conf, StandardCharsets.UTF_8).write(krb5confStr) logDebug(s"krb5.conf file content: $krb5confStr") } @@ -240,7 +240,7 @@ class KafkaTestUtils( | principal="$kafkaServerUser@$realm"; |}; """.stripMargin.trim - Files.write(content, file, StandardCharsets.UTF_8) + Files.asCharSink(file, StandardCharsets.UTF_8).write(content) logDebug(s"Created JAAS file: ${file.getPath}") logDebug(s"JAAS file content: $content") file.getAbsolutePath() diff --git a/connector/kafka-0-10-token-provider/pom.xml b/connector/kafka-0-10-token-provider/pom.xml index 2b2707b9da320..3cbfc34e7d806 100644 --- a/connector/kafka-0-10-token-provider/pom.xml +++ b/connector/kafka-0-10-token-provider/pom.xml @@ -64,6 +64,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.apache.hadoop hadoop-client-runtime diff --git a/connector/kafka-0-10/pom.xml b/connector/kafka-0-10/pom.xml index 1b26839a371ce..a42410e6ce885 100644 --- a/connector/kafka-0-10/pom.xml +++ b/connector/kafka-0-10/pom.xml @@ -119,6 +119,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.apache.spark spark-tags_${scala.binary.version} diff --git a/connector/kinesis-asl/pom.xml b/connector/kinesis-asl/pom.xml index 9a7f40443bbc9..7eba26ffdff74 100644 --- a/connector/kinesis-asl/pom.xml +++ b/connector/kinesis-asl/pom.xml @@ -81,6 +81,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.scalacheck scalacheck_${scala.binary.version} diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala deleted file mode 100644 index 3b0def8fc73f7..0000000000000 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala +++ /dev/null @@ -1,324 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.protobuf - -import scala.jdk.CollectionConverters._ - -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.Column -import org.apache.spark.sql.functions.lit -import org.apache.spark.sql.protobuf.utils.ProtobufUtils - -// scalastyle:off: object.name -object functions { -// scalastyle:on: object.name - - /** - * Converts a binary column of Protobuf format into its corresponding catalyst value. The - * Protobuf definition is provided through Protobuf descriptor file. - * - * @param data - * the binary column. - * @param messageName - * the protobuf message name to look for in descriptor file. - * @param descFilePath - * The Protobuf descriptor file. This file is usually created using `protoc` with - * `--descriptor_set_out` and `--include_imports` options. - * @param options - * @since 3.4.0 - */ - @Experimental - def from_protobuf( - data: Column, - messageName: String, - descFilePath: String, - options: java.util.Map[String, String]): Column = { - val descriptorFileContent = ProtobufUtils.readDescriptorFileContent(descFilePath) - from_protobuf(data, messageName, descriptorFileContent, options) - } - - /** - * Converts a binary column of Protobuf format into its corresponding catalyst value.The - * Protobuf definition is provided through Protobuf `FileDescriptorSet`. - * - * @param data - * the binary column. - * @param messageName - * the protobuf MessageName to look for in the descriptor set. - * @param binaryFileDescriptorSet - * Serialized Protobuf descriptor (`FileDescriptorSet`). Typically contents of file created - * using `protoc` with `--descriptor_set_out` and `--include_imports` options. - * @param options - * @since 3.5.0 - */ - @Experimental - def from_protobuf( - data: Column, - messageName: String, - binaryFileDescriptorSet: Array[Byte], - options: java.util.Map[String, String]): Column = { - Column.fnWithOptions( - "from_protobuf", - options.asScala.iterator, - data, - lit(messageName), - lit(binaryFileDescriptorSet) - ) - } - - /** - * Converts a binary column of Protobuf format into its corresponding catalyst value. The - * Protobuf definition is provided through Protobuf descriptor file. - * - * @param data - * the binary column. - * @param messageName - * the protobuf MessageName to look for in descriptor file. - * @param descFilePath - * The Protobuf descriptor file. This file is usually created using `protoc` with - * `--descriptor_set_out` and `--include_imports` options. - * @since 3.4.0 - */ - @Experimental - def from_protobuf(data: Column, messageName: String, descFilePath: String): Column = { - val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath) - from_protobuf(data, messageName, fileContent) - } - - /** - * Converts a binary column of Protobuf format into its corresponding catalyst value.The - * Protobuf definition is provided through Protobuf `FileDescriptorSet`. - * - * @param data - * the binary column. - * @param messageName - * the protobuf MessageName to look for in the descriptor set. - * @param binaryFileDescriptorSet - * Serialized Protobuf descriptor (`FileDescriptorSet`). Typically contents of file created - * using `protoc` with `--descriptor_set_out` and `--include_imports` options. - * @since 3.5.0 - */ - @Experimental - def from_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte]) - : Column = { - Column.fn( - "from_protobuf", - data, - lit(messageName), - lit(binaryFileDescriptorSet) - ) - } - - /** - * Converts a binary column of Protobuf format into its corresponding catalyst value. - * `messageClassName` points to Protobuf Java class. The jar containing Java class should be - * shaded. Specifically, `com.google.protobuf.*` should be shaded to - * `org.sparkproject.spark_protobuf.protobuf.*`. - * https://github.com/rangadi/shaded-protobuf-classes is useful to create shaded jar from - * Protobuf files. - * - * @param data - * the binary column. - * @param messageClassName - * The full name for Protobuf Java class. E.g. com.example.protos.ExampleEvent. - * The jar with these classes needs to be shaded as described above. - * @since 3.4.0 - */ - @Experimental - def from_protobuf(data: Column, messageClassName: String): Column = { - Column.fn( - "from_protobuf", - data, - lit(messageClassName) - ) - } - - /** - * Converts a binary column of Protobuf format into its corresponding catalyst value. - * `messageClassName` points to Protobuf Java class. The jar containing Java class should be - * shaded. Specifically, `com.google.protobuf.*` should be shaded to - * `org.sparkproject.spark_protobuf.protobuf.*`. - * https://github.com/rangadi/shaded-protobuf-classes is useful to create shaded jar from - * Protobuf files. - * - * @param data - * the binary column. - * @param messageClassName - * The full name for Protobuf Java class. E.g. com.example.protos.ExampleEvent. - * The jar with these classes needs to be shaded as described above. - * @param options - * @since 3.4.0 - */ - @Experimental - def from_protobuf( - data: Column, - messageClassName: String, - options: java.util.Map[String, String]): Column = { - Column.fnWithOptions( - "from_protobuf", - options.asScala.iterator, - data, - lit(messageClassName) - ) - } - - /** - * Converts a column into binary of protobuf format. The Protobuf definition is provided - * through Protobuf descriptor file. - * - * @param data - * the data column. - * @param messageName - * the protobuf MessageName to look for in descriptor file. - * @param descFilePath - * The Protobuf descriptor file. This file is usually created using `protoc` with - * `--descriptor_set_out` and `--include_imports` options. - * @since 3.4.0 - */ - @Experimental - def to_protobuf(data: Column, messageName: String, descFilePath: String): Column = { - to_protobuf(data, messageName, descFilePath, Map.empty[String, String].asJava) - } - - /** - * Converts a column into binary of protobuf format.The Protobuf definition is provided - * through Protobuf `FileDescriptorSet`. - * - * @param data - * the binary column. - * @param messageName - * the protobuf MessageName to look for in the descriptor set. - * @param binaryFileDescriptorSet - * Serialized Protobuf descriptor (`FileDescriptorSet`). Typically contents of file created - * using `protoc` with `--descriptor_set_out` and `--include_imports` options. - * - * @since 3.5.0 - */ - @Experimental - def to_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte]) - : Column = { - Column.fn( - "to_protobuf", - data, - lit(messageName), - lit(binaryFileDescriptorSet) - ) - } - /** - * Converts a column into binary of protobuf format. The Protobuf definition is provided - * through Protobuf descriptor file. - * - * @param data - * the data column. - * @param messageName - * the protobuf MessageName to look for in descriptor file. - * @param descFilePath - * the protobuf descriptor file. - * @param options - * @since 3.4.0 - */ - @Experimental - def to_protobuf( - data: Column, - messageName: String, - descFilePath: String, - options: java.util.Map[String, String]): Column = { - val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath) - to_protobuf(data, messageName, fileContent, options) - } - - /** - * Converts a column into binary of protobuf format.The Protobuf definition is provided - * through Protobuf `FileDescriptorSet`. - * - * @param data - * the binary column. - * @param messageName - * the protobuf MessageName to look for in the descriptor set. - * @param binaryFileDescriptorSet - * Serialized Protobuf descriptor (`FileDescriptorSet`). Typically contents of file created - * using `protoc` with `--descriptor_set_out` and `--include_imports` options. - * @param options - * @since 3.5.0 - */ - @Experimental - def to_protobuf( - data: Column, - messageName: String, - binaryFileDescriptorSet: Array[Byte], - options: java.util.Map[String, String] - ): Column = { - Column.fnWithOptions( - "to_protobuf", - options.asScala.iterator, - data, - lit(messageName), - lit(binaryFileDescriptorSet) - ) - } - - /** - * Converts a column into binary of protobuf format. - * `messageClassName` points to Protobuf Java class. The jar containing Java class should be - * shaded. Specifically, `com.google.protobuf.*` should be shaded to - * `org.sparkproject.spark_protobuf.protobuf.*`. - * https://github.com/rangadi/shaded-protobuf-classes is useful to create shaded jar from - * Protobuf files. - * - * @param data - * the data column. - * @param messageClassName - * The full name for Protobuf Java class. E.g. com.example.protos.ExampleEvent. - * The jar with these classes needs to be shaded as described above. - * @since 3.4.0 - */ - @Experimental - def to_protobuf(data: Column, messageClassName: String): Column = { - Column.fn( - "to_protobuf", - data, - lit(messageClassName) - ) - } - - /** - * Converts a column into binary of protobuf format. - * `messageClassName` points to Protobuf Java class. The jar containing Java class should be - * shaded. Specifically, `com.google.protobuf.*` should be shaded to - * `org.sparkproject.spark_protobuf.protobuf.*`. - * https://github.com/rangadi/shaded-protobuf-classes is useful to create shaded jar from - * Protobuf files. - * - * @param data - * the data column. - * @param messageClassName - * The full name for Protobuf Java class. E.g. com.example.protos.ExampleEvent. - * The jar with these classes needs to be shaded as described above. - * @param options - * @since 3.4.0 - */ - @Experimental - def to_protobuf(data: Column, messageClassName: String, options: java.util.Map[String, String]) - : Column = { - Column.fnWithOptions( - "to_protobuf", - options.asScala.iterator, - data, - lit(messageClassName) - ) - } -} diff --git a/core/benchmarks/ZStandardBenchmark-jdk21-results.txt b/core/benchmarks/ZStandardBenchmark-jdk21-results.txt index b3bffea826e5f..f6bd681451d5e 100644 --- a/core/benchmarks/ZStandardBenchmark-jdk21-results.txt +++ b/core/benchmarks/ZStandardBenchmark-jdk21-results.txt @@ -2,48 +2,48 @@ Benchmark ZStandardCompressionCodec ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.8.0-1014-azure AMD EPYC 7763 64-Core Processor Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------- -Compression 10000 times at level 1 without buffer pool 657 670 14 0.0 65699.2 1.0X -Compression 10000 times at level 2 without buffer pool 697 697 1 0.0 69673.4 0.9X -Compression 10000 times at level 3 without buffer pool 799 802 3 0.0 79855.2 0.8X -Compression 10000 times at level 1 with buffer pool 593 595 1 0.0 59326.9 1.1X -Compression 10000 times at level 2 with buffer pool 622 624 3 0.0 62194.1 1.1X -Compression 10000 times at level 3 with buffer pool 732 733 1 0.0 73178.6 0.9X +Compression 10000 times at level 1 without buffer pool 659 676 16 0.0 65860.7 1.0X +Compression 10000 times at level 2 without buffer pool 721 723 2 0.0 72135.5 0.9X +Compression 10000 times at level 3 without buffer pool 815 816 1 0.0 81500.6 0.8X +Compression 10000 times at level 1 with buffer pool 608 609 0 0.0 60846.6 1.1X +Compression 10000 times at level 2 with buffer pool 645 647 3 0.0 64476.3 1.0X +Compression 10000 times at level 3 with buffer pool 746 746 1 0.0 74584.0 0.9X -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.8.0-1014-azure AMD EPYC 7763 64-Core Processor Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------ -Decompression 10000 times from level 1 without buffer pool 813 820 11 0.0 81273.2 1.0X -Decompression 10000 times from level 2 without buffer pool 810 813 3 0.0 80986.2 1.0X -Decompression 10000 times from level 3 without buffer pool 812 813 2 0.0 81183.1 1.0X -Decompression 10000 times from level 1 with buffer pool 746 747 2 0.0 74568.7 1.1X -Decompression 10000 times from level 2 with buffer pool 744 746 2 0.0 74414.5 1.1X -Decompression 10000 times from level 3 with buffer pool 745 746 1 0.0 74538.6 1.1X +Decompression 10000 times from level 1 without buffer pool 828 829 1 0.0 82822.6 1.0X +Decompression 10000 times from level 2 without buffer pool 829 829 1 0.0 82900.7 1.0X +Decompression 10000 times from level 3 without buffer pool 828 833 8 0.0 82784.4 1.0X +Decompression 10000 times from level 1 with buffer pool 758 760 2 0.0 75756.5 1.1X +Decompression 10000 times from level 2 with buffer pool 758 758 1 0.0 75772.3 1.1X +Decompression 10000 times from level 3 with buffer pool 759 759 0 0.0 75852.7 1.1X -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.8.0-1014-azure AMD EPYC 7763 64-Core Processor Parallel Compression at level 3: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Parallel Compression with 0 workers 48 49 1 0.0 374256.1 1.0X -Parallel Compression with 1 workers 34 36 3 0.0 267557.3 1.4X -Parallel Compression with 2 workers 34 38 2 0.0 263684.3 1.4X -Parallel Compression with 4 workers 37 39 2 0.0 289956.1 1.3X -Parallel Compression with 8 workers 39 41 1 0.0 306975.2 1.2X -Parallel Compression with 16 workers 44 45 1 0.0 340992.0 1.1X +Parallel Compression with 0 workers 58 59 1 0.0 452489.9 1.0X +Parallel Compression with 1 workers 42 45 4 0.0 330066.0 1.4X +Parallel Compression with 2 workers 40 42 1 0.0 312560.3 1.4X +Parallel Compression with 4 workers 40 42 2 0.0 308802.7 1.5X +Parallel Compression with 8 workers 41 45 3 0.0 321331.3 1.4X +Parallel Compression with 16 workers 44 45 1 0.0 343311.5 1.3X -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.8.0-1014-azure AMD EPYC 7763 64-Core Processor Parallel Compression at level 9: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Parallel Compression with 0 workers 156 158 1 0.0 1220760.5 1.0X -Parallel Compression with 1 workers 191 192 2 0.0 1495168.2 0.8X -Parallel Compression with 2 workers 111 117 5 0.0 864459.9 1.4X -Parallel Compression with 4 workers 106 109 2 0.0 831025.5 1.5X -Parallel Compression with 8 workers 112 115 2 0.0 875732.7 1.4X -Parallel Compression with 16 workers 110 114 2 0.0 858160.9 1.4X +Parallel Compression with 0 workers 158 160 2 0.0 1234257.6 1.0X +Parallel Compression with 1 workers 193 194 1 0.0 1507686.4 0.8X +Parallel Compression with 2 workers 113 127 11 0.0 881068.0 1.4X +Parallel Compression with 4 workers 109 111 2 0.0 849241.3 1.5X +Parallel Compression with 8 workers 111 115 3 0.0 869455.2 1.4X +Parallel Compression with 16 workers 113 116 2 0.0 881832.5 1.4X diff --git a/core/benchmarks/ZStandardBenchmark-results.txt b/core/benchmarks/ZStandardBenchmark-results.txt index b230f825fecac..136f0333590cc 100644 --- a/core/benchmarks/ZStandardBenchmark-results.txt +++ b/core/benchmarks/ZStandardBenchmark-results.txt @@ -2,48 +2,48 @@ Benchmark ZStandardCompressionCodec ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.8.0-1014-azure AMD EPYC 7763 64-Core Processor Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------- -Compression 10000 times at level 1 without buffer pool 638 638 0 0.0 63765.0 1.0X -Compression 10000 times at level 2 without buffer pool 675 676 1 0.0 67529.4 0.9X -Compression 10000 times at level 3 without buffer pool 775 783 11 0.0 77531.6 0.8X -Compression 10000 times at level 1 with buffer pool 572 573 1 0.0 57223.2 1.1X -Compression 10000 times at level 2 with buffer pool 603 605 1 0.0 60323.7 1.1X -Compression 10000 times at level 3 with buffer pool 720 727 6 0.0 71980.9 0.9X +Compression 10000 times at level 1 without buffer pool 257 259 2 0.0 25704.2 1.0X +Compression 10000 times at level 2 without buffer pool 674 676 2 0.0 67396.3 0.4X +Compression 10000 times at level 3 without buffer pool 775 787 11 0.0 77497.9 0.3X +Compression 10000 times at level 1 with buffer pool 573 574 0 0.0 57347.3 0.4X +Compression 10000 times at level 2 with buffer pool 602 603 2 0.0 60162.8 0.4X +Compression 10000 times at level 3 with buffer pool 722 725 3 0.0 72247.3 0.4X -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.8.0-1014-azure AMD EPYC 7763 64-Core Processor Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------ -Decompression 10000 times from level 1 without buffer pool 584 585 1 0.0 58381.0 1.0X -Decompression 10000 times from level 2 without buffer pool 585 585 0 0.0 58465.9 1.0X -Decompression 10000 times from level 3 without buffer pool 585 586 1 0.0 58499.5 1.0X -Decompression 10000 times from level 1 with buffer pool 534 534 0 0.0 53375.7 1.1X -Decompression 10000 times from level 2 with buffer pool 533 533 0 0.0 53312.3 1.1X -Decompression 10000 times from level 3 with buffer pool 533 533 1 0.0 53255.1 1.1X +Decompression 10000 times from level 1 without buffer pool 176 177 1 0.1 17641.2 1.0X +Decompression 10000 times from level 2 without buffer pool 176 178 1 0.1 17628.9 1.0X +Decompression 10000 times from level 3 without buffer pool 175 176 0 0.1 17506.1 1.0X +Decompression 10000 times from level 1 with buffer pool 151 152 1 0.1 15051.5 1.2X +Decompression 10000 times from level 2 with buffer pool 150 151 1 0.1 14998.0 1.2X +Decompression 10000 times from level 3 with buffer pool 150 151 0 0.1 15019.4 1.2X -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.8.0-1014-azure AMD EPYC 7763 64-Core Processor Parallel Compression at level 3: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Parallel Compression with 0 workers 46 48 1 0.0 360483.5 1.0X -Parallel Compression with 1 workers 34 36 2 0.0 265816.1 1.4X -Parallel Compression with 2 workers 33 36 2 0.0 254525.8 1.4X -Parallel Compression with 4 workers 34 37 1 0.0 266270.8 1.4X -Parallel Compression with 8 workers 37 39 1 0.0 289289.2 1.2X -Parallel Compression with 16 workers 41 43 1 0.0 320243.3 1.1X +Parallel Compression with 0 workers 57 57 0 0.0 444425.2 1.0X +Parallel Compression with 1 workers 42 44 3 0.0 325107.6 1.4X +Parallel Compression with 2 workers 38 39 2 0.0 294840.0 1.5X +Parallel Compression with 4 workers 36 37 1 0.0 282143.1 1.6X +Parallel Compression with 8 workers 39 40 1 0.0 303793.6 1.5X +Parallel Compression with 16 workers 41 43 1 0.0 324165.5 1.4X -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.8.0-1014-azure AMD EPYC 7763 64-Core Processor Parallel Compression at level 9: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Parallel Compression with 0 workers 154 156 2 0.0 1205934.0 1.0X -Parallel Compression with 1 workers 191 194 4 0.0 1495729.9 0.8X -Parallel Compression with 2 workers 110 114 5 0.0 859158.9 1.4X -Parallel Compression with 4 workers 105 108 3 0.0 822932.2 1.5X -Parallel Compression with 8 workers 109 113 2 0.0 851560.0 1.4X -Parallel Compression with 16 workers 111 115 2 0.0 870695.9 1.4X +Parallel Compression with 0 workers 156 158 1 0.0 1220298.8 1.0X +Parallel Compression with 1 workers 188 189 1 0.0 1467911.4 0.8X +Parallel Compression with 2 workers 111 118 7 0.0 866985.2 1.4X +Parallel Compression with 4 workers 106 109 2 0.0 827592.1 1.5X +Parallel Compression with 8 workers 114 116 2 0.0 888419.5 1.4X +Parallel Compression with 16 workers 111 115 2 0.0 868463.5 1.4X diff --git a/core/pom.xml b/core/pom.xml index 19f58940ed942..7805a3f37ae53 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -393,6 +393,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.scalacheck scalacheck_${scala.binary.version} diff --git a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java index 5e9f1b78273a5..7dd87df713e6e 100644 --- a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java +++ b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java @@ -120,7 +120,8 @@ private boolean isEndOfStream() { private void checkReadException() throws IOException { if (readAborted) { - Throwables.propagateIfPossible(readException, IOException.class); + Throwables.throwIfInstanceOf(readException, IOException.class); + Throwables.throwIfUnchecked(readException); throw new IOException(readException); } } diff --git a/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java b/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java index 8ec5c2221b6e9..fa71eb066ff89 100644 --- a/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java +++ b/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java @@ -52,7 +52,7 @@ public Map getMessageParameters() { } @Override - public String getErrorClass() { + public String getCondition() { return errorClass; } } diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 3b7c7778e26ce..573608c4327e0 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -173,7 +173,7 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( } private def canShuffleMergeBeEnabled(): Boolean = { - val isPushShuffleEnabled = Utils.isPushBasedShuffleEnabled(rdd.sparkContext.getConf, + val isPushShuffleEnabled = Utils.isPushBasedShuffleEnabled(rdd.sparkContext.conf, // invoked at driver isDriver = true) if (isPushShuffleEnabled && rdd.isBarrier()) { diff --git a/core/src/main/scala/org/apache/spark/SparkFileAlreadyExistsException.scala b/core/src/main/scala/org/apache/spark/SparkFileAlreadyExistsException.scala index 0e578f045452e..82a0261f32ae7 100644 --- a/core/src/main/scala/org/apache/spark/SparkFileAlreadyExistsException.scala +++ b/core/src/main/scala/org/apache/spark/SparkFileAlreadyExistsException.scala @@ -33,5 +33,5 @@ private[spark] class SparkFileAlreadyExistsException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass } diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 5e3078d7292ba..fed15a067c00f 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -421,7 +421,7 @@ private[spark] object TestUtils extends SparkTestUtils { def createTempScriptWithExpectedOutput(dir: File, prefix: String, output: String): String = { val file = File.createTempFile(prefix, ".sh", dir) val script = s"cat < x._1.matches("SPARK_LOCAL_(IP|HOSTNAME)")) + Option(request.environmentVariables).getOrElse(Map.empty[String, String]) + .filterNot(x => x._1.matches("SPARK_LOCAL_(IP|HOSTNAME)")) .map(x => (x._1, replacePlaceHolder(x._2))) // Construct driver description diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala index 7f462148c71a1..63882259adcb5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala @@ -47,8 +47,6 @@ private[rest] class CreateSubmissionRequest extends SubmitRestProtocolRequest { super.doValidate() assert(sparkProperties != null, "No Spark properties set!") assertFieldIsSet(appResource, "appResource") - assertFieldIsSet(appArgs, "appArgs") - assertFieldIsSet(environmentVariables, "environmentVariables") assertPropertyIsSet("spark.app.name") assertPropertyIsBoolean(config.DRIVER_SUPERVISE.key) assertPropertyIsNumeric(config.DRIVER_CORES.key) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index bb96ecb38a640..ca0e024ad1aed 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -23,7 +23,7 @@ import java.nio.charset.StandardCharsets import scala.jdk.CollectionConverters._ -import com.google.common.io.Files +import com.google.common.io.{Files, FileWriteMode} import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.{DriverDescription, SparkHadoopUtil} @@ -216,7 +216,7 @@ private[deploy] class DriverRunner( val redactedCommand = Utils.redactCommandLineArgs(conf, builder.command.asScala.toSeq) .mkString("\"", "\" \"", "\"") val header = "Launch Command: %s\n%s\n\n".format(redactedCommand, "=" * 40) - Files.append(header, stderr, StandardCharsets.UTF_8) + Files.asCharSink(stderr, StandardCharsets.UTF_8, FileWriteMode.APPEND).write(header) CommandUtils.redirectStream(process.getErrorStream, stderr) } runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 8d0fb7a54f72a..d21904dd16ea7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -191,7 +191,7 @@ private[deploy] class ExecutorRunner( stdoutAppender = FileAppender(process.getInputStream, stdout, conf, true) val stderr = new File(executorDir, "stderr") - Files.write(header, stderr, StandardCharsets.UTF_8) + Files.asCharSink(stderr, StandardCharsets.UTF_8).write(header) stderrAppender = FileAppender(process.getErrorStream, stderr, conf, true) state = ExecutorState.RUNNING diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 94a27e1a3e6da..f24cd59418300 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -22,6 +22,7 @@ import java.lang.management.ManagementFactory import scala.annotation.tailrec import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging import org.apache.spark.internal.config.Worker._ import org.apache.spark.util.{IntParam, MemoryParam, Utils} @@ -59,6 +60,9 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { // This mutates the SparkConf, so all accesses to it must be made after this line propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile) + // Initialize logging system again after `spark.log.structuredLogging.enabled` takes effect + Utils.resetStructuredLogging(conf) + Logging.uninitialize() conf.get(WORKER_UI_PORT).foreach { webUiPort = _ } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index eaa07b9a81f5b..e880cf8da9ec2 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -468,6 +468,10 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { } } + // Initialize logging system again after `spark.log.structuredLogging.enabled` takes effect + Utils.resetStructuredLogging(driverConf) + Logging.uninitialize() + cfg.hadoopDelegationCreds.foreach { tokens => SparkHadoopUtil.get.addDelegationTokens(tokens, driverConf) } diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index f0d6cba6ae734..3c3017a9a64c1 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -45,8 +45,8 @@ private[spark] abstract class StreamFileInputFormat[T] * which is set through setMaxSplitSize */ def setMinPartitions(sc: SparkContext, context: JobContext, minPartitions: Int): Unit = { - val defaultMaxSplitBytes = sc.getConf.get(config.FILES_MAX_PARTITION_BYTES) - val openCostInBytes = sc.getConf.get(config.FILES_OPEN_COST_IN_BYTES) + val defaultMaxSplitBytes = sc.conf.get(config.FILES_MAX_PARTITION_BYTES) + val openCostInBytes = sc.conf.get(config.FILES_OPEN_COST_IN_BYTES) val defaultParallelism = Math.max(sc.defaultParallelism, minPartitions) val files = listStatus(context).asScala val totalBytes = files.filterNot(_.isDirectory).map(_.getLen + openCostInBytes).sum diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 47019c04aada2..134d1d6bdb885 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -273,15 +273,15 @@ package object config { private[spark] val EVENT_LOG_INCLUDE_TASK_METRICS_ACCUMULATORS = ConfigBuilder("spark.eventLog.includeTaskMetricsAccumulators") - .doc("Whether to include TaskMetrics' underlying accumulator values in the event log (as " + - "part of the Task/Stage/Job metrics' 'Accumulables' fields. This configuration defaults " + - "to false because the TaskMetrics values are already logged in the 'Task Metrics' " + - "fields (so the accumulator updates are redundant). This flag exists only as a " + - "backwards-compatibility escape hatch for applications that might rely on the old " + - "behavior. See SPARK-42204 for details.") + .doc("Whether to include TaskMetrics' underlying accumulator values in the event log " + + "(as part of the Task/Stage/Job metrics' 'Accumulables' fields. The TaskMetrics " + + "values are already logged in the 'Task Metrics' fields (so the accumulator updates " + + "are redundant). This flag defaults to true for behavioral backwards compatibility " + + "for applications that might rely on the redundant logging. " + + "See SPARK-42204 for details.") .version("4.0.0") .booleanConf - .createWithDefault(false) + .createWithDefault(true) private[spark] val EVENT_LOG_OVERWRITE = ConfigBuilder("spark.eventLog.overwrite") @@ -1386,7 +1386,6 @@ package object config { private[spark] val SHUFFLE_ACCURATE_BLOCK_SKEWED_FACTOR = ConfigBuilder("spark.shuffle.accurateBlockSkewedFactor") - .internal() .doc("A shuffle block is considered as skewed and will be accurately recorded in " + "HighlyCompressedMapStatus if its size is larger than this factor multiplying " + "the median shuffle block size or SHUFFLE_ACCURATE_BLOCK_THRESHOLD. It is " + diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 2c89fe7885d08..4f7338f74e298 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -174,7 +174,7 @@ private[spark] class DAGScheduler( // `NUM_CANCELLED_JOB_GROUPS_TO_TRACK` stored. On a new job submission, if its job group is in // this set, the job will be immediately cancelled. private[scheduler] val cancelledJobGroups = - new LimitedSizeFIFOSet[String](sc.getConf.get(config.NUM_CANCELLED_JOB_GROUPS_TO_TRACK)) + new LimitedSizeFIFOSet[String](sc.conf.get(config.NUM_CANCELLED_JOB_GROUPS_TO_TRACK)) /** * Contains the locations that each RDD's partitions are cached on. This map's keys are RDD ids @@ -224,9 +224,9 @@ private[spark] class DAGScheduler( private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ - private val disallowStageRetryForTest = sc.getConf.get(TEST_NO_STAGE_RETRY) + private val disallowStageRetryForTest = sc.conf.get(TEST_NO_STAGE_RETRY) - private val shouldMergeResourceProfiles = sc.getConf.get(config.RESOURCE_PROFILE_MERGE_CONFLICTS) + private val shouldMergeResourceProfiles = sc.conf.get(config.RESOURCE_PROFILE_MERGE_CONFLICTS) /** * Whether to unregister all the outputs on the host in condition that we receive a FetchFailure, @@ -234,19 +234,19 @@ private[spark] class DAGScheduler( * executor(instead of the host) on a FetchFailure. */ private[scheduler] val unRegisterOutputOnHostOnFetchFailure = - sc.getConf.get(config.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE) + sc.conf.get(config.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE) /** * Number of consecutive stage attempts allowed before a stage is aborted. */ private[scheduler] val maxConsecutiveStageAttempts = - sc.getConf.get(config.STAGE_MAX_CONSECUTIVE_ATTEMPTS) + sc.conf.get(config.STAGE_MAX_CONSECUTIVE_ATTEMPTS) /** * Max stage attempts allowed before a stage is aborted. */ private[scheduler] val maxStageAttempts: Int = { - Math.max(maxConsecutiveStageAttempts, sc.getConf.get(config.STAGE_MAX_ATTEMPTS)) + Math.max(maxConsecutiveStageAttempts, sc.conf.get(config.STAGE_MAX_ATTEMPTS)) } /** @@ -254,7 +254,7 @@ private[spark] class DAGScheduler( * count spark.stage.maxConsecutiveAttempts */ private[scheduler] val ignoreDecommissionFetchFailure = - sc.getConf.get(config.STAGE_IGNORE_DECOMMISSION_FETCH_FAILURE) + sc.conf.get(config.STAGE_IGNORE_DECOMMISSION_FETCH_FAILURE) /** * Number of max concurrent tasks check failures for each barrier job. @@ -264,14 +264,14 @@ private[spark] class DAGScheduler( /** * Time in seconds to wait between a max concurrent tasks check failure and the next check. */ - private val timeIntervalNumTasksCheck = sc.getConf + private val timeIntervalNumTasksCheck = sc.conf .get(config.BARRIER_MAX_CONCURRENT_TASKS_CHECK_INTERVAL) /** * Max number of max concurrent tasks check failures allowed for a job before fail the job * submission. */ - private val maxFailureNumTasksCheck = sc.getConf + private val maxFailureNumTasksCheck = sc.conf .get(config.BARRIER_MAX_CONCURRENT_TASKS_CHECK_MAX_FAILURES) private val messageScheduler = @@ -286,26 +286,26 @@ private[spark] class DAGScheduler( taskScheduler.setDAGScheduler(this) - private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(sc.getConf, isDriver = true) + private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(sc.conf, isDriver = true) private val blockManagerMasterDriverHeartbeatTimeout = - sc.getConf.get(config.STORAGE_BLOCKMANAGER_MASTER_DRIVER_HEARTBEAT_TIMEOUT).millis + sc.conf.get(config.STORAGE_BLOCKMANAGER_MASTER_DRIVER_HEARTBEAT_TIMEOUT).millis private val shuffleMergeResultsTimeoutSec = - sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_RESULTS_TIMEOUT) + sc.conf.get(config.PUSH_BASED_SHUFFLE_MERGE_RESULTS_TIMEOUT) private val shuffleMergeFinalizeWaitSec = - sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_FINALIZE_TIMEOUT) + sc.conf.get(config.PUSH_BASED_SHUFFLE_MERGE_FINALIZE_TIMEOUT) private val shuffleMergeWaitMinSizeThreshold = - sc.getConf.get(config.PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT) + sc.conf.get(config.PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT) - private val shufflePushMinRatio = sc.getConf.get(config.PUSH_BASED_SHUFFLE_MIN_PUSH_RATIO) + private val shufflePushMinRatio = sc.conf.get(config.PUSH_BASED_SHUFFLE_MIN_PUSH_RATIO) private val shuffleMergeFinalizeNumThreads = - sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_FINALIZE_THREADS) + sc.conf.get(config.PUSH_BASED_SHUFFLE_MERGE_FINALIZE_THREADS) - private val shuffleFinalizeRpcThreads = sc.getConf.get(config.PUSH_SHUFFLE_FINALIZE_RPC_THREADS) + private val shuffleFinalizeRpcThreads = sc.conf.get(config.PUSH_SHUFFLE_FINALIZE_RPC_THREADS) // Since SparkEnv gets initialized after DAGScheduler, externalShuffleClient needs to be // initialized lazily @@ -328,11 +328,10 @@ private[spark] class DAGScheduler( ThreadUtils.newDaemonFixedThreadPool(shuffleFinalizeRpcThreads, "shuffle-merge-finalize-rpc") /** Whether rdd cache visibility tracking is enabled. */ - private val trackingCacheVisibility: Boolean = - sc.getConf.get(RDD_CACHE_VISIBILITY_TRACKING_ENABLED) + private val trackingCacheVisibility: Boolean = sc.conf.get(RDD_CACHE_VISIBILITY_TRACKING_ENABLED) /** Whether to abort a stage after canceling all of its tasks. */ - private val legacyAbortStageAfterKillTasks = sc.getConf.get(LEGACY_ABORT_STAGE_AFTER_KILL_TASKS) + private val legacyAbortStageAfterKillTasks = sc.conf.get(LEGACY_ABORT_STAGE_AFTER_KILL_TASKS) /** * Called by the TaskSetManager to report task's starting. @@ -557,7 +556,7 @@ private[spark] class DAGScheduler( * TODO SPARK-24942 Improve cluster resource management with jobs containing barrier stage */ private def checkBarrierStageWithDynamicAllocation(rdd: RDD[_]): Unit = { - if (rdd.isBarrier() && Utils.isDynamicAllocationEnabled(sc.getConf)) { + if (rdd.isBarrier() && Utils.isDynamicAllocationEnabled(sc.conf)) { throw SparkCoreErrors.barrierStageWithDynamicAllocationError() } } @@ -2163,7 +2162,7 @@ private[spark] class DAGScheduler( case mapStage: ShuffleMapStage => val numMissingPartitions = mapStage.findMissingPartitions().length if (numMissingPartitions < mapStage.numTasks) { - if (sc.getConf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) { + if (sc.conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) { val reason = "A shuffle map stage with indeterminate output was failed " + "and retried. However, Spark can only do this while using the new " + "shuffle block fetching protocol. Please check the config " + @@ -2893,8 +2892,8 @@ private[spark] class DAGScheduler( val finalException = exception.collect { // If the error is user-facing (defines error class and is not internal error), we don't // wrap it with "Job aborted" and expose this error to the end users directly. - case st: Exception with SparkThrowable if st.getErrorClass != null && - !SparkThrowableHelper.isInternalError(st.getErrorClass) => + case st: Exception with SparkThrowable if st.getCondition != null && + !SparkThrowableHelper.isInternalError(st.getCondition) => st }.getOrElse { new SparkException(s"Job aborted due to stage failure: $reason", cause = exception.orNull) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index a6d62005e4e66..9843066d14ee8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -2126,8 +2126,10 @@ private[spark] class BlockManager( hasRemoveBlock = true if (tellMaster) { // Only update storage level from the captured block status before deleting, so that - // memory size and disk size are being kept for calculating delta. - reportBlockStatus(blockId, blockStatus.get.copy(storageLevel = StorageLevel.NONE)) + // memory size and disk size are being kept for calculating delta. Reset the replica + // count 0 in storage level to notify that it is a remove operation. + val storageLevel = StorageLevel(blockStatus.get.storageLevel.toInt, 0) + reportBlockStatus(blockId, blockStatus.get.copy(storageLevel = storageLevel)) } } finally { if (!hasRemoveBlock) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 73f89ea0e86e5..fc4e6e771aad7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -1059,13 +1059,13 @@ private[spark] class BlockManagerInfo( _blocks.put(blockId, blockStatus) _remainingMem -= memSize if (blockExists) { - logInfo(log"Updated ${MDC(BLOCK_ID, blockId)} in memory on " + + logDebug(log"Updated ${MDC(BLOCK_ID, blockId)} in memory on " + log"${MDC(HOST_PORT, blockManagerId.hostPort)} (current size: " + log"${MDC(CURRENT_MEMORY_SIZE, Utils.bytesToString(memSize))}, original " + log"size: ${MDC(ORIGINAL_MEMORY_SIZE, Utils.bytesToString(originalMemSize))}, " + log"free: ${MDC(FREE_MEMORY_SIZE, Utils.bytesToString(_remainingMem))})") } else { - logInfo(log"Added ${MDC(BLOCK_ID, blockId)} in memory on " + + logDebug(log"Added ${MDC(BLOCK_ID, blockId)} in memory on " + log"${MDC(HOST_PORT, blockManagerId.hostPort)} " + log"(size: ${MDC(CURRENT_MEMORY_SIZE, Utils.bytesToString(memSize))}, " + log"free: ${MDC(FREE_MEMORY_SIZE, Utils.bytesToString(_remainingMem))})") @@ -1075,12 +1075,12 @@ private[spark] class BlockManagerInfo( blockStatus = BlockStatus(storageLevel, memSize = 0, diskSize = diskSize) _blocks.put(blockId, blockStatus) if (blockExists) { - logInfo(log"Updated ${MDC(BLOCK_ID, blockId)} on disk on " + + logDebug(log"Updated ${MDC(BLOCK_ID, blockId)} on disk on " + log"${MDC(HOST_PORT, blockManagerId.hostPort)} " + log"(current size: ${MDC(CURRENT_DISK_SIZE, Utils.bytesToString(diskSize))}," + log" original size: ${MDC(ORIGINAL_DISK_SIZE, Utils.bytesToString(originalDiskSize))})") } else { - logInfo(log"Added ${MDC(BLOCK_ID, blockId)} on disk on " + + logDebug(log"Added ${MDC(BLOCK_ID, blockId)} on disk on " + log"${MDC(HOST_PORT, blockManagerId.hostPort)} (size: " + log"${MDC(CURRENT_DISK_SIZE, Utils.bytesToString(diskSize))})") } @@ -1098,13 +1098,13 @@ private[spark] class BlockManagerInfo( blockStatus.remove(blockId) } if (originalLevel.useMemory) { - logInfo(log"Removed ${MDC(BLOCK_ID, blockId)} on " + + logDebug(log"Removed ${MDC(BLOCK_ID, blockId)} on " + log"${MDC(HOST_PORT, blockManagerId.hostPort)} in memory " + log"(size: ${MDC(ORIGINAL_MEMORY_SIZE, Utils.bytesToString(originalMemSize))}, " + log"free: ${MDC(FREE_MEMORY_SIZE, Utils.bytesToString(_remainingMem))})") } if (originalLevel.useDisk) { - logInfo(log"Removed ${MDC(BLOCK_ID, blockId)} on " + + logDebug(log"Removed ${MDC(BLOCK_ID, blockId)} on " + log"${MDC(HOST_PORT, blockManagerId.hostPort)} on disk" + log" (size: ${MDC(ORIGINAL_DISK_SIZE, Utils.bytesToString(originalDiskSize))})") } diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala index 7a2b7d9caec42..fc7a4675429aa 100644 --- a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala +++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala @@ -35,7 +35,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { // Carriage return private val CR = '\r' // Update period of progress bar, in milliseconds - private val updatePeriodMSec = sc.getConf.get(UI_CONSOLE_PROGRESS_UPDATE_INTERVAL) + private val updatePeriodMSec = sc.conf.get(UI_CONSOLE_PROGRESS_UPDATE_INTERVAL) // Delay to show up a progress bar, in milliseconds private val firstDelayMSec = 500L diff --git a/core/src/main/scala/org/apache/spark/util/LazyTry.scala b/core/src/main/scala/org/apache/spark/util/LazyTry.scala new file mode 100644 index 0000000000000..7edc08672c26b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/LazyTry.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.util.Try + +/** + * Wrapper utility for a lazy val, with two differences compared to scala behavior: + * + * 1. Non-retrying in case of failure. This wrapper stores the exception in a Try, and will re-throw + * it on the access to `get`. + * In scala, when a `lazy val` field initialization throws an exception, the field remains + * uninitialized, and initialization will be re-attempted on the next access. This also can lead + * to performance issues, needlessly computing something towards a failure, and also can lead to + * duplicated side effects. + * + * 2. Resolving locking issues. + * In scala, when a `lazy val` field is initialized, it grabs the synchronized lock on the + * enclosing object instance. This can lead both to performance issues, and deadlocks. + * For example: + * a) Thread 1 entered a synchronized method, grabbing a coarse lock on the parent object. + * b) Thread 2 get spawned off, and tries to initialize a lazy value on the same parent object + * This causes scala to also try to grab a lock on the parent object. + * c) If thread 1 waits for thread 2 to join, a deadlock occurs. + * This wrapper will only grab a lock on the wrapper itself, and not the parent object. + * + * @param initialize The block of code to initialize the lazy value. + * @tparam T type of the lazy value. + */ +private[spark] class LazyTry[T](initialize: => T) extends Serializable { + private lazy val tryT: Try[T] = Utils.doTryWithCallerStacktrace { initialize } + + /** + * Get the lazy value. If the initialization block threw an exception, it will be re-thrown here. + * The exception will be re-thrown with the current caller's stacktrace. + * An exception with stack trace from when the exception was first thrown can be accessed with + * ``` + * ex.getSuppressed.find { e => + * e.getMessage == org.apache.spark.util.Utils.TRY_WITH_CALLER_STACKTRACE_FULL_STACKTRACE + * } + * ``` + */ + def get: T = Utils.getTryWithCallerStacktrace(tryT) +} + +private[spark] object LazyTry { + /** + * Create a new LazyTry instance. + * + * @param initialize The block of code to initialize the lazy value. + * @tparam T type of the lazy value. + * @return a new LazyTry instance. + */ + def apply[T](initialize: => T): LazyTry[T] = new LazyTry(initialize) +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index d8392cd8043de..5703128aacbb9 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1351,6 +1351,86 @@ private[spark] object Utils } } + val TRY_WITH_CALLER_STACKTRACE_FULL_STACKTRACE = + "Full stacktrace of original doTryWithCallerStacktrace caller" + + val TRY_WITH_CALLER_STACKTRACE_TRY_STACKTRACE = + "Stacktrace under doTryWithCallerStacktrace" + + /** + * Use Try with stacktrace substitution for the caller retrieving the error. + * + * Normally in case of failure, the exception would have the stacktrace of the caller that + * originally called doTryWithCallerStacktrace. However, we want to replace the part above + * this function with the stacktrace of the caller who calls getTryWithCallerStacktrace. + * So here we save the part of the stacktrace below doTryWithCallerStacktrace, and + * getTryWithCallerStacktrace will stitch it with the new stack trace of the caller. + * The full original stack trace is kept in ex.getSuppressed. + * + * @param f Code block to be wrapped in Try + * @return Try with Success or Failure of the code block. Use with getTryWithCallerStacktrace. + */ + def doTryWithCallerStacktrace[T](f: => T): Try[T] = { + val t = Try { + f + } + t match { + case Failure(ex) => + // Note: we remove the common suffix instead of e.g. finding the call to this function, to + // account for recursive calls with multiple doTryWithCallerStacktrace on the stack trace. + val origStackTrace = ex.getStackTrace + val currentStackTrace = Thread.currentThread().getStackTrace + val commonSuffixLen = origStackTrace.reverse.zip(currentStackTrace.reverse).takeWhile { + case (exElem, currentElem) => exElem == currentElem + }.length + val belowEx = new Exception(TRY_WITH_CALLER_STACKTRACE_TRY_STACKTRACE) + belowEx.setStackTrace(origStackTrace.dropRight(commonSuffixLen)) + ex.addSuppressed(belowEx) + + // keep the full original stack trace in a suppressed exception. + val fullEx = new Exception(TRY_WITH_CALLER_STACKTRACE_FULL_STACKTRACE) + fullEx.setStackTrace(origStackTrace) + ex.addSuppressed(fullEx) + case Success(_) => // nothing + } + t + } + + /** + * Retrieve the result of Try that was created by doTryWithCallerStacktrace. + * + * In case of failure, the resulting exception has a stack trace that combines the stack trace + * below the original doTryWithCallerStacktrace which triggered it, with the caller stack trace + * of the current caller of getTryWithCallerStacktrace. + * + * Full stack trace of the original doTryWithCallerStacktrace caller can be retrieved with + * ``` + * ex.getSuppressed.find { e => + * e.getMessage == Utils.TRY_WITH_CALLER_STACKTRACE_FULL_STACKTRACE + * } + * ``` + * + * + * @param t Try from doTryWithCallerStacktrace + * @return Result of the Try or rethrows the failure exception with modified stacktrace. + */ + def getTryWithCallerStacktrace[T](t: Try[T]): T = t match { + case Failure(ex) => + val belowStacktrace = ex.getSuppressed.find { e => + // added in doTryWithCallerStacktrace + e.getMessage == TRY_WITH_CALLER_STACKTRACE_TRY_STACKTRACE + }.getOrElse { + // If we don't have the expected stacktrace information, just rethrow + throw ex + }.getStackTrace + // We are modifying and throwing the original exception. It would be better if we could + // return a copy, but we can't easily clone it and preserve. If this is accessed from + // multiple threads that then look at the stack trace, this could break. + ex.setStackTrace(belowStacktrace ++ Thread.currentThread().getStackTrace.drop(1)) + throw ex + case Success(s) => s + } + // A regular expression to match classes of the internal Spark API's // that we want to skip when finding the call site of a method. private val SPARK_CORE_CLASS_REGEX = @@ -2593,6 +2673,19 @@ private[spark] object Utils } } + /** + * Utility function to enable or disable structured logging based on SparkConf. + * This is designed for a code path which logging system may be initilized before + * loading SparkConf. + */ + def resetStructuredLogging(sparkConf: SparkConf): Unit = { + if (sparkConf.getBoolean(STRUCTURED_LOGGING_ENABLED.key, defaultValue = true)) { + Logging.enableStructuredLogging() + } else { + Logging.disableStructuredLogging() + } + } + /** * Return the jar files pointed by the "spark.jars" property. Spark internally will distribute * these jars through file server. In the YARN mode, it will return an empty list, since YARN diff --git a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala index 46e311d8b0476..ec43666898fa7 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala @@ -208,7 +208,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) /** * Re-hash a value to deal better with hash functions that don't differ in the lower bits. */ - private def rehash(h: Int): Int = Hashing.murmur3_32().hashInt(h).asInt() + private def rehash(h: Int): Int = Hashing.murmur3_32_fixed().hashInt(h).asInt() /** Double the table's size and re-hash everything */ protected def growTable(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index a42fa9ba6bc85..3d1eb5788c707 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -266,7 +266,7 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( /** * Re-hash a value to deal better with hash functions that don't differ in the lower bits. */ - private def hashcode(h: Int): Int = Hashing.murmur3_32().hashInt(h).asInt() + private def hashcode(h: Int): Int = Hashing.murmur3_32_fixed().hashInt(h).asInt() private def nextPowerOf2(n: Int): Int = { if (n == 0) { diff --git a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java index 11bd2b2a3312c..802cb2667cc88 100644 --- a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java @@ -960,7 +960,7 @@ public void textFiles() throws IOException { rdd.saveAsTextFile(outputDir); // Read the plain text file and check it's OK File outputFile = new File(outputDir, "part-00000"); - String content = Files.toString(outputFile, StandardCharsets.UTF_8); + String content = Files.asCharSource(outputFile, StandardCharsets.UTF_8).read(); assertEquals("1\n2\n3\n4\n", content); // Also try reading it in as a text file RDD List expected = Arrays.asList("1", "2", "3", "4"); diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 5651dc9b2dbdc..5f9912cbd021d 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -334,8 +334,8 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { for (i <- 0 until 8) { val tempFile = new File(tempDir, s"part-0000$i") - Files.write("someline1 in file1\nsomeline2 in file1\nsomeline3 in file1", tempFile, - StandardCharsets.UTF_8) + Files.asCharSink(tempFile, StandardCharsets.UTF_8) + .write("someline1 in file1\nsomeline2 in file1\nsomeline3 in file1") } for (p <- Seq(1, 2, 8)) { diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 380231ce97c0b..ca51e61f5ed44 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -288,7 +288,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft sem.acquire(1) sc.cancelJobGroupAndFutureJobs(s"job-group-$idx") ThreadUtils.awaitReady(job, Duration.Inf).failed.foreach { case e: SparkException => - assert(e.getErrorClass == "SPARK_JOB_CANCELLED") + assert(e.getCondition == "SPARK_JOB_CANCELLED") } } // submit a job with the 0 job group that was evicted from cancelledJobGroups set, it should run diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 12f9d2f83c777..44b2da603a1f6 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -119,8 +119,8 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu val absolutePath2 = file2.getAbsolutePath try { - Files.write("somewords1", file1, StandardCharsets.UTF_8) - Files.write("somewords2", file2, StandardCharsets.UTF_8) + Files.asCharSink(file1, StandardCharsets.UTF_8).write("somewords1") + Files.asCharSink(file2, StandardCharsets.UTF_8).write("somewords2") val length1 = file1.length() val length2 = file2.length() @@ -178,10 +178,10 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu s"${jarFile.getParent}/../${jarFile.getParentFile.getName}/${jarFile.getName}#zoo" try { - Files.write("somewords1", file1, StandardCharsets.UTF_8) - Files.write("somewords22", file2, StandardCharsets.UTF_8) - Files.write("somewords333", file3, StandardCharsets.UTF_8) - Files.write("somewords4444", file4, StandardCharsets.UTF_8) + Files.asCharSink(file1, StandardCharsets.UTF_8).write("somewords1") + Files.asCharSink(file2, StandardCharsets.UTF_8).write("somewords22") + Files.asCharSink(file3, StandardCharsets.UTF_8).write("somewords333") + Files.asCharSink(file4, StandardCharsets.UTF_8).write("somewords4444") val length1 = file1.length() val length2 = file2.length() val length3 = file1.length() @@ -373,8 +373,8 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu assert(subdir2.mkdir()) val file1 = new File(subdir1, "file") val file2 = new File(subdir2, "file") - Files.write("old", file1, StandardCharsets.UTF_8) - Files.write("new", file2, StandardCharsets.UTF_8) + Files.asCharSink(file1, StandardCharsets.UTF_8).write("old") + Files.asCharSink(file2, StandardCharsets.UTF_8).write("new") sc = new SparkContext("local-cluster[1,1,1024]", "test") sc.addFile(file1.getAbsolutePath) def getAddedFileContents(): String = { @@ -503,12 +503,15 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu try { // Create 5 text files. - Files.write("someline1 in file1\nsomeline2 in file1\nsomeline3 in file1", file1, - StandardCharsets.UTF_8) - Files.write("someline1 in file2\nsomeline2 in file2", file2, StandardCharsets.UTF_8) - Files.write("someline1 in file3", file3, StandardCharsets.UTF_8) - Files.write("someline1 in file4\nsomeline2 in file4", file4, StandardCharsets.UTF_8) - Files.write("someline1 in file2\nsomeline2 in file5", file5, StandardCharsets.UTF_8) + Files.asCharSink(file1, StandardCharsets.UTF_8) + .write("someline1 in file1\nsomeline2 in file1\nsomeline3 in file1") + Files.asCharSink(file2, StandardCharsets.UTF_8) + .write("someline1 in file2\nsomeline2 in file2") + Files.asCharSink(file3, StandardCharsets.UTF_8).write("someline1 in file3") + Files.asCharSink(file4, StandardCharsets.UTF_8) + .write("someline1 in file4\nsomeline2 in file4") + Files.asCharSink(file5, StandardCharsets.UTF_8) + .write("someline1 in file2\nsomeline2 in file5") sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 9f310c06ac5ae..e38efc27b78f9 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -343,7 +343,7 @@ abstract class SparkFunSuite parameters: Map[String, String] = Map.empty, matchPVals: Boolean = false, queryContext: Array[ExpectedContext] = Array.empty): Unit = { - assert(exception.getErrorClass === condition) + assert(exception.getCondition === condition) sqlState.foreach(state => assert(exception.getSqlState === state)) val expectedParameters = exception.getMessageParameters.asScala if (matchPVals) { diff --git a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala index 946ea75686e32..ea845c0f93a4b 100644 --- a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala @@ -47,7 +47,7 @@ class SparkThrowableSuite extends SparkFunSuite { }}} */ private val regenerateCommand = "SPARK_GENERATE_GOLDEN_FILES=1 build/sbt " + - "\"core/testOnly *SparkThrowableSuite -- -t \\\"Error classes match with document\\\"\"" + "\"core/testOnly *SparkThrowableSuite -- -t \\\"Error conditions are correctly formatted\\\"\"" private val errorJsonFilePath = getWorkspaceFilePath( "common", "utils", "src", "main", "resources", "error", "error-conditions.json") @@ -199,7 +199,7 @@ class SparkThrowableSuite extends SparkFunSuite { val e = intercept[SparkException] { getMessage("UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", Map.empty[String, String]) } - assert(e.getErrorClass === "INTERNAL_ERROR") + assert(e.getCondition === "INTERNAL_ERROR") assert(e.getMessageParameters().get("message").contains("Undefined error message parameter")) } @@ -245,7 +245,7 @@ class SparkThrowableSuite extends SparkFunSuite { throw new SparkException("Arbitrary legacy message") } catch { case e: SparkThrowable => - assert(e.getErrorClass == null) + assert(e.getCondition == null) assert(!e.isInternalError) assert(e.getSqlState == null) case _: Throwable => @@ -262,7 +262,7 @@ class SparkThrowableSuite extends SparkFunSuite { cause = null) } catch { case e: SparkThrowable => - assert(e.getErrorClass == "CANNOT_PARSE_DECIMAL") + assert(e.getCondition == "CANNOT_PARSE_DECIMAL") assert(!e.isInternalError) assert(e.getSqlState == "22018") case _: Throwable => @@ -357,7 +357,7 @@ class SparkThrowableSuite extends SparkFunSuite { |}""".stripMargin) // Legacy mode when an exception does not have any error class class LegacyException extends Throwable with SparkThrowable { - override def getErrorClass: String = null + override def getCondition: String = null override def getMessage: String = "Test message" } val e3 = new LegacyException @@ -452,7 +452,7 @@ class SparkThrowableSuite extends SparkFunSuite { val e = intercept[SparkException] { new ErrorClassesJsonReader(Seq(errorJsonFilePath.toUri.toURL, json.toURI.toURL)) } - assert(e.getErrorClass === "INTERNAL_ERROR") + assert(e.getCondition === "INTERNAL_ERROR") assert(e.getMessage.contains("DIVIDE.BY_ZERO")) } @@ -478,7 +478,7 @@ class SparkThrowableSuite extends SparkFunSuite { val e = intercept[SparkException] { new ErrorClassesJsonReader(Seq(errorJsonFilePath.toUri.toURL, json.toURI.toURL)) } - assert(e.getErrorClass === "INTERNAL_ERROR") + assert(e.getCondition === "INTERNAL_ERROR") assert(e.getMessage.contains("BY.ZERO")) } } diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 1efef3383b821..b0f36b9744fa8 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -317,13 +317,13 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio // Instead, crash the driver by directly accessing the broadcast value. val e1 = intercept[SparkException] { broadcast.value } assert(e1.isInternalError) - assert(e1.getErrorClass == "INTERNAL_ERROR_BROADCAST") + assert(e1.getCondition == "INTERNAL_ERROR_BROADCAST") val e2 = intercept[SparkException] { broadcast.unpersist(blocking = true) } assert(e2.isInternalError) - assert(e2.getErrorClass == "INTERNAL_ERROR_BROADCAST") + assert(e2.getCondition == "INTERNAL_ERROR_BROADCAST") val e3 = intercept[SparkException] { broadcast.destroy(blocking = true) } assert(e3.isInternalError) - assert(e3.getErrorClass == "INTERNAL_ERROR_BROADCAST") + assert(e3.getCondition == "INTERNAL_ERROR_BROADCAST") } else { val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum)) assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet) @@ -339,7 +339,7 @@ package object testPackage extends Assertions { val thrown = intercept[SparkException] { broadcast.value } assert(thrown.getMessage.contains("BroadcastSuite.scala")) assert(thrown.isInternalError) - assert(thrown.getErrorClass == "INTERNAL_ERROR_BROADCAST") + assert(thrown.getCondition == "INTERNAL_ERROR_BROADCAST") } } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileReadersSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileReadersSuite.scala index f34f792881f90..7501a98a1a573 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileReadersSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileReadersSuite.scala @@ -221,7 +221,7 @@ class SingleFileEventLogFileReaderSuite extends EventLogFileReadersSuite { val entry = is.getNextEntry assert(entry != null) val actual = new String(ByteStreams.toByteArray(is), StandardCharsets.UTF_8) - val expected = Files.toString(new File(logPath.toString), StandardCharsets.UTF_8) + val expected = Files.asCharSource(new File(logPath.toString), StandardCharsets.UTF_8).read() assert(actual === expected) assert(is.getNextEntry === null) } @@ -368,8 +368,8 @@ class RollingEventLogFilesReaderSuite extends EventLogFileReadersSuite { assert(allFileNames.contains(fileName)) val actual = new String(ByteStreams.toByteArray(is), StandardCharsets.UTF_8) - val expected = Files.toString(new File(logPath.toString, fileName), - StandardCharsets.UTF_8) + val expected = Files.asCharSource( + new File(logPath.toString, fileName), StandardCharsets.UTF_8).read() assert(actual === expected) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 3013a5bf4a294..852f94bda870d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -708,7 +708,8 @@ abstract class FsHistoryProviderSuite extends SparkFunSuite with Matchers with P while (entry != null) { val actual = new String(ByteStreams.toByteArray(inputStream), StandardCharsets.UTF_8) val expected = - Files.toString(logs.find(_.getName == entry.getName).get, StandardCharsets.UTF_8) + Files.asCharSource(logs.find(_.getName == entry.getName).get, StandardCharsets.UTF_8) + .read() actual should be (expected) totalEntries += 1 entry = inputStream.getNextEntry diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala index 2b9b110a41424..807e5ec3e823e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala @@ -45,8 +45,8 @@ class HistoryServerArgumentsSuite extends SparkFunSuite { test("Properties File Arguments Parsing --properties-file") { withTempDir { tmpDir => val outFile = File.createTempFile("test-load-spark-properties", "test", tmpDir) - Files.write("spark.test.CustomPropertyA blah\n" + - "spark.test.CustomPropertyB notblah\n", outFile, UTF_8) + Files.asCharSink(outFile, UTF_8).write("spark.test.CustomPropertyA blah\n" + + "spark.test.CustomPropertyB notblah\n") val argStrings = Array("--properties-file", outFile.getAbsolutePath) val hsa = new HistoryServerArguments(conf, argStrings) assert(conf.get("spark.test.CustomPropertyA") === "blah") diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index abb5ae720af07..6b2bd90cd4314 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -283,7 +283,7 @@ abstract class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with val expectedFile = { new File(logDir, entry.getName) } - val expected = Files.toString(expectedFile, StandardCharsets.UTF_8) + val expected = Files.asCharSource(expectedFile, StandardCharsets.UTF_8).read() val actual = new String(ByteStreams.toByteArray(zipStream), StandardCharsets.UTF_8) actual should be (expected) filesCompared += 1 diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala index 9eb5172583120..f2807f258f2d1 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -87,8 +87,6 @@ class SubmitRestProtocolSuite extends SparkFunSuite { message.clientSparkVersion = "1.2.3" message.appResource = "honey-walnut-cherry.jar" message.mainClass = "org.apache.spark.examples.SparkPie" - message.appArgs = Array("two slices") - message.environmentVariables = Map("PATH" -> "/dev/null") val conf = new SparkConf(false) conf.set("spark.app.name", "SparkPie") message.sparkProperties = conf.getAll.toMap diff --git a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala index 7bd84c810c42e..8b98df103c014 100644 --- a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala @@ -37,7 +37,7 @@ import org.scalatestplus.mockito.MockitoSugar import org.apache.spark._ import org.apache.spark.TestUtils._ import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext, SparkPlugin} -import org.apache.spark.internal.config.PLUGINS +import org.apache.spark.internal.config.{EXECUTOR_MEMORY, PLUGINS} import org.apache.spark.resource._ import org.apache.spark.resource.ResourceUtils._ import org.apache.spark.resource.TestResourceIDs._ @@ -581,7 +581,8 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite */ test("SPARK-40320 Executor should exit when initialization failed for fatal error") { val conf = createSparkConf() - .setMaster("local-cluster[1, 1, 1024]") + .setMaster("local-cluster[1, 1, 512]") + .set(EXECUTOR_MEMORY.key, "512m") .set(PLUGINS, Seq(classOf[TestFatalErrorPlugin].getName)) .setAppName("test") sc = new SparkContext(conf) @@ -599,7 +600,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite } try { sc.addSparkListener(listener) - eventually(timeout(15.seconds)) { + eventually(timeout(30.seconds)) { assert(executorAddCounter.get() >= 2) assert(executorRemovedCounter.get() >= 2) } diff --git a/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala b/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala index 79fa8d21bf3f1..fc8f48df2cb7d 100644 --- a/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala +++ b/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala @@ -383,7 +383,7 @@ object NonLocalModeSparkPlugin { resources: Map[String, ResourceInformation]): Unit = { val path = conf.get(TEST_PATH_CONF) val strToWrite = createFileStringWithGpuAddrs(id, resources) - Files.write(strToWrite, new File(path, s"$filePrefix$id"), StandardCharsets.UTF_8) + Files.asCharSink(new File(path, s"$filePrefix$id"), StandardCharsets.UTF_8).write(strToWrite) } def reset(): Unit = { diff --git a/core/src/test/scala/org/apache/spark/resource/ResourceDiscoveryPluginSuite.scala b/core/src/test/scala/org/apache/spark/resource/ResourceDiscoveryPluginSuite.scala index ff7d680352177..edf138df9e207 100644 --- a/core/src/test/scala/org/apache/spark/resource/ResourceDiscoveryPluginSuite.scala +++ b/core/src/test/scala/org/apache/spark/resource/ResourceDiscoveryPluginSuite.scala @@ -148,7 +148,7 @@ object TestResourceDiscoveryPlugin { def writeFile(conf: SparkConf, id: String): Unit = { val path = conf.get(TEST_PATH_CONF) val fileName = s"$id - ${UUID.randomUUID.toString}" - Files.write(id, new File(path, fileName), StandardCharsets.UTF_8) + Files.asCharSink(new File(path, fileName), StandardCharsets.UTF_8).write(id) } } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 3ef382573517b..66b1ee7b58ac8 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -868,23 +868,23 @@ abstract class RpcEnvSuite extends SparkFunSuite { val conf = createSparkConf() val file = new File(tempDir, "file") - Files.write(UUID.randomUUID().toString(), file, UTF_8) + Files.asCharSink(file, UTF_8).write(UUID.randomUUID().toString) val fileWithSpecialChars = new File(tempDir, "file name") - Files.write(UUID.randomUUID().toString(), fileWithSpecialChars, UTF_8) + Files.asCharSink(fileWithSpecialChars, UTF_8).write(UUID.randomUUID().toString) val empty = new File(tempDir, "empty") - Files.write("", empty, UTF_8); + Files.asCharSink(empty, UTF_8).write("") val jar = new File(tempDir, "jar") - Files.write(UUID.randomUUID().toString(), jar, UTF_8) + Files.asCharSink(jar, UTF_8).write(UUID.randomUUID().toString) val dir1 = new File(tempDir, "dir1") assert(dir1.mkdir()) val subFile1 = new File(dir1, "file1") - Files.write(UUID.randomUUID().toString(), subFile1, UTF_8) + Files.asCharSink(subFile1, UTF_8).write(UUID.randomUUID().toString) val dir2 = new File(tempDir, "dir2") assert(dir2.mkdir()) val subFile2 = new File(dir2, "file2") - Files.write(UUID.randomUUID().toString(), subFile2, UTF_8) + Files.asCharSink(subFile2, UTF_8).write(UUID.randomUUID().toString) val fileUri = env.fileServer.addFile(file) val fileWithSpecialCharsUri = env.fileServer.addFile(fileWithSpecialChars) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 1fbc900727c4c..f5fca56e5ef77 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -38,6 +38,8 @@ import org.apache.spark.internal.config.Tests._ import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService +import org.apache.spark.network.shuffle.ExternalBlockStoreClient +import org.apache.spark.network.util.{MapConfigProvider, TransportConf} import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{KryoSerializer, SerializerManager} @@ -296,6 +298,41 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite } } + test("Test block location after replication with SHUFFLE_SERVICE_FETCH_RDD_ENABLED enabled") { + val newConf = conf.clone() + newConf.set(SHUFFLE_SERVICE_ENABLED, true) + newConf.set(SHUFFLE_SERVICE_FETCH_RDD_ENABLED, true) + newConf.set(Tests.TEST_SKIP_ESS_REGISTER, true) + val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo]() + val shuffleClient = Some(new ExternalBlockStoreClient( + new TransportConf("shuffle", MapConfigProvider.EMPTY), + null, false, 5000)) + master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager-2", + new BlockManagerMasterEndpoint(rpcEnv, true, newConf, + new LiveListenerBus(newConf), shuffleClient, blockManagerInfo, mapOutputTracker, + sc.env.shuffleManager, isDriver = true)), + rpcEnv.setupEndpoint("blockmanagerHeartbeat-2", + new BlockManagerMasterHeartbeatEndpoint(rpcEnv, true, blockManagerInfo)), newConf, true) + + val shuffleServicePort = newConf.get(SHUFFLE_SERVICE_PORT) + val store1 = makeBlockManager(10000, "host-1") + val store2 = makeBlockManager(10000, "host-2") + assert(master.getPeers(store1.blockManagerId).toSet === Set(store2.blockManagerId)) + + val blockId = RDDBlockId(1, 2) + val message = new Array[Byte](1000) + + // if SHUFFLE_SERVICE_FETCH_RDD_ENABLED is enabled, then shuffle port should be present. + store1.putSingle(blockId, message, StorageLevel.DISK_ONLY) + assert(master.getLocations(blockId).contains( + BlockManagerId("host-1", "localhost", shuffleServicePort, None))) + + // after block is removed, shuffle port should be removed. + store1.removeBlock(blockId, true) + assert(!master.getLocations(blockId).contains( + BlockManagerId("host-1", "localhost", shuffleServicePort, None))) + } + test("block replication - addition and deletion of block managers") { val blockSize = 1000 val storeSize = 10000 diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 17dff20dd993b..9fbe15402c8b3 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -33,7 +33,7 @@ import scala.reflect.classTag import com.esotericsoftware.kryo.KryoException import org.mockito.{ArgumentCaptor, ArgumentMatchers => mc} -import org.mockito.Mockito.{doAnswer, mock, never, spy, times, verify, when} +import org.mockito.Mockito.{atLeastOnce, doAnswer, mock, never, spy, times, verify, when} import org.scalatest.PrivateMethodTester import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.concurrent.Eventually._ @@ -698,7 +698,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe removedFromMemory: Boolean, removedFromDisk: Boolean): Unit = { def assertSizeReported(captor: ArgumentCaptor[Long], expectRemoved: Boolean): Unit = { - assert(captor.getAllValues().size() === 1) + assert(captor.getAllValues().size() >= 1) if (expectRemoved) { assert(captor.getValue() > 0) } else { @@ -708,15 +708,18 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe val memSizeCaptor = ArgumentCaptor.forClass(classOf[Long]).asInstanceOf[ArgumentCaptor[Long]] val diskSizeCaptor = ArgumentCaptor.forClass(classOf[Long]).asInstanceOf[ArgumentCaptor[Long]] - verify(master).updateBlockInfo(mc.eq(store.blockManagerId), mc.eq(blockId), - mc.eq(StorageLevel.NONE), memSizeCaptor.capture(), diskSizeCaptor.capture()) + val storageLevelCaptor = + ArgumentCaptor.forClass(classOf[StorageLevel]).asInstanceOf[ArgumentCaptor[StorageLevel]] + verify(master, atLeastOnce()).updateBlockInfo(mc.eq(store.blockManagerId), mc.eq(blockId), + storageLevelCaptor.capture(), memSizeCaptor.capture(), diskSizeCaptor.capture()) assertSizeReported(memSizeCaptor, removedFromMemory) assertSizeReported(diskSizeCaptor, removedFromDisk) + assert(storageLevelCaptor.getValue.replication == 0) } private def assertUpdateBlockInfoNotReported(store: BlockManager, blockId: BlockId): Unit = { verify(master, never()).updateBlockInfo(mc.eq(store.blockManagerId), mc.eq(blockId), - mc.eq(StorageLevel.NONE), mc.anyInt(), mc.anyInt()) + mc.any[StorageLevel](), mc.anyInt(), mc.anyInt()) } test("reregistration on heart beat") { diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 35ef0587b9b4c..4497ea1b2b798 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -54,11 +54,11 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter { val inputStream = new ByteArrayInputStream(testString.getBytes(StandardCharsets.UTF_8)) // The `header` should not be covered val header = "Add header" - Files.write(header, testFile, StandardCharsets.UTF_8) + Files.asCharSink(testFile, StandardCharsets.UTF_8).write(header) val appender = new FileAppender(inputStream, testFile) inputStream.close() appender.awaitTermination() - assert(Files.toString(testFile, StandardCharsets.UTF_8) === header + testString) + assert(Files.asCharSource(testFile, StandardCharsets.UTF_8).read() === header + testString) } test("SPARK-35027: basic file appender - close stream") { @@ -392,7 +392,7 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter { IOUtils.closeQuietly(inputStream) } } else { - Files.toString(file, StandardCharsets.UTF_8) + Files.asCharSource(file, StandardCharsets.UTF_8).read() } }.mkString("") assert(allText === expectedText) diff --git a/core/src/test/scala/org/apache/spark/util/LazyTrySuite.scala b/core/src/test/scala/org/apache/spark/util/LazyTrySuite.scala new file mode 100644 index 0000000000000..79c07f8fbfead --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/LazyTrySuite.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +import org.apache.spark.SparkFunSuite + +class LazyTrySuite extends SparkFunSuite{ + test("LazyTry should initialize only once") { + var count = 0 + val lazyVal = LazyTry { + count += 1 + count + } + assert(count == 0) + assert(lazyVal.get == 1) + assert(count == 1) + assert(lazyVal.get == 1) + assert(count == 1) + } + + test("LazyTry should re-throw exceptions") { + val lazyVal = LazyTry { + throw new RuntimeException("test") + } + intercept[RuntimeException] { + lazyVal.get + } + intercept[RuntimeException] { + lazyVal.get + } + } + + test("LazyTry should re-throw exceptions with current caller stack-trace") { + val fileName = Thread.currentThread().getStackTrace()(1).getFileName + val lineNo = Thread.currentThread().getStackTrace()(1).getLineNumber + val lazyVal = LazyTry { + throw new RuntimeException("test") + } + + val e1 = intercept[RuntimeException] { + lazyVal.get // lineNo + 6 + } + assert(e1.getStackTrace + .exists(elem => elem.getFileName == fileName && elem.getLineNumber == lineNo + 6)) + + val e2 = intercept[RuntimeException] { + lazyVal.get // lineNo + 12 + } + assert(e2.getStackTrace + .exists(elem => elem.getFileName == fileName && elem.getLineNumber == lineNo + 12)) + } + + test("LazyTry does not lock containing object") { + class LazyContainer() { + @volatile var aSet = 0 + + val a: LazyTry[Int] = LazyTry { + aSet = 1 + aSet + } + + val b: LazyTry[Int] = LazyTry { + val t = new Thread(new Runnable { + override def run(): Unit = { + assert(a.get == 1) + } + }) + t.start() + t.join() + aSet + } + } + val container = new LazyContainer() + // Nothing is lazy initialized yet + assert(container.aSet == 0) + // This will not deadlock, thread t will initialize a, and update aSet + assert(container.b.get == 1) + assert(container.aSet == 1) + } + + // Scala lazy val tests are added to test for potential changes in the semantics of scala lazy val + + test("Scala lazy val initializing multiple times on error") { + class LazyValError() { + var counter = 0 + lazy val a = { + counter += 1 + throw new RuntimeException("test") + } + } + val lazyValError = new LazyValError() + intercept[RuntimeException] { + lazyValError.a + } + assert(lazyValError.counter == 1) + intercept[RuntimeException] { + lazyValError.a + } + assert(lazyValError.counter == 2) + } + + test("Scala lazy val locking containing object and deadlocking") { + // Note: this will change in scala 3, with different lazy vals not deadlocking with each other. + // https://docs.scala-lang.org/scala3/reference/changed-features/lazy-vals-init.html + class LazyValContainer() { + @volatile var aSet = 0 + @volatile var t: Thread = _ + + lazy val a = { + aSet = 1 + aSet + } + + lazy val b = { + t = new Thread(new Runnable { + override def run(): Unit = { + assert(a == 1) + } + }) + t.start() + t.join(1000) + aSet + } + } + val container = new LazyValContainer() + // Nothing is lazy initialized yet + assert(container.aSet == 0) + // This will deadlock, because b will take monitor on LazyValContainer, and then thread t + // will wait on that monitor, not able to initialize a. + // b will therefore see aSet == 0. + assert(container.b == 0) + // However, after b finishes initializing, the monitor will be released, and then thread t + // will finish initializing a, and set aSet to 1. + container.t.join() + assert(container.aSet == 1) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 4fe6fcf17f49f..a6e3345fc600c 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -28,7 +28,7 @@ import java.util.concurrent.TimeUnit import java.util.zip.GZIPOutputStream import scala.collection.mutable.ListBuffer -import scala.util.Random +import scala.util.{Random, Try} import com.google.common.io.Files import org.apache.commons.io.IOUtils @@ -735,8 +735,8 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties { withTempDir { tmpDir => val outFile = File.createTempFile("test-load-spark-properties", "test", tmpDir) System.setProperty("spark.test.fileNameLoadB", "2") - Files.write("spark.test.fileNameLoadA true\n" + - "spark.test.fileNameLoadB 1\n", outFile, UTF_8) + Files.asCharSink(outFile, UTF_8).write("spark.test.fileNameLoadA true\n" + + "spark.test.fileNameLoadB 1\n") val properties = Utils.getPropertiesFromFile(outFile.getAbsolutePath) properties .filter { case (k, v) => k.startsWith("spark.")} @@ -765,7 +765,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties { val innerSourceDir = Utils.createTempDir(root = sourceDir.getPath) val sourceFile = File.createTempFile("someprefix", "somesuffix", innerSourceDir) val targetDir = new File(tempDir, "target-dir") - Files.write("some text", sourceFile, UTF_8) + Files.asCharSink(sourceFile, UTF_8).write("some text") val path = if (Utils.isWindows) { @@ -1523,6 +1523,116 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties { conf.set(SERIALIZER, "org.apache.spark.serializer.JavaSerializer") assert(Utils.isPushBasedShuffleEnabled(conf, isDriver = true) === false) } + + + private def throwException(): String = { + throw new Exception("test") + } + + private def callDoTry(): Try[String] = { + Utils.doTryWithCallerStacktrace { + throwException() + } + } + + private def callGetTry(t: Try[String]): String = { + Utils.getTryWithCallerStacktrace(t) + } + + private def callGetTryAgain(t: Try[String]): String = { + Utils.getTryWithCallerStacktrace(t) + } + + test("doTryWithCallerStacktrace and getTryWithCallerStacktrace") { + val t = callDoTry() + + val e1 = intercept[Exception] { + callGetTry(t) + } + // Uncomment for manual inspection + // e1.printStackTrace() + // Example: + // java.lang.Exception: test + // at org.apache.spark.util.UtilsSuite.throwException(UtilsSuite.scala:1640) + // at org.apache.spark.util.UtilsSuite.$anonfun$callDoTry$1(UtilsSuite.scala:1645) + // at scala.util.Try$.apply(Try.scala:213) + // at org.apache.spark.util.Utils$.doTryWithCallerStacktrace(Utils.scala:1586) + // at org.apache.spark.util.Utils$.getTryWithCallerStacktrace(Utils.scala:1639) + // at org.apache.spark.util.UtilsSuite.callGetTry(UtilsSuite.scala:1650) + // at org.apache.spark.util.UtilsSuite.$anonfun$new$165(UtilsSuite.scala:1661) + // <- callGetTry is seen as calling getTryWithCallerStacktrace + + val st1 = e1.getStackTrace + // throwException should be on the stack trace + assert(st1.exists(_.getMethodName == "throwException")) + // callDoTry shouldn't be on the stack trace, but callGetTry should be. + assert(!st1.exists(_.getMethodName == "callDoTry")) + assert(st1.exists(_.getMethodName == "callGetTry")) + + // The original stack trace with callDoTry should be in the suppressed exceptions. + // Example: + // scalastyle:off line.size.limit + // Suppressed: java.lang.Exception: Full stacktrace of original doTryWithCallerStacktrace caller + // at org.apache.spark.util.UtilsSuite.throwException(UtilsSuite.scala:1640) + // at org.apache.spark.util.UtilsSuite.$anonfun$callDoTry$1(UtilsSuite.scala:1645) + // at scala.util.Try$.apply(Try.scala:213) + // at org.apache.spark.util.Utils$.doTryWithCallerStacktrace(Utils.scala:1586) + // at org.apache.spark.util.UtilsSuite.callDoTry(UtilsSuite.scala:1645) + // at org.apache.spark.util.UtilsSuite.$anonfun$new$165(UtilsSuite.scala:1658) + // ... 56 more + // scalastyle:on line.size.limit + val origSt = e1.getSuppressed.find( + _.getMessage == Utils.TRY_WITH_CALLER_STACKTRACE_FULL_STACKTRACE) + assert(origSt.isDefined) + assert(origSt.get.getStackTrace.exists(_.getMethodName == "throwException")) + assert(origSt.get.getStackTrace.exists(_.getMethodName == "callDoTry")) + + // The stack trace under Try should be in the suppressed exceptions. + // Example: + // Suppressed: java.lang.Exception: Stacktrace under doTryWithCallerStacktrace + // at org.apache.spark.util.UtilsSuite.throwException(UtilsSuite.scala: 1640) + // at org.apache.spark.util.UtilsSuite.$anonfun$callDoTry$1(UtilsSuite.scala: 1645) + // at scala.util.Try$.apply(Try.scala: 213) + // at org.apache.spark.util.Utils$.doTryWithCallerStacktrace(Utils.scala: 1586) + val trySt = e1.getSuppressed.find( + _.getMessage == Utils.TRY_WITH_CALLER_STACKTRACE_TRY_STACKTRACE) + assert(trySt.isDefined) + // calls under callDoTry should be present. + assert(trySt.get.getStackTrace.exists(_.getMethodName == "throwException")) + // callDoTry should be removed. + assert(!trySt.get.getStackTrace.exists(_.getMethodName == "callDoTry")) + + val e2 = intercept[Exception] { + callGetTryAgain(t) + } + // Uncomment for manual inspection + // e2.printStackTrace() + // Example: + // java.lang.Exception: test + // at org.apache.spark.util.UtilsSuite.throwException(UtilsSuite.scala:1640) + // at org.apache.spark.util.UtilsSuite.$anonfun$callDoTry$1(UtilsSuite.scala:1645) + // at scala.util.Try$.apply(Try.scala:213) + // at org.apache.spark.util.Utils$.doTryWithCallerStacktrace(Utils.scala:1586) + // at org.apache.spark.util.Utils$.getTryWithCallerStacktrace(Utils.scala:1639) + // at org.apache.spark.util.UtilsSuite.callGetTryAgain(UtilsSuite.scala:1654) + // at org.apache.spark.util.UtilsSuite.$anonfun$new$165(UtilsSuite.scala:1711) + // <- callGetTryAgain is seen as calling getTryWithCallerStacktrace + + val st2 = e2.getStackTrace + // throwException should be on the stack trace + assert(st2.exists(_.getMethodName == "throwException")) + // callDoTry shouldn't be on the stack trace, but callGetTryAgain should be. + assert(!st2.exists(_.getMethodName == "callDoTry")) + assert(st2.exists(_.getMethodName == "callGetTryAgain")) + // callGetTry that we called before shouldn't be on the stack trace. + assert(!st2.exists(_.getMethodName == "callGetTry")) + + // Unfortunately, this utility is not able to clone the exception, but modifies it in place, + // so now e1 is also pointing to "callGetTryAgain" instead of "callGetTry". + val st1Again = e1.getStackTrace + assert(st1Again.exists(_.getMethodName == "callGetTryAgain")) + assert(!st1Again.exists(_.getMethodName == "callGetTry")) + } } private class SimpleExtension diff --git a/dev/.rat-excludes b/dev/.rat-excludes index f38fd7e2012a5..b82cb7078c9f3 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -140,3 +140,4 @@ ui-test/package.json ui-test/package-lock.json core/src/main/resources/org/apache/spark/ui/static/package.json .*\.har +.nojekyll diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 9871cc0bca04f..91e84b0780798 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -44,7 +44,7 @@ commons-compiler/3.1.9//commons-compiler-3.1.9.jar commons-compress/1.27.1//commons-compress-1.27.1.jar commons-crypto/1.1.0//commons-crypto-1.1.0.jar commons-dbcp/1.4//commons-dbcp-1.4.jar -commons-io/2.16.1//commons-io-2.16.1.jar +commons-io/2.17.0//commons-io-2.17.0.jar commons-lang/2.6//commons-lang-2.6.jar commons-lang3/3.17.0//commons-lang3-3.17.0.jar commons-math3/3.6.1//commons-math3-3.6.1.jar @@ -67,7 +67,7 @@ error_prone_annotations/2.26.1//error_prone_annotations-2.26.1.jar esdk-obs-java/3.20.4.2//esdk-obs-java-3.20.4.2.jar failureaccess/1.0.2//failureaccess-1.0.2.jar flatbuffers-java/24.3.25//flatbuffers-java-24.3.25.jar -gcs-connector/hadoop3-2.2.21/shaded/gcs-connector-hadoop3-2.2.21-shaded.jar +gcs-connector/hadoop3-2.2.25/shaded/gcs-connector-hadoop3-2.2.25-shaded.jar gmetric4j/1.0.10//gmetric4j-1.0.10.jar gson/2.11.0//gson-2.11.0.jar guava/33.2.1-jre//guava-33.2.1-jre.jar @@ -105,16 +105,16 @@ ini4j/0.5.4//ini4j-0.5.4.jar istack-commons-runtime/3.0.8//istack-commons-runtime-3.0.8.jar ivy/2.5.2//ivy-2.5.2.jar j2objc-annotations/3.0.0//j2objc-annotations-3.0.0.jar -jackson-annotations/2.17.2//jackson-annotations-2.17.2.jar +jackson-annotations/2.18.0//jackson-annotations-2.18.0.jar jackson-core-asl/1.9.13//jackson-core-asl-1.9.13.jar -jackson-core/2.17.2//jackson-core-2.17.2.jar -jackson-databind/2.17.2//jackson-databind-2.17.2.jar -jackson-dataformat-cbor/2.17.2//jackson-dataformat-cbor-2.17.2.jar -jackson-dataformat-yaml/2.17.2//jackson-dataformat-yaml-2.17.2.jar +jackson-core/2.18.0//jackson-core-2.18.0.jar +jackson-databind/2.18.0//jackson-databind-2.18.0.jar +jackson-dataformat-cbor/2.18.0//jackson-dataformat-cbor-2.18.0.jar +jackson-dataformat-yaml/2.18.0//jackson-dataformat-yaml-2.18.0.jar jackson-datatype-jdk8/2.17.0//jackson-datatype-jdk8-2.17.0.jar -jackson-datatype-jsr310/2.17.2//jackson-datatype-jsr310-2.17.2.jar +jackson-datatype-jsr310/2.18.0//jackson-datatype-jsr310-2.18.0.jar jackson-mapper-asl/1.9.13//jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.13/2.17.2//jackson-module-scala_2.13-2.17.2.jar +jackson-module-scala_2.13/2.18.0//jackson-module-scala_2.13-2.18.0.jar jakarta.annotation-api/2.0.0//jakarta.annotation-api-2.0.0.jar jakarta.inject-api/2.0.1//jakarta.inject-api-2.0.1.jar jakarta.servlet-api/5.0.0//jakarta.servlet-api-5.0.0.jar @@ -144,7 +144,7 @@ jetty-util-ajax/11.0.23//jetty-util-ajax-11.0.23.jar jetty-util/11.0.23//jetty-util-11.0.23.jar jjwt-api/0.12.6//jjwt-api-0.12.6.jar jline/2.14.6//jline-2.14.6.jar -jline/3.25.1//jline-3.25.1.jar +jline/3.26.3//jline-3.26.3.jar jna/5.14.0//jna-5.14.0.jar joda-time/2.13.0//joda-time-2.13.0.jar jodd-core/3.5.2//jodd-core-3.5.2.jar @@ -159,48 +159,48 @@ jsr305/3.0.0//jsr305-3.0.0.jar jta/1.1//jta-1.1.jar jul-to-slf4j/2.0.16//jul-to-slf4j-2.0.16.jar kryo-shaded/4.0.2//kryo-shaded-4.0.2.jar -kubernetes-client-api/6.13.3//kubernetes-client-api-6.13.3.jar -kubernetes-client/6.13.3//kubernetes-client-6.13.3.jar -kubernetes-httpclient-okhttp/6.13.3//kubernetes-httpclient-okhttp-6.13.3.jar -kubernetes-model-admissionregistration/6.13.3//kubernetes-model-admissionregistration-6.13.3.jar -kubernetes-model-apiextensions/6.13.3//kubernetes-model-apiextensions-6.13.3.jar -kubernetes-model-apps/6.13.3//kubernetes-model-apps-6.13.3.jar -kubernetes-model-autoscaling/6.13.3//kubernetes-model-autoscaling-6.13.3.jar -kubernetes-model-batch/6.13.3//kubernetes-model-batch-6.13.3.jar -kubernetes-model-certificates/6.13.3//kubernetes-model-certificates-6.13.3.jar -kubernetes-model-common/6.13.3//kubernetes-model-common-6.13.3.jar -kubernetes-model-coordination/6.13.3//kubernetes-model-coordination-6.13.3.jar -kubernetes-model-core/6.13.3//kubernetes-model-core-6.13.3.jar -kubernetes-model-discovery/6.13.3//kubernetes-model-discovery-6.13.3.jar -kubernetes-model-events/6.13.3//kubernetes-model-events-6.13.3.jar -kubernetes-model-extensions/6.13.3//kubernetes-model-extensions-6.13.3.jar -kubernetes-model-flowcontrol/6.13.3//kubernetes-model-flowcontrol-6.13.3.jar -kubernetes-model-gatewayapi/6.13.3//kubernetes-model-gatewayapi-6.13.3.jar -kubernetes-model-metrics/6.13.3//kubernetes-model-metrics-6.13.3.jar -kubernetes-model-networking/6.13.3//kubernetes-model-networking-6.13.3.jar -kubernetes-model-node/6.13.3//kubernetes-model-node-6.13.3.jar -kubernetes-model-policy/6.13.3//kubernetes-model-policy-6.13.3.jar -kubernetes-model-rbac/6.13.3//kubernetes-model-rbac-6.13.3.jar -kubernetes-model-resource/6.13.3//kubernetes-model-resource-6.13.3.jar -kubernetes-model-scheduling/6.13.3//kubernetes-model-scheduling-6.13.3.jar -kubernetes-model-storageclass/6.13.3//kubernetes-model-storageclass-6.13.3.jar +kubernetes-client-api/6.13.4//kubernetes-client-api-6.13.4.jar +kubernetes-client/6.13.4//kubernetes-client-6.13.4.jar +kubernetes-httpclient-okhttp/6.13.4//kubernetes-httpclient-okhttp-6.13.4.jar +kubernetes-model-admissionregistration/6.13.4//kubernetes-model-admissionregistration-6.13.4.jar +kubernetes-model-apiextensions/6.13.4//kubernetes-model-apiextensions-6.13.4.jar +kubernetes-model-apps/6.13.4//kubernetes-model-apps-6.13.4.jar +kubernetes-model-autoscaling/6.13.4//kubernetes-model-autoscaling-6.13.4.jar +kubernetes-model-batch/6.13.4//kubernetes-model-batch-6.13.4.jar +kubernetes-model-certificates/6.13.4//kubernetes-model-certificates-6.13.4.jar +kubernetes-model-common/6.13.4//kubernetes-model-common-6.13.4.jar +kubernetes-model-coordination/6.13.4//kubernetes-model-coordination-6.13.4.jar +kubernetes-model-core/6.13.4//kubernetes-model-core-6.13.4.jar +kubernetes-model-discovery/6.13.4//kubernetes-model-discovery-6.13.4.jar +kubernetes-model-events/6.13.4//kubernetes-model-events-6.13.4.jar +kubernetes-model-extensions/6.13.4//kubernetes-model-extensions-6.13.4.jar +kubernetes-model-flowcontrol/6.13.4//kubernetes-model-flowcontrol-6.13.4.jar +kubernetes-model-gatewayapi/6.13.4//kubernetes-model-gatewayapi-6.13.4.jar +kubernetes-model-metrics/6.13.4//kubernetes-model-metrics-6.13.4.jar +kubernetes-model-networking/6.13.4//kubernetes-model-networking-6.13.4.jar +kubernetes-model-node/6.13.4//kubernetes-model-node-6.13.4.jar +kubernetes-model-policy/6.13.4//kubernetes-model-policy-6.13.4.jar +kubernetes-model-rbac/6.13.4//kubernetes-model-rbac-6.13.4.jar +kubernetes-model-resource/6.13.4//kubernetes-model-resource-6.13.4.jar +kubernetes-model-scheduling/6.13.4//kubernetes-model-scheduling-6.13.4.jar +kubernetes-model-storageclass/6.13.4//kubernetes-model-storageclass-6.13.4.jar lapack/3.0.3//lapack-3.0.3.jar leveldbjni-all/1.8//leveldbjni-all-1.8.jar libfb303/0.9.3//libfb303-0.9.3.jar libthrift/0.16.0//libthrift-0.16.0.jar listenablefuture/9999.0-empty-to-avoid-conflict-with-guava//listenablefuture-9999.0-empty-to-avoid-conflict-with-guava.jar -log4j-1.2-api/2.22.1//log4j-1.2-api-2.22.1.jar -log4j-api/2.22.1//log4j-api-2.22.1.jar -log4j-core/2.22.1//log4j-core-2.22.1.jar -log4j-layout-template-json/2.22.1//log4j-layout-template-json-2.22.1.jar -log4j-slf4j2-impl/2.22.1//log4j-slf4j2-impl-2.22.1.jar +log4j-1.2-api/2.24.1//log4j-1.2-api-2.24.1.jar +log4j-api/2.24.1//log4j-api-2.24.1.jar +log4j-core/2.24.1//log4j-core-2.24.1.jar +log4j-layout-template-json/2.24.1//log4j-layout-template-json-2.24.1.jar +log4j-slf4j2-impl/2.24.1//log4j-slf4j2-impl-2.24.1.jar logging-interceptor/3.12.12//logging-interceptor-3.12.12.jar lz4-java/1.8.0//lz4-java-1.8.0.jar -metrics-core/4.2.27//metrics-core-4.2.27.jar -metrics-graphite/4.2.27//metrics-graphite-4.2.27.jar -metrics-jmx/4.2.27//metrics-jmx-4.2.27.jar -metrics-json/4.2.27//metrics-json-4.2.27.jar -metrics-jvm/4.2.27//metrics-jvm-4.2.27.jar +metrics-core/4.2.28//metrics-core-4.2.28.jar +metrics-graphite/4.2.28//metrics-graphite-4.2.28.jar +metrics-jmx/4.2.28//metrics-jmx-4.2.28.jar +metrics-json/4.2.28//metrics-json-4.2.28.jar +metrics-jvm/4.2.28//metrics-jvm-4.2.28.jar minlog/1.3.0//minlog-1.3.0.jar netty-all/4.1.110.Final//netty-all-4.1.110.Final.jar netty-buffer/4.1.110.Final//netty-buffer-4.1.110.Final.jar @@ -241,22 +241,22 @@ orc-shims/2.0.2//orc-shims-2.0.2.jar oro/2.0.8//oro-2.0.8.jar osgi-resource-locator/1.0.3//osgi-resource-locator-1.0.3.jar paranamer/2.8//paranamer-2.8.jar -parquet-column/1.14.2//parquet-column-1.14.2.jar -parquet-common/1.14.2//parquet-common-1.14.2.jar -parquet-encoding/1.14.2//parquet-encoding-1.14.2.jar -parquet-format-structures/1.14.2//parquet-format-structures-1.14.2.jar -parquet-hadoop/1.14.2//parquet-hadoop-1.14.2.jar -parquet-jackson/1.14.2//parquet-jackson-1.14.2.jar +parquet-column/1.14.3//parquet-column-1.14.3.jar +parquet-common/1.14.3//parquet-common-1.14.3.jar +parquet-encoding/1.14.3//parquet-encoding-1.14.3.jar +parquet-format-structures/1.14.3//parquet-format-structures-1.14.3.jar +parquet-hadoop/1.14.3//parquet-hadoop-1.14.3.jar +parquet-jackson/1.14.3//parquet-jackson-1.14.3.jar pickle/1.5//pickle-1.5.jar py4j/0.10.9.7//py4j-0.10.9.7.jar remotetea-oncrpc/1.1.2//remotetea-oncrpc-1.1.2.jar rocksdbjni/9.5.2//rocksdbjni-9.5.2.jar scala-collection-compat_2.13/2.7.0//scala-collection-compat_2.13-2.7.0.jar -scala-compiler/2.13.14//scala-compiler-2.13.14.jar -scala-library/2.13.14//scala-library-2.13.14.jar +scala-compiler/2.13.15//scala-compiler-2.13.15.jar +scala-library/2.13.15//scala-library-2.13.15.jar scala-parallel-collections_2.13/1.0.4//scala-parallel-collections_2.13-1.0.4.jar scala-parser-combinators_2.13/2.4.0//scala-parser-combinators_2.13-2.4.0.jar -scala-reflect/2.13.14//scala-reflect-2.13.14.jar +scala-reflect/2.13.15//scala-reflect-2.13.15.jar scala-xml_2.13/2.3.0//scala-xml_2.13-2.3.0.jar slf4j-api/2.0.16//slf4j-api-2.0.16.jar snakeyaml-engine/2.7//snakeyaml-engine-2.7.jar @@ -274,10 +274,10 @@ tink/1.15.0//tink-1.15.0.jar transaction-api/1.1//transaction-api-1.1.jar univocity-parsers/2.9.1//univocity-parsers-2.9.1.jar wildfly-openssl/1.1.3.Final//wildfly-openssl-1.1.3.Final.jar -xbean-asm9-shaded/4.25//xbean-asm9-shaded-4.25.jar +xbean-asm9-shaded/4.26//xbean-asm9-shaded-4.26.jar xmlschema-core/2.3.1//xmlschema-core-2.3.1.jar xz/1.10//xz-1.10.jar zjsonpatch/0.3.0//zjsonpatch-0.3.0.jar zookeeper-jute/3.9.2//zookeeper-jute-3.9.2.jar zookeeper/3.9.2//zookeeper-3.9.2.jar -zstd-jni/1.5.6-5//zstd-jni-1.5.6-5.jar +zstd-jni/1.5.6-6//zstd-jni-1.5.6-6.jar diff --git a/dev/infra/Dockerfile b/dev/infra/Dockerfile index 5939e429b2f35..1619b009e9364 100644 --- a/dev/infra/Dockerfile +++ b/dev/infra/Dockerfile @@ -24,7 +24,7 @@ LABEL org.opencontainers.image.ref.name="Apache Spark Infra Image" # Overwrite this label to avoid exposing the underlying Ubuntu OS version label LABEL org.opencontainers.image.version="" -ENV FULL_REFRESH_DATE 20240903 +ENV FULL_REFRESH_DATE 20241002 ENV DEBIAN_FRONTEND noninteractive ENV DEBCONF_NONINTERACTIVE_SEEN true @@ -91,10 +91,10 @@ RUN mkdir -p /usr/local/pypy/pypy3.9 && \ ln -sf /usr/local/pypy/pypy3.9/bin/pypy /usr/local/bin/pypy3.9 && \ ln -sf /usr/local/pypy/pypy3.9/bin/pypy /usr/local/bin/pypy3 RUN curl -sS https://bootstrap.pypa.io/get-pip.py | pypy3 -RUN pypy3 -m pip install 'numpy==1.26.4' 'six==1.16.0' 'pandas==2.2.2' scipy coverage matplotlib lxml +RUN pypy3 -m pip install 'numpy==1.26.4' 'six==1.16.0' 'pandas==2.2.3' scipy coverage matplotlib lxml -ARG BASIC_PIP_PKGS="numpy==1.26.4 pyarrow>=15.0.0 six==1.16.0 pandas==2.2.2 scipy plotly>=4.8 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2" +ARG BASIC_PIP_PKGS="numpy==1.26.4 pyarrow>=15.0.0 six==1.16.0 pandas==2.2.3 scipy plotly>=4.8 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2" # Python deps for Spark Connect ARG CONNECT_PIP_PKGS="grpcio==1.62.0 grpcio-status==1.62.0 protobuf==4.25.1 googleapis-common-protos==1.56.4 graphviz==0.20.3" @@ -135,13 +135,22 @@ RUN apt-get update && apt-get install -y \ python3.12 \ && rm -rf /var/lib/apt/lists/* RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.12 -# TODO(SPARK-46647) Add unittest-xml-reporting into Python 3.12 image when it supports Python 3.12 RUN python3.12 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this -RUN python3.12 -m pip install $BASIC_PIP_PKGS $CONNECT_PIP_PKGS lxml && \ +RUN python3.12 -m pip install $BASIC_PIP_PKGS unittest-xml-reporting $CONNECT_PIP_PKGS lxml && \ python3.12 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu && \ python3.12 -m pip install torcheval && \ python3.12 -m pip cache purge +# Install Python 3.13 at the last stage to avoid breaking the existing Python installations +RUN apt-get update && apt-get install -y \ + python3.13 \ + && rm -rf /var/lib/apt/lists/* +RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.13 +# TODO(SPARK-49862) Add BASIC_PIP_PKGS and CONNECT_PIP_PKGS to Python 3.13 image when it supports Python 3.13 +RUN python3.13 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this +RUN python3.13 -m pip install lxml numpy>=2.1 && \ + python3.13 -m pip cache purge + # Remove unused installation packages to free up disk space RUN apt-get remove --purge -y 'gfortran-11' 'humanity-icon-theme' 'nodejs-doc' || true RUN apt-get autoremove --purge -y diff --git a/dev/py-cleanup b/dev/py-cleanup new file mode 100755 index 0000000000000..6a2edd1040171 --- /dev/null +++ b/dev/py-cleanup @@ -0,0 +1,31 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utility for temporary files cleanup in 'python'. +# usage: ./dev/py-cleanup + +set -ex + +SPARK_HOME="$(cd "`dirname $0`"/..; pwd)" +cd "$SPARK_HOME" + +rm -rf python/target +rm -rf python/lib/pyspark.zip +rm -rf python/docs/build +rm -rf python/docs/source/reference/*/api diff --git a/dev/requirements.txt b/dev/requirements.txt index 5486c98ab8f8f..cafc73405aaa8 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -7,7 +7,7 @@ pyarrow>=10.0.0 six==1.16.0 pandas>=2.0.0 scipy -plotly +plotly>=4.8 mlflow>=2.3.1 scikit-learn matplotlib diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 34fbb8450d544..d2c000b702a64 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -520,6 +520,7 @@ def __hash__(self): "pyspark.sql.tests.test_errors", "pyspark.sql.tests.test_functions", "pyspark.sql.tests.test_group", + "pyspark.sql.tests.test_sql", "pyspark.sql.tests.pandas.test_pandas_cogrouped_map", "pyspark.sql.tests.pandas.test_pandas_grouped_map", "pyspark.sql.tests.pandas.test_pandas_grouped_map_with_state", @@ -548,6 +549,9 @@ def __hash__(self): "pyspark.sql.tests.test_udtf", "pyspark.sql.tests.test_utils", "pyspark.sql.tests.test_resources", + "pyspark.sql.tests.plot.test_frame_plot", + "pyspark.sql.tests.plot.test_frame_plot_plotly", + "pyspark.sql.tests.test_connect_compatibility", ], ) @@ -1029,6 +1033,7 @@ def __hash__(self): "pyspark.sql.tests.connect.test_parity_serde", "pyspark.sql.tests.connect.test_parity_functions", "pyspark.sql.tests.connect.test_parity_group", + "pyspark.sql.tests.connect.test_parity_sql", "pyspark.sql.tests.connect.test_parity_dataframe", "pyspark.sql.tests.connect.test_parity_collection", "pyspark.sql.tests.connect.test_parity_creation", @@ -1051,6 +1056,8 @@ def __hash__(self): "pyspark.sql.tests.connect.test_parity_arrow_cogrouped_map", "pyspark.sql.tests.connect.test_parity_python_datasource", "pyspark.sql.tests.connect.test_parity_python_streaming_datasource", + "pyspark.sql.tests.connect.test_parity_frame_plot", + "pyspark.sql.tests.connect.test_parity_frame_plot_plotly", "pyspark.sql.tests.connect.test_utils", "pyspark.sql.tests.connect.client.test_artifact", "pyspark.sql.tests.connect.client.test_artifact_localcluster", diff --git a/docs/_config.yml b/docs/_config.yml index e74eda0470417..089d6bf2097b8 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -22,7 +22,7 @@ include: SPARK_VERSION: 4.0.0-SNAPSHOT SPARK_VERSION_SHORT: 4.0.0 SCALA_BINARY_VERSION: "2.13" -SCALA_VERSION: "2.13.14" +SCALA_VERSION: "2.13.15" SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK SPARK_GITHUB_URL: https://github.com/apache/spark # Before a new release, we should: diff --git a/docs/configuration.md b/docs/configuration.md index 73d57b687ca2a..3c83ed92c1280 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1232,6 +1232,19 @@ Apart from these, the following properties are also available, and may be useful 2.2.1 + + spark.shuffle.accurateBlockSkewedFactor + -1.0 + + A shuffle block is considered as skewed and will be accurately recorded in + HighlyCompressedMapStatus if its size is larger than this factor multiplying + the median shuffle block size or spark.shuffle.accurateBlockThreshold. It is + recommended to set this parameter to be the same as + spark.sql.adaptive.skewJoin.skewedPartitionFactor. Set to -1.0 to disable this + feature by default. + + 3.3.0 + spark.shuffle.registration.timeout 5000 diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index d8be32e047717..e7e9eb9ea4ffc 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -44,7 +44,7 @@ Cluster administrators should use [Pod Security Policies](https://kubernetes.io/ # Prerequisites -* A running Kubernetes cluster at version >= 1.28 with access configured to it using +* A running Kubernetes cluster at version >= 1.29 with access configured to it using [kubectl](https://kubernetes.io/docs/reference/kubectl/). If you do not already have a working Kubernetes cluster, you may set up a test cluster on your local machine using [minikube](https://kubernetes.io/docs/getting-started-guides/minikube/). @@ -1191,6 +1191,15 @@ See the [configuration page](configuration.html) for information on Spark config 4.0.0 + + spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].annotation.[AnnotationName] + (none) + + Configure Kubernetes Volume annotations passed to the Kubernetes with AnnotationName as key having specified value, must conform with Kubernetes annotations format. For example, + spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.annotation.foo=bar. + + 4.0.0 + spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].mount.path (none) @@ -1236,6 +1245,15 @@ See the [configuration page](configuration.html) for information on Spark config 4.0.0 + + spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].annotation.[AnnotationName] + (none) + + Configure Kubernetes Volume annotations passed to the Kubernetes with AnnotationName as key having specified value, must conform with Kubernetes annotations format. For example, + spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.annotation.foo=bar. + + 4.0.0 + spark.kubernetes.local.dirs.tmpfs false diff --git a/docs/security.md b/docs/security.md index b97abfeacf240..c7d3fd5f8c36f 100644 --- a/docs/security.md +++ b/docs/security.md @@ -947,7 +947,7 @@ mechanism (see `java.util.ServiceLoader`). Implementations of `org.apache.spark.security.HadoopDelegationTokenProvider` can be made available to Spark by listing their names in the corresponding file in the jar's `META-INF/services` directory. -Delegation token support is currently only supported in YARN mode. Consult the +Delegation token support is currently only supported in YARN and Kubernetes mode. Consult the deployment-specific page for more information. The following options provides finer-grained control for this feature: diff --git a/docs/sql-data-sources-avro.md b/docs/sql-data-sources-avro.md index 3721f92d93266..c06e1fd46d2da 100644 --- a/docs/sql-data-sources-avro.md +++ b/docs/sql-data-sources-avro.md @@ -353,6 +353,13 @@ Data source options of Avro can be set via: read 4.0.0 + + recursiveFieldMaxDepth + -1 + If this option is specified to negative or is set to 0, recursive fields are not permitted. Setting it to 1 drops all recursive fields, 2 allows recursive fields to be recursed once, and 3 allows it to be recursed twice and so on, up to 15. Values larger than 15 are not allowed in order to avoid inadvertently creating very large schemas. If an avro message has depth beyond this limit, the Spark struct returned is truncated after the recursion limit. An example of usage can be found in section Handling circular references of Avro fields + read + 4.0.0 + ## Configuration @@ -628,3 +635,41 @@ You can also specify the whole output Avro schema with the option `avroSchema`, decimal + +## Handling circular references of Avro fields +In Avro, a circular reference occurs when the type of a field is defined in one of the parent records. This can cause issues when parsing the data, as it can result in infinite loops or other unexpected behavior. +To read Avro data with schema that has circular reference, users can use the `recursiveFieldMaxDepth` option to specify the maximum number of levels of recursion to allow when parsing the schema. By default, Spark Avro data source will not permit recursive fields by setting `recursiveFieldMaxDepth` to -1. However, you can set this option to 1 to 15 if needed. + +Setting `recursiveFieldMaxDepth` to 1 drops all recursive fields, setting it to 2 allows it to be recursed once, and setting it to 3 allows it to be recursed twice. A `recursiveFieldMaxDepth` value greater than 15 is not allowed, as it can lead to performance issues and even stack overflows. + +SQL Schema for the below Avro message will vary based on the value of `recursiveFieldMaxDepth`. + +
+
+This div is only used to make markdown editor/viewer happy and does not display on web + +```avro +
+ +{% highlight avro %} +{ + "type": "record", + "name": "Node", + "fields": [ + {"name": "Id", "type": "int"}, + {"name": "Next", "type": ["null", "Node"]} + ] +} + +// The Avro schema defined above, would be converted into a Spark SQL columns with the following +// structure based on `recursiveFieldMaxDepth` value. + +1: struct +2: struct> +3: struct>> + +{% endhighlight %} +
+``` +
+
diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index fff6906457f7d..b4446b1538cd6 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -426,6 +426,7 @@ Below is a list of all the keywords in Spark SQL. |BY|non-reserved|non-reserved|reserved| |BYTE|non-reserved|non-reserved|non-reserved| |CACHE|non-reserved|non-reserved|non-reserved| +|CALL|reserved|non-reserved|reserved| |CALLED|non-reserved|non-reserved|non-reserved| |CASCADE|non-reserved|non-reserved|non-reserved| |CASE|reserved|non-reserved|reserved| @@ -580,6 +581,7 @@ Below is a list of all the keywords in Spark SQL. |LOCKS|non-reserved|non-reserved|non-reserved| |LOGICAL|non-reserved|non-reserved|non-reserved| |LONG|non-reserved|non-reserved|non-reserved| +|LOOP|non-reserved|non-reserved|non-reserved| |MACRO|non-reserved|non-reserved|non-reserved| |MAP|non-reserved|non-reserved|non-reserved| |MATCHED|non-reserved|non-reserved|non-reserved| diff --git a/docs/sql-ref-functions-builtin.md b/docs/sql-ref-functions-builtin.md index c5f4e44dec0d9..b6572609a34b8 100644 --- a/docs/sql-ref-functions-builtin.md +++ b/docs/sql-ref-functions-builtin.md @@ -116,3 +116,13 @@ license: | {% include_api_gen generated-generator-funcs-table.html %} #### Examples {% include_api_gen generated-generator-funcs-examples.html %} + +### Table Functions +{% include_api_gen generated-table-funcs-table.html %} +#### Examples +{% include_api_gen generated-table-funcs-examples.html %} + +### Variant Functions +{% include_api_gen generated-variant-funcs-table.html %} +#### Examples +{% include_api_gen generated-variant-funcs-examples.html %} diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java index 0c11c40cfe7ed..1052f47ea496e 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.regex.Pattern; +import com.google.common.io.FileWriteMode; import scala.Tuple2; import com.google.common.io.Files; @@ -152,7 +153,8 @@ private static JavaStreamingContext createContext(String ip, System.out.println(output); System.out.println("Dropped " + droppedWordsCounter.value() + " word(s) totally"); System.out.println("Appending to " + outputFile.getAbsolutePath()); - Files.append(output + "\n", outputFile, Charset.defaultCharset()); + Files.asCharSink(outputFile, Charset.defaultCharset(), FileWriteMode.APPEND) + .write(output + "\n"); }); return ssc; diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index 98539d6494231..1ec6ee4abd327 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -21,7 +21,7 @@ package org.apache.spark.examples.streaming import java.io.File import java.nio.charset.Charset -import com.google.common.io.Files +import com.google.common.io.{Files, FileWriteMode} import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.broadcast.Broadcast @@ -134,7 +134,8 @@ object RecoverableNetworkWordCount { println(output) println(s"Dropped ${droppedWordsCounter.value} word(s) totally") println(s"Appending to ${outputFile.getAbsolutePath}") - Files.append(output + "\n", outputFile, Charset.defaultCharset()) + Files.asCharSink(outputFile, Charset.defaultCharset(), FileWriteMode.APPEND) + .write(output + "\n") } ssc } diff --git a/launcher/pom.xml b/launcher/pom.xml index c47244ff887a6..e8feb7b684555 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -57,6 +57,16 @@ mockito-core test
+ + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.slf4j jul-to-slf4j diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index ecfe45f046f2b..3b35a481adb1b 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -52,6 +52,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.apache.spark spark-tags_${scala.binary.version} diff --git a/mllib/pom.xml b/mllib/pom.xml index 4f983a325a0c1..c342519ca428a 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -117,6 +117,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.apache.spark spark-streaming_${scala.binary.version} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 30f3e4c4af021..5486c39034fd3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -204,7 +204,7 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) val inputType = try { SchemaUtils.getSchemaFieldType(schema, inputColName) } catch { - case e: SparkIllegalArgumentException if e.getErrorClass == "FIELD_NOT_FOUND" => + case e: SparkIllegalArgumentException if e.getCondition == "FIELD_NOT_FOUND" => throw new SparkException(s"Input column $inputColName does not exist.") case e: Exception => throw e diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 8e64f60427d90..20b03edf23c4a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -127,7 +127,7 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi validateAndTransformField(schema, inputColName, dtype, outputColName) ) } catch { - case e: SparkIllegalArgumentException if e.getErrorClass == "FIELD_NOT_FOUND" => + case e: SparkIllegalArgumentException if e.getCondition == "FIELD_NOT_FOUND" => if (skipNonExistsCol) { None } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 1a004f71749e1..5899bf891ec9d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -517,7 +517,7 @@ class ALSModel private[ml] ( ) ratings.groupBy(srcOutputColumn) - .agg(collect_top_k(struct(ratingColumn, dstOutputColumn), num, false)) + .agg(ALSModel.collect_top_k(struct(ratingColumn, dstOutputColumn), num, false)) .as[(Int, Seq[(Float, Int)])] .map(t => (t._1, t._2.map(p => (p._2, p._1)))) .toDF(srcOutputColumn, recommendColumn) @@ -546,6 +546,9 @@ object ALSModel extends MLReadable[ALSModel] { private val Drop = "drop" private[recommendation] final val supportedColdStartStrategies = Array(NaN, Drop) + private[recommendation] def collect_top_k(e: Column, num: Int, reverse: Boolean): Column = + Column.internalFn("collect_top_k", e, lit(num), lit(reverse)) + @Since("1.6.0") override def read: MLReader[ALSModel] = new ALSModelReader diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 3b306eff99689..ff132e2a29a89 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -17,12 +17,13 @@ package org.apache.spark.ml.util +import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.VectorUDT -import org.apache.spark.sql.catalyst.util.AttributeNameParser +import org.apache.spark.sql.catalyst.util.{AttributeNameParser, QuotingUtils} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ - /** * Utils for handling schemas. */ @@ -206,6 +207,10 @@ private[spark] object SchemaUtils { checkColumnTypes(schema, colName, typeCandidates) } + def toSQLId(parts: String): String = { + AttributeNameParser.parseAttributeName(parts).map(QuotingUtils.quoteIdentifier).mkString(".") + } + /** * Get schema field. * @param schema input schema @@ -213,11 +218,16 @@ private[spark] object SchemaUtils { */ def getSchemaField(schema: StructType, colName: String): StructField = { val colSplits = AttributeNameParser.parseAttributeName(colName) - var field = schema(colSplits(0)) - for (colSplit <- colSplits.slice(1, colSplits.length)) { - field = field.dataType.asInstanceOf[StructType](colSplit) + val fieldOpt = schema.findNestedField(colSplits, resolver = SQLConf.get.resolver) + if (fieldOpt.isEmpty) { + throw new SparkIllegalArgumentException( + errorClass = "FIELD_NOT_FOUND", + messageParameters = Map( + "fieldName" -> toSQLId(colName), + "fields" -> schema.fields.map(f => toSQLId(f.name)).mkString(", ")) + ) } - field + fieldOpt.get._2 } /** diff --git a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java index c3038fa9e1f8f..5f0d22ea2a8aa 100644 --- a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java @@ -50,7 +50,7 @@ public void setUp() throws IOException { tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource"); File file = new File(tempDir, "part-00000"); String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0"; - Files.write(s, file, StandardCharsets.UTF_8); + Files.asCharSink(file, StandardCharsets.UTF_8).write(s); path = tempDir.toURI().toString(); } diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/CollectTopKSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/CollectTopKSuite.scala index b79e10d0d267e..bd83d5498ae6f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/CollectTopKSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/CollectTopKSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.ml.recommendation +import org.apache.spark.ml.recommendation.ALSModel.collect_top_k import org.apache.spark.ml.util.MLTest import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.functions.{col, collect_top_k, struct} +import org.apache.spark.sql.functions.{col, struct} class CollectTopKSuite extends MLTest { diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index f2bb145614725..6a0d7b1237ee4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -65,9 +65,9 @@ class LibSVMRelationSuite val succ = new File(dir, "_SUCCESS") val file0 = new File(dir, "part-00000") val file1 = new File(dir, "part-00001") - Files.write("", succ, StandardCharsets.UTF_8) - Files.write(lines0, file0, StandardCharsets.UTF_8) - Files.write(lines1, file1, StandardCharsets.UTF_8) + Files.asCharSink(succ, StandardCharsets.UTF_8).write("") + Files.asCharSink(file0, StandardCharsets.UTF_8).write(lines0) + Files.asCharSink(file1, StandardCharsets.UTF_8).write(lines1) path = dir.getPath } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index a90c9c80d4959..1a02e26b9260c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -93,7 +93,7 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { """.stripMargin val tempDir = Utils.createTempDir() val file = new File(tempDir.getPath, "part-00000") - Files.write(lines, file, StandardCharsets.UTF_8) + Files.asCharSink(file, StandardCharsets.UTF_8).write(lines) val path = tempDir.toURI.toString val pointsWithNumFeatures = loadLibSVMFile(sc, path, 6).collect() @@ -126,7 +126,7 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { """.stripMargin val tempDir = Utils.createTempDir() val file = new File(tempDir.getPath, "part-00000") - Files.write(lines, file, StandardCharsets.UTF_8) + Files.asCharSink(file, StandardCharsets.UTF_8).write(lines) val path = tempDir.toURI.toString intercept[SparkException] { @@ -143,7 +143,7 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { """.stripMargin val tempDir = Utils.createTempDir() val file = new File(tempDir.getPath, "part-00000") - Files.write(lines, file, StandardCharsets.UTF_8) + Files.asCharSink(file, StandardCharsets.UTF_8).write(lines) val path = tempDir.toURI.toString intercept[SparkException] { diff --git a/pom.xml b/pom.xml index 694ea31e6f377..2b89454873782 100644 --- a/pom.xml +++ b/pom.xml @@ -84,6 +84,7 @@ common/utils common/variant common/tags + sql/connect/shims core graphx mllib @@ -118,13 +119,13 @@ 3.9.9 3.2.0 spark - 9.7 + 9.7.1 2.0.16 - 2.22.1 + 2.24.1 3.4.0 - 3.25.4 + 3.25.5 3.11.4 ${hadoop.version} 3.9.2 @@ -137,7 +138,7 @@ 3.8.0 10.16.1.1 - 1.14.2 + 1.14.3 2.0.2 shaded-protobuf 11.0.23 @@ -151,7 +152,7 @@ If you change codahale.metrics.version, you also need to change the link to metrics.dropwizard.io in docs/monitoring.md. --> - 4.2.27 + 4.2.28 1.12.0 1.12.0 @@ -161,7 +162,7 @@ 0.12.8 - hadoop3-2.2.21 + hadoop3-2.2.25 4.5.14 4.4.16 @@ -169,7 +170,7 @@ 3.2.2 4.4 - 2.13.14 + 2.13.15 2.13 2.2.0 4.9.1 @@ -180,14 +181,14 @@ true true 1.9.13 - 2.17.2 - 2.17.2 + 2.18.0 + 2.18.0 2.3.1 1.1.10.7 3.0.3 1.17.1 1.27.1 - 2.16.1 + 2.17.0 2.6 @@ -226,12 +227,12 @@ and ./python/packaging/connect/setup.py too. --> 17.0.0 - 3.0.0-M2 + 3.0.0 0.12.6 org.fusesource.leveldbjni - 6.13.3 + 6.13.4 1.17.6 ${java.home} @@ -490,7 +491,7 @@ org.apache.xbean xbean-asm9-shaded - 4.25 + 4.26 -Wconf:cat=deprecation&msg=it will become a keyword in Scala 3:e + + -Wconf:cat=deprecation&msg=method getErrorClass in trait SparkThrowable is deprecated:e -Xss128m diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index dfe7b14e2ec66..f31a29788aafe 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -125,26 +125,6 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Observation"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Observation$"), - // SPARK-49414: Remove Logging from DataFrameReader. - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.DataFrameReader"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logName"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.log"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logInfo"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logDebug"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logTrace"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logWarning"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logError"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logInfo"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logDebug"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logTrace"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logWarning"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logError"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.isTraceEnabled"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.initializeLogIfNecessary"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.initializeLogIfNecessary"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.initializeLogIfNecessary$default$2"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.initializeForcefully"), - // SPARK-49425: Create a shared DataFrameWriter interface. ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameWriter"), @@ -183,7 +163,44 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$QueryStartedEvent"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminatedEvent"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryStatus"), - ) + + // SPARK-49415: Shared SQLImplicits. + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DatasetHolder"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DatasetHolder$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.LowPrioritySQLImplicits"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.SQLContext$implicits$"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.SQLImplicits"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLImplicits.StringToColumn"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.this"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLImplicits$StringToColumn"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.SparkSession$implicits$"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.SQLImplicits.session"), + + // SPARK-49282: Shared SparkSessionBuilder + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.SparkSession$Builder"), + + // SPARK-49286: Avro/Protobuf functions in sql/api + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.avro.functions"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.avro.functions$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.protobuf.functions"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.protobuf.functions$"), + + // SPARK-49434: Move aggregators to sql/api + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.javalang.typed"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.scalalang.typed"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.scalalang.typed$"), + + // SPARK-49418: Consolidate thread local handling in sql/api + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.SparkSession.setActiveSession"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.SparkSession.setDefaultSession"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.api.SparkSessionCompanion.clearActiveSession"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.api.SparkSessionCompanion.clearDefaultSession"), + + // SPARK-49748: Add getCondition and deprecate getErrorClass in SparkThrowable + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkThrowable.getCondition"), + ) ++ loggingExcludes("org.apache.spark.sql.DataFrameReader") ++ + loggingExcludes("org.apache.spark.sql.streaming.DataStreamReader") ++ + loggingExcludes("org.apache.spark.sql.SparkSession#Builder") // Default exclude rules lazy val defaultExcludes = Seq( @@ -201,6 +218,8 @@ object MimaExcludes { ProblemFilters.exclude[Problem]("org.apache.spark.sql.execution.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.internal.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.errors.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.classic.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.connect.*"), // DSv2 catalog and expression APIs are unstable yet. We should enable this back. ProblemFilters.exclude[Problem]("org.apache.spark.sql.connector.catalog.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.connector.expressions.*"), @@ -222,6 +241,26 @@ object MimaExcludes { } ) + private def loggingExcludes(fqn: String) = { + Seq( + ProblemFilters.exclude[MissingTypesProblem](fqn), + missingMethod(fqn, "logName"), + missingMethod(fqn, "log"), + missingMethod(fqn, "logInfo"), + missingMethod(fqn, "logDebug"), + missingMethod(fqn, "logTrace"), + missingMethod(fqn, "logWarning"), + missingMethod(fqn, "logError"), + missingMethod(fqn, "isTraceEnabled"), + missingMethod(fqn, "initializeLogIfNecessary"), + missingMethod(fqn, "initializeLogIfNecessary$default$2"), + missingMethod(fqn, "initializeForcefully")) + } + + private def missingMethod(names: String*) = { + ProblemFilters.exclude[DirectMissingMethodProblem](names.mkString(".")) + } + def excludes(version: String): Seq[Problem => Boolean] = version match { case v if v.startsWith("4.0") => v40excludes case _ => Seq() diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 4a8214b2e20a3..737efa8f7846b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -45,24 +45,24 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, tokenProviderKafka010, sqlKafka010, avro, protobuf) = Seq( - "catalyst", "sql", "hive", "hive-thriftserver", "token-provider-kafka-0-10", "sql-kafka-0-10", "avro", "protobuf" - ).map(ProjectRef(buildLocation, _)) + val sqlProjects@Seq(sqlApi, catalyst, sql, hive, hiveThriftServer, tokenProviderKafka010, sqlKafka010, avro, protobuf) = + Seq("sql-api", "catalyst", "sql", "hive", "hive-thriftserver", "token-provider-kafka-0-10", + "sql-kafka-0-10", "avro", "protobuf").map(ProjectRef(buildLocation, _)) val streamingProjects@Seq(streaming, streamingKafka010) = Seq("streaming", "streaming-kafka-0-10").map(ProjectRef(buildLocation, _)) - val connectCommon = ProjectRef(buildLocation, "connect-common") - val connect = ProjectRef(buildLocation, "connect") - val connectClient = ProjectRef(buildLocation, "connect-client-jvm") + val connectProjects@Seq(connectCommon, connect, connectClient, connectShims) = + Seq("connect-common", "connect", "connect-client-jvm", "connect-shims") + .map(ProjectRef(buildLocation, _)) val allProjects@Seq( core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, tags, sketch, kvstore, - commonUtils, sqlApi, variant, _* + commonUtils, variant, _* ) = Seq( "core", "graphx", "mllib", "mllib-local", "repl", "network-common", "network-shuffle", "launcher", "unsafe", - "tags", "sketch", "kvstore", "common-utils", "sql-api", "variant" - ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects ++ Seq(connectCommon, connect, connectClient) + "tags", "sketch", "kvstore", "common-utils", "variant" + ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects ++ connectProjects val optionallyEnabledProjects@Seq(kubernetes, yarn, sparkGangliaLgpl, streamingKinesisAsl, @@ -89,7 +89,7 @@ object BuildCommons { // Google Protobuf version used for generating the protobuf. // SPARK-41247: needs to be consistent with `protobuf.version` in `pom.xml`. - val protoVersion = "3.25.4" + val protoVersion = "3.25.5" // GRPC version used for Spark Connect. val grpcVersion = "1.62.2" } @@ -234,7 +234,11 @@ object SparkBuild extends PomBuild { // replace -Xfatal-warnings with fine-grained configuration, since 2.13.2 // verbose warning on deprecation, error on all others // see `scalac -Wconf:help` for details - "-Wconf:cat=deprecation:wv,any:e", + // since 2.13.15, "-Wconf:cat=deprecation:wv,any:e" no longer takes effect and needs to + // be changed to "-Wconf:any:e", "-Wconf:cat=deprecation:wv", + // please refer to the details: https://github.com/scala/scala/pull/10708 + "-Wconf:any:e", + "-Wconf:cat=deprecation:wv", // 2.13-specific warning hits to be muted (as narrowly as possible) and addressed separately "-Wunused:imports", "-Wconf:msg=^(?=.*?method|value|type|object|trait|inheritance)(?=.*?deprecated)(?=.*?since 2.13).+$:e", @@ -250,7 +254,9 @@ object SparkBuild extends PomBuild { // reduce the cost of migration in subsequent versions. "-Wconf:cat=deprecation&msg=it will become a keyword in Scala 3:e", // SPARK-46938 to prevent enum scan on pmml-model, under spark-mllib module. - "-Wconf:cat=other&site=org.dmg.pmml.*:w" + "-Wconf:cat=other&site=org.dmg.pmml.*:w", + // SPARK-49937 ban call the method `SparkThrowable#getErrorClass` + "-Wconf:cat=deprecation&msg=method getErrorClass in trait SparkThrowable is deprecated:e" ) } ) @@ -356,7 +362,7 @@ object SparkBuild extends PomBuild { /* Enable shared settings on all projects */ (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ copyJarsProjects ++ Seq(spark, tools)) .foreach(enable(sharedSettings ++ DependencyOverrides.settings ++ - ExcludedDependencies.settings ++ Checkstyle.settings)) + ExcludedDependencies.settings ++ Checkstyle.settings ++ ExcludeShims.settings)) /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) @@ -365,7 +371,7 @@ object SparkBuild extends PomBuild { Seq( spark, hive, hiveThriftServer, repl, networkCommon, networkShuffle, networkYarn, unsafe, tags, tokenProviderKafka010, sqlKafka010, connectCommon, connect, connectClient, - variant + variant, connectShims ).contains(x) } @@ -1083,6 +1089,36 @@ object ExcludedDependencies { ) } +/** + * This excludes the spark-connect-shims module from a module when it is not part of the connect + * client dependencies. + */ +object ExcludeShims { + val shimmedProjects = Set("spark-sql-api", "spark-connect-common", "spark-connect-client-jvm") + val classPathFilter = TaskKey[Classpath => Classpath]("filter for classpath") + lazy val settings = Seq( + classPathFilter := { + if (!shimmedProjects(moduleName.value)) { + cp => cp.filterNot(_.data.name.contains("spark-connect-shims")) + } else { + identity _ + } + }, + Compile / internalDependencyClasspath := + classPathFilter.value((Compile / internalDependencyClasspath).value), + Compile / internalDependencyAsJars := + classPathFilter.value((Compile / internalDependencyAsJars).value), + Runtime / internalDependencyClasspath := + classPathFilter.value((Runtime / internalDependencyClasspath).value), + Runtime / internalDependencyAsJars := + classPathFilter.value((Runtime / internalDependencyAsJars).value), + Test / internalDependencyClasspath := + classPathFilter.value((Test / internalDependencyClasspath).value), + Test / internalDependencyAsJars := + classPathFilter.value((Test / internalDependencyAsJars).value), + ) +} + /** * Project to pull previous artifacts of Spark for generating Mima excludes. */ @@ -1202,7 +1238,7 @@ object YARN { genConfigProperties := { val file = (Compile / classDirectory).value / s"org/apache/spark/deploy/yarn/$propFileName" val isHadoopProvided = SbtPomKeys.effectivePom.value.getProperties.get(hadoopProvidedProp) - IO.write(file, s"$hadoopProvidedProp = $isHadoopProvided") + sbt.IO.write(file, s"$hadoopProvidedProp = $isHadoopProvided") }, Compile / copyResources := (Def.taskDyn { val c = (Compile / copyResources).value @@ -1352,6 +1388,7 @@ trait SharedUnidocSettings { .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/kvstore"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/catalyst"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/connect/"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/classic/"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/execution"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/hive"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/catalog/v2/utils"))) @@ -1451,10 +1488,12 @@ object SparkUnidoc extends SharedUnidocSettings { lazy val settings = baseSettings ++ Seq( (ScalaUnidoc / unidoc / unidocProjectFilter) := inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, kubernetes, - yarn, tags, streamingKafka010, sqlKafka010, connectCommon, connect, connectClient, protobuf), + yarn, tags, streamingKafka010, sqlKafka010, connectCommon, connect, connectClient, + connectShims, protobuf), (JavaUnidoc / unidoc / unidocProjectFilter) := inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, kubernetes, - yarn, tags, streamingKafka010, sqlKafka010, connectCommon, connect, connectClient, protobuf), + yarn, tags, streamingKafka010, sqlKafka010, connectCommon, connect, connectClient, + connectShims, protobuf), ) } diff --git a/project/plugins.sbt b/project/plugins.sbt index 151af24440c05..67d739452d8da 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -36,9 +36,9 @@ addSbtPlugin("com.github.sbt" % "sbt-unidoc" % "0.5.0") addSbtPlugin("io.spray" % "sbt-revolver" % "0.10.0") -libraryDependencies += "org.ow2.asm" % "asm" % "9.7" +libraryDependencies += "org.ow2.asm" % "asm" % "9.7.1" -libraryDependencies += "org.ow2.asm" % "asm-commons" % "9.7" +libraryDependencies += "org.ow2.asm" % "asm-commons" % "9.7.1" addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.8.3") diff --git a/python/docs/source/getting_started/install.rst b/python/docs/source/getting_started/install.rst index 549656bea103e..88c0a8c26cc94 100644 --- a/python/docs/source/getting_started/install.rst +++ b/python/docs/source/getting_started/install.rst @@ -183,6 +183,7 @@ Package Supported version Note Additional libraries that enhance functionality but are not included in the installation packages: - **memory-profiler**: Used for PySpark UDF memory profiling, ``spark.profile.show(...)`` and ``spark.sql.pyspark.udf.profiler``. +- **plotly**: Used for PySpark plotting, ``DataFrame.plot``. Note that PySpark requires Java 17 or later with ``JAVA_HOME`` properly set and refer to |downloading|_. diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst index 4910a5b59273b..6248e71331656 100644 --- a/python/docs/source/reference/pyspark.sql/functions.rst +++ b/python/docs/source/reference/pyspark.sql/functions.rst @@ -148,6 +148,7 @@ Mathematical Functions try_multiply try_subtract unhex + uniform width_bucket @@ -189,6 +190,7 @@ String Functions overlay position printf + randstr regexp_count regexp_extract regexp_extract_all diff --git a/python/packaging/classic/setup.py b/python/packaging/classic/setup.py index 79b74483f00dd..76fd638c4aa03 100755 --- a/python/packaging/classic/setup.py +++ b/python/packaging/classic/setup.py @@ -288,6 +288,7 @@ def run(self): "pyspark.sql.connect.streaming.worker", "pyspark.sql.functions", "pyspark.sql.pandas", + "pyspark.sql.plot", "pyspark.sql.protobuf", "pyspark.sql.streaming", "pyspark.sql.worker", @@ -373,6 +374,7 @@ def run(self): "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", "Typing :: Typed", diff --git a/python/packaging/connect/setup.py b/python/packaging/connect/setup.py index ab166c79747df..6ae16e9a9ad3a 100755 --- a/python/packaging/connect/setup.py +++ b/python/packaging/connect/setup.py @@ -77,6 +77,7 @@ "pyspark.sql.tests.connect.client", "pyspark.sql.tests.connect.shell", "pyspark.sql.tests.pandas", + "pyspark.sql.tests.plot", "pyspark.sql.tests.streaming", "pyspark.ml.tests.connect", "pyspark.pandas.tests", @@ -161,6 +162,7 @@ "pyspark.sql.connect.streaming.worker", "pyspark.sql.functions", "pyspark.sql.pandas", + "pyspark.sql.plot", "pyspark.sql.protobuf", "pyspark.sql.streaming", "pyspark.sql.worker", diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index 4061d024a83cd..ab01d386645b2 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -94,9 +94,9 @@ "Could not get batch id from ." ] }, - "CANNOT_INFER_ARRAY_TYPE": { + "CANNOT_INFER_ARRAY_ELEMENT_TYPE": { "message": [ - "Can not infer Array Type from a list with None as the first element." + "Can not infer the element data type, an non-empty list starting with an non-None value is required." ] }, "CANNOT_INFER_EMPTY_SCHEMA": { @@ -802,11 +802,21 @@ " >= must be installed; however, it was not found." ] }, + "PANDAS_UDF_OUTPUT_EXCEEDS_INPUT_ROWS" : { + "message": [ + "The Pandas SCALAR_ITER UDF outputs more rows than input rows." + ] + }, "PIPE_FUNCTION_EXITED": { "message": [ "Pipe function `` exited with error code ." ] }, + "PLOT_NOT_NUMERIC_COLUMN": { + "message": [ + "Argument must be a numerical column for plotting, got ." + ] + }, "PYTHON_HASH_SEED_NOT_SET": { "message": [ "Randomness of hash of string should be disabled via PYTHONHASHSEED." @@ -1088,6 +1098,16 @@ "Function `` should use only POSITIONAL or POSITIONAL OR KEYWORD arguments." ] }, + "UNSUPPORTED_PLOT_BACKEND": { + "message": [ + "`` is not supported, it should be one of the values from " + ] + }, + "UNSUPPORTED_PLOT_BACKEND_PARAM": { + "message": [ + "`` does not support `` set to , it should be one of the values from " + ] + }, "UNSUPPORTED_SIGNATURE": { "message": [ "Unsupported signature: ." diff --git a/python/pyspark/install.py b/python/pyspark/install.py index 90b0150b0a8ca..ba67a157e964d 100644 --- a/python/pyspark/install.py +++ b/python/pyspark/install.py @@ -163,7 +163,7 @@ def install_spark(dest, spark_version, hadoop_version, hive_version): tar.close() if os.path.exists(package_local_path): os.remove(package_local_path) - raise IOError("Unable to download %s." % pretty_pkg_name) + raise OSError("Unable to download %s." % pretty_pkg_name) def get_preferred_mirrors(): diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index cba4219a0694b..72fcfccf19e4c 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -213,7 +213,6 @@ class FPGrowth( | [q]| 2| +---------+----+ only showing top 5 rows - ... >>> fpm.associationRules.sort("antecedent", "consequent").show(5) +----------+----------+----------+----+------------------+ |antecedent|consequent|confidence|lift| support| @@ -225,7 +224,6 @@ class FPGrowth( | [q]| [t]| 1.0| 2.0|0.3333333333333333| +----------+----------+----------+----+------------------+ only showing top 5 rows - ... >>> new_data = spark.createDataFrame([(["t", "s"], )], ["items"]) >>> sorted(fpm.transform(new_data).first().newPrediction) ['x', 'y', 'z'] diff --git a/python/pyspark/pandas/data_type_ops/datetime_ops.py b/python/pyspark/pandas/data_type_ops/datetime_ops.py index 9b4cc72fa2e45..dc2f68232e730 100644 --- a/python/pyspark/pandas/data_type_ops/datetime_ops.py +++ b/python/pyspark/pandas/data_type_ops/datetime_ops.py @@ -34,6 +34,7 @@ ) from pyspark.sql.utils import pyspark_column_op from pyspark.pandas._typing import Dtype, IndexOpsLike, SeriesOrIndex +from pyspark.pandas.spark import functions as SF from pyspark.pandas.base import IndexOpsMixin from pyspark.pandas.data_type_ops.base import ( DataTypeOps, @@ -150,10 +151,7 @@ class DatetimeNTZOps(DatetimeOps): """ def _cast_spark_column_timestamp_to_long(self, scol: Column) -> Column: - from pyspark import SparkContext - - jvm = SparkContext._active_spark_context._jvm - return Column(jvm.PythonSQLUtils.castTimestampNTZToLong(scol._jc)) + return SF.timestamp_ntz_to_long(scol) def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike: dtype, spark_type = pandas_on_spark_type(dtype) diff --git a/python/pyspark/pandas/internal.py b/python/pyspark/pandas/internal.py index 4be345201ba65..6063641e22e3b 100644 --- a/python/pyspark/pandas/internal.py +++ b/python/pyspark/pandas/internal.py @@ -902,11 +902,10 @@ def attach_default_index( @staticmethod def attach_sequence_column(sdf: PySparkDataFrame, column_name: str) -> PySparkDataFrame: - scols = [scol_for(sdf, column) for column in sdf.columns] sequential_index = ( F.row_number().over(Window.orderBy(F.monotonically_increasing_id())).cast("long") - 1 ) - return sdf.select(sequential_index.alias(column_name), *scols) + return sdf.select(sequential_index.alias(column_name), "*") @staticmethod def attach_distributed_column(sdf: PySparkDataFrame, column_name: str) -> PySparkDataFrame: diff --git a/python/pyspark/pandas/plot/core.py b/python/pyspark/pandas/plot/core.py index 7630ecc398954..7333fae1ad432 100644 --- a/python/pyspark/pandas/plot/core.py +++ b/python/pyspark/pandas/plot/core.py @@ -215,7 +215,7 @@ def compute_hist(psdf, bins): # refers to org.apache.spark.ml.feature.Bucketizer#binarySearchForBuckets def binary_search_for_buckets(value: Column): - index = SF.binary_search(F.lit(bins), value) + index = SF.array_binary_search(F.lit(bins), value) bucket = F.when(index >= 0, index).otherwise(-index - 2) unboundErrMsg = F.lit(f"value %s out of the bins bounds: [{bins[0]}, {bins[-1]}]") return ( @@ -479,7 +479,7 @@ class PandasOnSparkPlotAccessor(PandasObject): "pie": TopNPlotBase().get_top_n, "bar": TopNPlotBase().get_top_n, "barh": TopNPlotBase().get_top_n, - "scatter": TopNPlotBase().get_top_n, + "scatter": SampledPlotBase().get_sampled, "area": SampledPlotBase().get_sampled, "line": SampledPlotBase().get_sampled, } @@ -756,10 +756,10 @@ def barh(self, x=None, y=None, **kwargs): Parameters ---------- - x : label or position, default DataFrame.index - Column to be used for categories. - y : label or position, default All numeric columns in dataframe + x : label or position, default All numeric columns in dataframe Columns to be plotted from the DataFrame. + y : label or position, default DataFrame.index + Column to be used for categories. **kwds Keyword arguments to pass on to :meth:`pyspark.pandas.DataFrame.plot` or :meth:`pyspark.pandas.Series.plot`. @@ -770,6 +770,13 @@ def barh(self, x=None, y=None, **kwargs): Return an custom object when ``backend!=plotly``. Return an ndarray when ``subplots=True`` (matplotlib-only). + Notes + ----- + In Plotly and Matplotlib, the interpretation of `x` and `y` for `barh` plots differs. + In Plotly, `x` refers to the values and `y` refers to the categories. + In Matplotlib, `x` refers to the categories and `y` refers to the values. + Ensure correct axis labeling based on the backend used. + See Also -------- plotly.express.bar : Plot a vertical bar plot using plotly. diff --git a/python/pyspark/pandas/spark/functions.py b/python/pyspark/pandas/spark/functions.py index 4bcf07f6f6503..bdd11559df3b6 100644 --- a/python/pyspark/pandas/spark/functions.py +++ b/python/pyspark/pandas/spark/functions.py @@ -19,197 +19,76 @@ """ from pyspark.sql import Column, functions as F from pyspark.sql.utils import is_remote -from typing import Union +from typing import Union, TYPE_CHECKING +if TYPE_CHECKING: + from pyspark.sql._typing import ColumnOrName -def product(col: Column, dropna: bool) -> Column: - if is_remote(): - from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit - - return _invoke_function_over_columns( - "pandas_product", - col, - lit(dropna), - ) - - else: - from pyspark import SparkContext - sc = SparkContext._active_spark_context - return Column(sc._jvm.PythonSQLUtils.pandasProduct(col._jc, dropna)) - - -def stddev(col: Column, ddof: int) -> Column: +def _invoke_internal_function_over_columns(name: str, *cols: "ColumnOrName") -> Column: if is_remote(): - from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit + from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns - return _invoke_function_over_columns( - "pandas_stddev", - col, - lit(ddof), - ) + return _invoke_function_over_columns(name, *cols) else: + from pyspark.sql.classic.column import _to_seq, _to_java_column from pyspark import SparkContext sc = SparkContext._active_spark_context - return Column(sc._jvm.PythonSQLUtils.pandasStddev(col._jc, ddof)) + return Column(sc._jvm.PythonSQLUtils.internalFn(name, _to_seq(sc, cols, _to_java_column))) -def var(col: Column, ddof: int) -> Column: - if is_remote(): - from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit +def timestamp_ntz_to_long(col: Column) -> Column: + return _invoke_internal_function_over_columns("timestamp_ntz_to_long", col) - return _invoke_function_over_columns( - "pandas_var", - col, - lit(ddof), - ) - else: - from pyspark import SparkContext +def product(col: Column, dropna: bool) -> Column: + return _invoke_internal_function_over_columns("pandas_product", col, F.lit(dropna)) - sc = SparkContext._active_spark_context - return Column(sc._jvm.PythonSQLUtils.pandasVariance(col._jc, ddof)) +def stddev(col: Column, ddof: int) -> Column: + return _invoke_internal_function_over_columns("pandas_stddev", col, F.lit(ddof)) -def skew(col: Column) -> Column: - if is_remote(): - from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns - return _invoke_function_over_columns( - "pandas_skew", - col, - ) +def var(col: Column, ddof: int) -> Column: + return _invoke_internal_function_over_columns("pandas_var", col, F.lit(ddof)) - else: - from pyspark import SparkContext - sc = SparkContext._active_spark_context - return Column(sc._jvm.PythonSQLUtils.pandasSkewness(col._jc)) +def skew(col: Column) -> Column: + return _invoke_internal_function_over_columns("pandas_skew", col) def kurt(col: Column) -> Column: - if is_remote(): - from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns - - return _invoke_function_over_columns( - "pandas_kurt", - col, - ) - - else: - from pyspark import SparkContext - - sc = SparkContext._active_spark_context - return Column(sc._jvm.PythonSQLUtils.pandasKurtosis(col._jc)) + return _invoke_internal_function_over_columns("pandas_kurt", col) def mode(col: Column, dropna: bool) -> Column: - if is_remote(): - from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit - - return _invoke_function_over_columns( - "pandas_mode", - col, - lit(dropna), - ) - - else: - from pyspark import SparkContext - - sc = SparkContext._active_spark_context - return Column(sc._jvm.PythonSQLUtils.pandasMode(col._jc, dropna)) + return _invoke_internal_function_over_columns("pandas_mode", col, F.lit(dropna)) def covar(col1: Column, col2: Column, ddof: int) -> Column: - if is_remote(): - from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit + return _invoke_internal_function_over_columns("pandas_covar", col1, col2, F.lit(ddof)) - return _invoke_function_over_columns( - "pandas_covar", - col1, - col2, - lit(ddof), - ) - else: - from pyspark import SparkContext - - sc = SparkContext._active_spark_context - return Column(sc._jvm.PythonSQLUtils.pandasCovar(col1._jc, col2._jc, ddof)) - - -def ewm(col: Column, alpha: float, ignore_na: bool) -> Column: - if is_remote(): - from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit - - return _invoke_function_over_columns( - "ewm", - col, - lit(alpha), - lit(ignore_na), - ) - - else: - from pyspark import SparkContext - - sc = SparkContext._active_spark_context - return Column(sc._jvm.PythonSQLUtils.ewm(col._jc, alpha, ignore_na)) +def ewm(col: Column, alpha: float, ignorena: bool) -> Column: + return _invoke_internal_function_over_columns("ewm", col, F.lit(alpha), F.lit(ignorena)) def null_index(col: Column) -> Column: - if is_remote(): - from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns - - return _invoke_function_over_columns( - "null_index", - col, - ) - - else: - from pyspark import SparkContext - - sc = SparkContext._active_spark_context - return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc)) + return _invoke_internal_function_over_columns("null_index", col) def distributed_sequence_id() -> Column: - if is_remote(): - from pyspark.sql.connect.functions.builtin import _invoke_function - - return _invoke_function("distributed_sequence_id") - else: - from pyspark import SparkContext - - sc = SparkContext._active_spark_context - return Column(sc._jvm.PythonSQLUtils.distributed_sequence_id()) + return _invoke_internal_function_over_columns("distributed_sequence_id") def collect_top_k(col: Column, num: int, reverse: bool) -> Column: - if is_remote(): - from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns + return _invoke_internal_function_over_columns("collect_top_k", col, F.lit(num), F.lit(reverse)) - return _invoke_function_over_columns("collect_top_k", col, F.lit(num), F.lit(reverse)) - else: - from pyspark import SparkContext - - sc = SparkContext._active_spark_context - return Column(sc._jvm.PythonSQLUtils.collect_top_k(col._jc, num, reverse)) - - -def binary_search(col: Column, value: Column) -> Column: - if is_remote(): - from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns - - return _invoke_function_over_columns("array_binary_search", col, value) - - else: - from pyspark import SparkContext - - sc = SparkContext._active_spark_context - return Column(sc._jvm.PythonSQLUtils.binary_search(col._jc, value._jc)) +def array_binary_search(col: Column, value: Column) -> Column: + return _invoke_internal_function_over_columns("array_binary_search", col, value) def make_interval(unit: str, e: Union[Column, int, float]) -> Column: diff --git a/python/pyspark/pandas/supported_api_gen.py b/python/pyspark/pandas/supported_api_gen.py index bbf0b3cbc3d67..f2a73cb1c1adf 100644 --- a/python/pyspark/pandas/supported_api_gen.py +++ b/python/pyspark/pandas/supported_api_gen.py @@ -38,7 +38,7 @@ MAX_MISSING_PARAMS_SIZE = 5 COMMON_PARAMETER_SET = {"kwargs", "args", "cls"} MODULE_GROUP_MATCH = [(pd, ps), (pdw, psw), (pdg, psg)] -PANDAS_LATEST_VERSION = "2.2.2" +PANDAS_LATEST_VERSION = "2.2.3" RST_HEADER = """ ===================== diff --git a/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py b/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py index 37469db2c8f51..8d197649aaebe 100644 --- a/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py +++ b/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py @@ -105,9 +105,10 @@ def check_barh_plot_with_x_y(pdf, psdf, x, y): self.assertEqual(pdf.plot.barh(x=x, y=y), psdf.plot.barh(x=x, y=y)) # this is testing plot with specified x and y - pdf1 = pd.DataFrame({"lab": ["A", "B", "C"], "val": [10, 30, 20]}) + pdf1 = pd.DataFrame({"lab": ["A", "B", "C"], "val": [10, 30, 20], "val2": [1.1, 2.2, 3.3]}) psdf1 = ps.from_pandas(pdf1) - check_barh_plot_with_x_y(pdf1, psdf1, x="lab", y="val") + check_barh_plot_with_x_y(pdf1, psdf1, x="val", y="lab") + check_barh_plot_with_x_y(pdf1, psdf1, x=["val", "val2"], y="lab") def test_barh_plot(self): def check_barh_plot(pdf, psdf): diff --git a/python/pyspark/pandas/window.py b/python/pyspark/pandas/window.py index 0aaeb7df89be5..fb5dd29169e91 100644 --- a/python/pyspark/pandas/window.py +++ b/python/pyspark/pandas/window.py @@ -2434,7 +2434,8 @@ def _compute_unified_alpha(self) -> float: if opt_count != 1: raise ValueError("com, span, halflife, and alpha are mutually exclusive") - return unified_alpha + # convert possible numpy.float64 to float for lit function + return float(unified_alpha) @abstractmethod def _apply_as_series_or_frame(self, func: Callable[[Column], Column]) -> FrameLike: diff --git a/python/pyspark/sql/classic/dataframe.py b/python/pyspark/sql/classic/dataframe.py index 91b9591625904..91dec609e522a 100644 --- a/python/pyspark/sql/classic/dataframe.py +++ b/python/pyspark/sql/classic/dataframe.py @@ -55,6 +55,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.column import Column +from pyspark.sql.functions import builtin as F from pyspark.sql.classic.column import _to_seq, _to_list, _to_java_column from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2 from pyspark.sql.merge import MergeIntoWriter @@ -73,6 +74,11 @@ from pyspark.sql.pandas.conversion import PandasConversionMixin from pyspark.sql.pandas.map_ops import PandasMapOpsMixin +try: + from pyspark.sql.plot import PySparkPlotAccessor +except ImportError: + PySparkPlotAccessor = None # type: ignore + if TYPE_CHECKING: from py4j.java_gateway import JavaObject import pyarrow as pa @@ -354,8 +360,13 @@ def checkpoint(self, eager: bool = True) -> ParentDataFrame: jdf = self._jdf.checkpoint(eager) return DataFrame(jdf, self.sparkSession) - def localCheckpoint(self, eager: bool = True) -> ParentDataFrame: - jdf = self._jdf.localCheckpoint(eager) + def localCheckpoint( + self, eager: bool = True, storageLevel: Optional[StorageLevel] = None + ) -> ParentDataFrame: + if storageLevel is None: + jdf = self._jdf.localCheckpoint(eager) + else: + jdf = self._jdf.localCheckpoint(eager, self._sc._getJavaStorageLevel(storageLevel)) return DataFrame(jdf, self.sparkSession) def withWatermark(self, eventTime: str, delayThreshold: str) -> ParentDataFrame: @@ -595,44 +606,8 @@ def sample( # type: ignore[misc] fraction: Optional[Union[int, float]] = None, seed: Optional[int] = None, ) -> ParentDataFrame: - # For the cases below: - # sample(True, 0.5 [, seed]) - # sample(True, fraction=0.5 [, seed]) - # sample(withReplacement=False, fraction=0.5 [, seed]) - is_withReplacement_set = type(withReplacement) == bool and isinstance(fraction, float) - - # For the case below: - # sample(faction=0.5 [, seed]) - is_withReplacement_omitted_kwargs = withReplacement is None and isinstance(fraction, float) - - # For the case below: - # sample(0.5 [, seed]) - is_withReplacement_omitted_args = isinstance(withReplacement, float) - - if not ( - is_withReplacement_set - or is_withReplacement_omitted_kwargs - or is_withReplacement_omitted_args - ): - argtypes = [type(arg).__name__ for arg in [withReplacement, fraction, seed]] - raise PySparkTypeError( - errorClass="NOT_BOOL_OR_FLOAT_OR_INT", - messageParameters={ - "arg_name": "withReplacement (optional), " - + "fraction (required) and seed (optional)", - "arg_type": ", ".join(argtypes), - }, - ) - - if is_withReplacement_omitted_args: - if fraction is not None: - seed = cast(int, fraction) - fraction = withReplacement - withReplacement = None - - seed = int(seed) if seed is not None else None - args = [arg for arg in [withReplacement, fraction, seed] if arg is not None] - jdf = self._jdf.sample(*args) + _w, _f, _s = self._preapare_args_for_sample(withReplacement, fraction, seed) + jdf = self._jdf.sample(*[_w, _f, _s]) return DataFrame(jdf, self.sparkSession) def sampleBy( @@ -868,7 +843,8 @@ def sortWithinPartitions( *cols: Union[int, str, Column, List[Union[int, str, Column]]], **kwargs: Any, ) -> ParentDataFrame: - jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs)) + _cols = self._preapare_cols_for_sort(F.col, cols, kwargs) + jdf = self._jdf.sortWithinPartitions(self._jseq(_cols, _to_java_column)) return DataFrame(jdf, self.sparkSession) def sort( @@ -876,7 +852,8 @@ def sort( *cols: Union[int, str, Column, List[Union[int, str, Column]]], **kwargs: Any, ) -> ParentDataFrame: - jdf = self._jdf.sort(self._sort_cols(cols, kwargs)) + _cols = self._preapare_cols_for_sort(F.col, cols, kwargs) + jdf = self._jdf.sort(self._jseq(_cols, _to_java_column)) return DataFrame(jdf, self.sparkSession) orderBy = sort @@ -923,51 +900,6 @@ def _jcols_ordinal(self, *cols: "ColumnOrNameOrOrdinal") -> "JavaObject": _cols.append(c) # type: ignore[arg-type] return self._jseq(_cols, _to_java_column) - def _sort_cols( - self, - cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]], - kwargs: Dict[str, Any], - ) -> "JavaObject": - """Return a JVM Seq of Columns that describes the sort order""" - if not cols: - raise PySparkValueError( - errorClass="CANNOT_BE_EMPTY", - messageParameters={"item": "column"}, - ) - if len(cols) == 1 and isinstance(cols[0], list): - cols = cols[0] - - jcols = [] - for c in cols: - if isinstance(c, int) and not isinstance(c, bool): - # ordinal is 1-based - if c > 0: - _c = self[c - 1] - # negative ordinal means sort by desc - elif c < 0: - _c = self[-c - 1].desc() - else: - raise PySparkIndexError( - errorClass="ZERO_INDEX", - messageParameters={}, - ) - else: - _c = c # type: ignore[assignment] - jcols.append(_to_java_column(cast("ColumnOrName", _c))) - - ascending = kwargs.get("ascending", True) - if isinstance(ascending, (bool, int)): - if not ascending: - jcols = [jc.desc() for jc in jcols] - elif isinstance(ascending, list): - jcols = [jc if asc else jc.desc() for asc, jc in zip(ascending, jcols)] - else: - raise PySparkTypeError( - errorClass="NOT_BOOL_OR_LIST", - messageParameters={"arg_name": "ascending", "arg_type": type(ascending).__name__}, - ) - return self._jseq(jcols) - def describe(self, *cols: Union[str, List[str]]) -> ParentDataFrame: if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] # type: ignore[assignment] @@ -1063,7 +995,7 @@ def selectExpr(self, *expr: Union[str, List[str]]) -> ParentDataFrame: jdf = self._jdf.selectExpr(self._jseq(expr)) return DataFrame(jdf, self.sparkSession) - def filter(self, condition: "ColumnOrName") -> ParentDataFrame: + def filter(self, condition: Union[Column, str]) -> ParentDataFrame: if isinstance(condition, str): jdf = self._jdf.filter(condition) elif isinstance(condition, Column): @@ -1782,7 +1714,7 @@ def semanticHash(self) -> int: def inputFiles(self) -> List[str]: return list(self._jdf.inputFiles()) - def where(self, condition: "ColumnOrName") -> ParentDataFrame: + def where(self, condition: Union[Column, str]) -> ParentDataFrame: return self.filter(condition) # Two aliases below were added for pandas compatibility many years ago. @@ -1804,10 +1736,10 @@ def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": # type: ign def drop_duplicates(self, subset: Optional[List[str]] = None) -> ParentDataFrame: return self.dropDuplicates(subset) - def writeTo(self, table: str) -> DataFrameWriterV2: + def writeTo(self, table: str) -> "DataFrameWriterV2": return DataFrameWriterV2(self, table) - def mergeInto(self, table: str, condition: Column) -> MergeIntoWriter: + def mergeInto(self, table: str, condition: Column) -> "MergeIntoWriter": return MergeIntoWriter(self, table, condition) def pandas_api( @@ -1862,6 +1794,10 @@ def executionInfo(self) -> Optional["ExecutionInfo"]: messageParameters={"member": "queryExecution"}, ) + @property + def plot(self) -> PySparkPlotAccessor: + return PySparkPlotAccessor(self) + class DataFrameNaFunctions(ParentDataFrameNaFunctions): def __init__(self, df: ParentDataFrame): diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py index ea6788e858317..e0c7cc448933d 100644 --- a/python/pyspark/sql/connect/client/reattach.py +++ b/python/pyspark/sql/connect/client/reattach.py @@ -24,7 +24,7 @@ import uuid from collections.abc import Generator from typing import Optional, Any, Iterator, Iterable, Tuple, Callable, cast, Type, ClassVar -from multiprocessing.pool import ThreadPool +from concurrent.futures import ThreadPoolExecutor import os import grpc @@ -58,19 +58,18 @@ class ExecutePlanResponseReattachableIterator(Generator): # Lock to manage the pool _lock: ClassVar[RLock] = RLock() - _release_thread_pool_instance: Optional[ThreadPool] = None + _release_thread_pool_instance: Optional[ThreadPoolExecutor] = None @classmethod # type: ignore[misc] @property - def _release_thread_pool(cls) -> ThreadPool: + def _release_thread_pool(cls) -> ThreadPoolExecutor: # Perform a first check outside the critical path. if cls._release_thread_pool_instance is not None: return cls._release_thread_pool_instance with cls._lock: if cls._release_thread_pool_instance is None: - cls._release_thread_pool_instance = ThreadPool( - os.cpu_count() if os.cpu_count() else 8 - ) + max_workers = os.cpu_count() or 8 + cls._release_thread_pool_instance = ThreadPoolExecutor(max_workers=max_workers) return cls._release_thread_pool_instance @classmethod @@ -81,8 +80,7 @@ def shutdown(cls: Type["ExecutePlanResponseReattachableIterator"]) -> None: """ with cls._lock: if cls._release_thread_pool_instance is not None: - cls._release_thread_pool.close() # type: ignore[attr-defined] - cls._release_thread_pool.join() # type: ignore[attr-defined] + cls._release_thread_pool.shutdown() # type: ignore[attr-defined] cls._release_thread_pool_instance = None def __init__( @@ -212,7 +210,7 @@ def target() -> None: with self._lock: if self._release_thread_pool_instance is not None: - self._release_thread_pool.apply_async(target) + self._release_thread_pool.submit(target) def _release_all(self) -> None: """ @@ -237,7 +235,7 @@ def target() -> None: with self._lock: if self._release_thread_pool_instance is not None: - self._release_thread_pool.apply_async(target) + self._release_thread_pool.submit(target) self._result_complete = True def _call_iter(self, iter_fun: Callable) -> Any: diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 768abd655d497..3d5b845fcd24c 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -86,6 +86,10 @@ from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema from pyspark.sql.pandas.functions import _validate_pandas_udf # type: ignore[attr-defined] +try: + from pyspark.sql.plot import PySparkPlotAccessor +except ImportError: + PySparkPlotAccessor = None # type: ignore if TYPE_CHECKING: from pyspark.sql.connect._typing import ( @@ -531,7 +535,7 @@ def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": def groupby(self, __cols: Union[List[Column], List[str], List[int]]) -> "GroupedData": ... - def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> GroupedData: + def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] @@ -566,7 +570,7 @@ def rollup(self, *cols: "ColumnOrName") -> "GroupedData": def rollup(self, __cols: Union[List[Column], List[str]]) -> "GroupedData": ... - def rollup(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] + def rollup(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": # type: ignore[misc] _cols: List[Column] = [] for c in cols: if isinstance(c, Column): @@ -727,70 +731,24 @@ def _convert_col(df: ParentDataFrame, col: "ColumnOrName") -> Column: session=self._session, ) - def limit(self, n: int) -> ParentDataFrame: - res = DataFrame(plan.Limit(child=self._plan, limit=n), session=self._session) + def limit(self, num: int) -> ParentDataFrame: + res = DataFrame(plan.Limit(child=self._plan, limit=num), session=self._session) res._cached_schema = self._cached_schema return res def tail(self, num: int) -> List[Row]: return DataFrame(plan.Tail(child=self._plan, limit=num), session=self._session).collect() - def _sort_cols( - self, - cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]], - kwargs: Dict[str, Any], - ) -> List[Column]: - """Return a JVM Seq of Columns that describes the sort order""" - if cols is None: - raise PySparkValueError( - errorClass="CANNOT_BE_EMPTY", - messageParameters={"item": "cols"}, - ) - - if len(cols) == 1 and isinstance(cols[0], list): - cols = cols[0] - - _cols: List[Column] = [] - for c in cols: - if isinstance(c, int) and not isinstance(c, bool): - # ordinal is 1-based - if c > 0: - _c = self[c - 1] - # negative ordinal means sort by desc - elif c < 0: - _c = self[-c - 1].desc() - else: - raise PySparkIndexError( - errorClass="ZERO_INDEX", - messageParameters={}, - ) - else: - _c = c # type: ignore[assignment] - _cols.append(F._to_col(cast("ColumnOrName", _c))) - - ascending = kwargs.get("ascending", True) - if isinstance(ascending, (bool, int)): - if not ascending: - _cols = [c.desc() for c in _cols] - elif isinstance(ascending, list): - _cols = [c if asc else c.desc() for asc, c in zip(ascending, _cols)] - else: - raise PySparkTypeError( - errorClass="NOT_BOOL_OR_LIST", - messageParameters={"arg_name": "ascending", "arg_type": type(ascending).__name__}, - ) - - return [F._sort_col(c) for c in _cols] - def sort( self, *cols: Union[int, str, Column, List[Union[int, str, Column]]], **kwargs: Any, ) -> ParentDataFrame: + _cols = self._preapare_cols_for_sort(F.col, cols, kwargs) res = DataFrame( plan.Sort( self._plan, - columns=self._sort_cols(cols, kwargs), + columns=[F._sort_col(c) for c in _cols], is_global=True, ), session=self._session, @@ -805,10 +763,11 @@ def sortWithinPartitions( *cols: Union[int, str, Column, List[Union[int, str, Column]]], **kwargs: Any, ) -> ParentDataFrame: + _cols = self._preapare_cols_for_sort(F.col, cols, kwargs) res = DataFrame( plan.Sort( self._plan, - columns=self._sort_cols(cols, kwargs), + columns=[F._sort_col(c) for c in _cols], is_global=False, ), session=self._session, @@ -822,53 +781,14 @@ def sample( fraction: Optional[Union[int, float]] = None, seed: Optional[int] = None, ) -> ParentDataFrame: - # For the cases below: - # sample(True, 0.5 [, seed]) - # sample(True, fraction=0.5 [, seed]) - # sample(withReplacement=False, fraction=0.5 [, seed]) - is_withReplacement_set = type(withReplacement) == bool and isinstance(fraction, float) - - # For the case below: - # sample(faction=0.5 [, seed]) - is_withReplacement_omitted_kwargs = withReplacement is None and isinstance(fraction, float) - - # For the case below: - # sample(0.5 [, seed]) - is_withReplacement_omitted_args = isinstance(withReplacement, float) - - if not ( - is_withReplacement_set - or is_withReplacement_omitted_kwargs - or is_withReplacement_omitted_args - ): - argtypes = [type(arg).__name__ for arg in [withReplacement, fraction, seed]] - raise PySparkTypeError( - errorClass="NOT_BOOL_OR_FLOAT_OR_INT", - messageParameters={ - "arg_name": "withReplacement (optional), " - + "fraction (required) and seed (optional)", - "arg_type": ", ".join(argtypes), - }, - ) - - if is_withReplacement_omitted_args: - if fraction is not None: - seed = cast(int, fraction) - fraction = withReplacement - withReplacement = None - - if withReplacement is None: - withReplacement = False - - seed = int(seed) if seed is not None else random.randint(0, sys.maxsize) - + _w, _f, _s = self._preapare_args_for_sample(withReplacement, fraction, seed) res = DataFrame( plan.Sample( child=self._plan, lower_bound=0.0, - upper_bound=fraction, # type: ignore[arg-type] - with_replacement=withReplacement, # type: ignore[arg-type] - seed=seed, + upper_bound=_f, + with_replacement=_w, + seed=_s, ), session=self._session, ) @@ -927,7 +847,11 @@ def _show_string( )._to_table() return table[0][0].as_py() - def withColumns(self, colsMap: Dict[str, Column]) -> ParentDataFrame: + def withColumns(self, *colsMap: Dict[str, Column]) -> ParentDataFrame: + # Below code is to help enable kwargs in future. + assert len(colsMap) == 1 + colsMap = colsMap[0] # type: ignore[assignment] + if not isinstance(colsMap, dict): raise PySparkTypeError( errorClass="NOT_DICT", @@ -2189,7 +2113,7 @@ def cb(ei: "ExecutionInfo") -> None: return DataFrameWriterV2(self._plan, self._session, table, cb) - def mergeInto(self, table: str, condition: Column) -> MergeIntoWriter: + def mergeInto(self, table: str, condition: Column) -> "MergeIntoWriter": def cb(ei: "ExecutionInfo") -> None: self._execution_info = ei @@ -2197,10 +2121,10 @@ def cb(ei: "ExecutionInfo") -> None: self._plan, self._session, table, condition, cb # type: ignore[arg-type] ) - def offset(self, n: int) -> ParentDataFrame: - return DataFrame(plan.Offset(child=self._plan, offset=n), session=self._session) + def offset(self, num: int) -> ParentDataFrame: + return DataFrame(plan.Offset(child=self._plan, offset=num), session=self._session) - def checkpoint(self, eager: bool = True) -> "DataFrame": + def checkpoint(self, eager: bool = True) -> ParentDataFrame: cmd = plan.Checkpoint(child=self._plan, local=False, eager=eager) _, properties, self._execution_info = self._session.client.execute_command( cmd.command(self._session.client) @@ -2210,8 +2134,10 @@ def checkpoint(self, eager: bool = True) -> "DataFrame": assert isinstance(checkpointed._plan, plan.CachedRemoteRelation) return checkpointed - def localCheckpoint(self, eager: bool = True) -> "DataFrame": - cmd = plan.Checkpoint(child=self._plan, local=True, eager=eager) + def localCheckpoint( + self, eager: bool = True, storageLevel: Optional[StorageLevel] = None + ) -> ParentDataFrame: + cmd = plan.Checkpoint(child=self._plan, local=True, eager=eager, storage_level=storageLevel) _, properties, self._execution_info = self._session.client.execute_command( cmd.command(self._session.client) ) @@ -2239,6 +2165,10 @@ def rdd(self) -> "RDD[Row]": def executionInfo(self) -> Optional["ExecutionInfo"]: return self._execution_info + @property + def plot(self) -> PySparkPlotAccessor: + return PySparkPlotAccessor(self) + class DataFrameNaFunctions(ParentDataFrameNaFunctions): def __init__(self, df: ParentDataFrame): diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index db1cd1c013be5..203b6ce371a5c 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -301,7 +301,7 @@ def _infer_type(cls, value: Any) -> DataType: return NullType() elif isinstance(value, (bytes, bytearray)): return BinaryType() - elif isinstance(value, bool): + elif isinstance(value, (bool, np.bool_)): return BooleanType() elif isinstance(value, int): if JVM_INT_MIN <= value <= JVM_INT_MAX: @@ -323,10 +323,8 @@ def _infer_type(cls, value: Any) -> DataType: return StringType() elif isinstance(value, decimal.Decimal): return DecimalType() - elif isinstance(value, datetime.datetime) and is_timestamp_ntz_preferred(): - return TimestampNTZType() elif isinstance(value, datetime.datetime): - return TimestampType() + return TimestampNTZType() if is_timestamp_ntz_preferred() else TimestampType() elif isinstance(value, datetime.date): return DateType() elif isinstance(value, datetime.timedelta): @@ -335,23 +333,15 @@ def _infer_type(cls, value: Any) -> DataType: dt = _from_numpy_type(value.dtype) if dt is not None: return dt - elif isinstance(value, np.bool_): - return BooleanType() elif isinstance(value, list): # follow the 'infer_array_from_first_element' strategy in 'sql.types._infer_type' # right now, it's dedicated for pyspark.ml params like array<...>, array> - if len(value) == 0: - raise PySparkValueError( - errorClass="CANNOT_BE_EMPTY", - messageParameters={"item": "value"}, - ) - first = value[0] - if first is None: + if len(value) == 0 or value[0] is None: raise PySparkTypeError( - errorClass="CANNOT_INFER_ARRAY_TYPE", + errorClass="CANNOT_INFER_ARRAY_ELEMENT_TYPE", messageParameters={}, ) - return ArrayType(LiteralExpression._infer_type(first), True) + return ArrayType(LiteralExpression._infer_type(value[0]), True) raise PySparkTypeError( errorClass="UNSUPPORTED_DATA_TYPE", @@ -477,8 +467,30 @@ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": def __repr__(self) -> str: if self._value is None: return "NULL" - else: - return f"{self._value}" + elif isinstance(self._dataType, DateType): + dt = DateType().fromInternal(self._value) + if dt is not None and isinstance(dt, datetime.date): + return dt.strftime("%Y-%m-%d") + elif isinstance(self._dataType, TimestampType): + ts = TimestampType().fromInternal(self._value) + if ts is not None and isinstance(ts, datetime.datetime): + return ts.strftime("%Y-%m-%d %H:%M:%S.%f") + elif isinstance(self._dataType, TimestampNTZType): + ts = TimestampNTZType().fromInternal(self._value) + if ts is not None and isinstance(ts, datetime.datetime): + return ts.strftime("%Y-%m-%d %H:%M:%S.%f") + elif isinstance(self._dataType, DayTimeIntervalType): + delta = DayTimeIntervalType().fromInternal(self._value) + if delta is not None and isinstance(delta, datetime.timedelta): + import pandas as pd + + # Note: timedelta itself does not provide isoformat method. + # Both Pandas and java.time.Duration provide it, but the format + # is sightly different: + # java.time.Duration only applies HOURS, MINUTES, SECONDS units, + # while Pandas applies all supported units. + return pd.Timedelta(delta).isoformat() # type: ignore[attr-defined] + return f"{self._value}" class ColumnReference(Expression): @@ -787,7 +799,7 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression: return expr def __repr__(self) -> str: - return f"WithField({self._structExpr}, {self._fieldName}, {self._valueExpr})" + return f"update_field({self._structExpr}, {self._fieldName}, {self._valueExpr})" class DropField(Expression): @@ -811,7 +823,7 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression: return expr def __repr__(self) -> str: - return f"DropField({self._structExpr}, {self._fieldName})" + return f"drop_field({self._structExpr}, {self._fieldName})" class UnresolvedExtractValue(Expression): @@ -835,7 +847,7 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression: return expr def __repr__(self) -> str: - return f"UnresolvedExtractValue({str(self._child)}, {str(self._extraction)})" + return f"{self._child}['{self._extraction}']" class UnresolvedRegex(Expression): diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index 031e7c22542d2..db12e085468a0 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -65,7 +65,6 @@ from pyspark.sql.types import ( _from_numpy_type, DataType, - LongType, StructType, ArrayType, StringType, @@ -1008,6 +1007,22 @@ def unhex(col: "ColumnOrName") -> Column: unhex.__doc__ = pysparkfuncs.unhex.__doc__ +def uniform( + min: Union[Column, int, float], + max: Union[Column, int, float], + seed: Optional[Union[Column, int]] = None, +) -> Column: + if seed is None: + return _invoke_function_over_columns( + "uniform", lit(min), lit(max), lit(random.randint(0, sys.maxsize)) + ) + else: + return _invoke_function_over_columns("uniform", lit(min), lit(max), lit(seed)) + + +uniform.__doc__ = pysparkfuncs.uniform.__doc__ + + def approxCountDistinct(col: "ColumnOrName", rsd: Optional[float] = None) -> Column: warnings.warn("Deprecated in 3.4, use approx_count_distinct instead.", FutureWarning) return approx_count_distinct(col, rsd) @@ -1126,11 +1141,12 @@ def grouping_id(*cols: "ColumnOrName") -> Column: def count_min_sketch( col: "ColumnOrName", - eps: "ColumnOrName", - confidence: "ColumnOrName", - seed: "ColumnOrName", + eps: Union[Column, float], + confidence: Union[Column, float], + seed: Optional[Union[Column, int]] = None, ) -> Column: - return _invoke_function_over_columns("count_min_sketch", col, eps, confidence, seed) + _seed = lit(random.randint(0, sys.maxsize)) if seed is None else lit(seed) + return _invoke_function_over_columns("count_min_sketch", col, lit(eps), lit(confidence), _seed) count_min_sketch.__doc__ = pysparkfuncs.count_min_sketch.__doc__ @@ -2204,12 +2220,9 @@ def schema_of_xml(xml: Union[str, Column], options: Optional[Mapping[str, str]] schema_of_xml.__doc__ = pysparkfuncs.schema_of_xml.__doc__ -def shuffle(col: "ColumnOrName") -> Column: - return _invoke_function( - "shuffle", - _to_col(col), - LiteralExpression(random.randint(0, sys.maxsize), LongType()), - ) +def shuffle(col: "ColumnOrName", seed: Optional[Union[Column, int]] = None) -> Column: + _seed = lit(random.randint(0, sys.maxsize)) if seed is None else lit(seed) + return _invoke_function("shuffle", _to_col(col), _seed) shuffle.__doc__ = pysparkfuncs.shuffle.__doc__ @@ -2381,22 +2394,31 @@ def unbase64(col: "ColumnOrName") -> Column: unbase64.__doc__ = pysparkfuncs.unbase64.__doc__ -def ltrim(col: "ColumnOrName") -> Column: - return _invoke_function_over_columns("ltrim", col) +def ltrim(col: "ColumnOrName", trim: Optional["ColumnOrName"] = None) -> Column: + if trim is not None: + return _invoke_function_over_columns("ltrim", trim, col) + else: + return _invoke_function_over_columns("ltrim", col) ltrim.__doc__ = pysparkfuncs.ltrim.__doc__ -def rtrim(col: "ColumnOrName") -> Column: - return _invoke_function_over_columns("rtrim", col) +def rtrim(col: "ColumnOrName", trim: Optional["ColumnOrName"] = None) -> Column: + if trim is not None: + return _invoke_function_over_columns("rtrim", trim, col) + else: + return _invoke_function_over_columns("rtrim", col) rtrim.__doc__ = pysparkfuncs.rtrim.__doc__ -def trim(col: "ColumnOrName") -> Column: - return _invoke_function_over_columns("trim", col) +def trim(col: "ColumnOrName", trim: Optional["ColumnOrName"] = None) -> Column: + if trim is not None: + return _invoke_function_over_columns("trim", trim, col) + else: + return _invoke_function_over_columns("trim", col) trim.__doc__ = pysparkfuncs.trim.__doc__ @@ -2488,8 +2510,14 @@ def sentences( sentences.__doc__ = pysparkfuncs.sentences.__doc__ -def substring(str: "ColumnOrName", pos: int, len: int) -> Column: - return _invoke_function("substring", _to_col(str), lit(pos), lit(len)) +def substring( + str: "ColumnOrName", + pos: Union["ColumnOrName", int], + len: Union["ColumnOrName", int], +) -> Column: + _pos = lit(pos) if isinstance(pos, int) else _to_col(pos) + _len = lit(len) if isinstance(len, int) else _to_col(len) + return _invoke_function("substring", _to_col(str), _pos, _len) substring.__doc__ = pysparkfuncs.substring.__doc__ @@ -2578,6 +2606,18 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: regexp_like.__doc__ = pysparkfuncs.regexp_like.__doc__ +def randstr(length: Union[Column, int], seed: Optional[Union[Column, int]] = None) -> Column: + if seed is None: + return _invoke_function_over_columns( + "randstr", lit(length), lit(random.randint(0, sys.maxsize)) + ) + else: + return _invoke_function_over_columns("randstr", lit(length), lit(seed)) + + +randstr.__doc__ = pysparkfuncs.randstr.__doc__ + + def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: return _invoke_function_over_columns("regexp_count", str, regexp) diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index fbed0eabc684f..b74f863db1e83 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -1868,21 +1868,29 @@ def command(self, session: "SparkConnectClient") -> proto.Command: class Checkpoint(LogicalPlan): - def __init__(self, child: Optional["LogicalPlan"], local: bool, eager: bool) -> None: + def __init__( + self, + child: Optional["LogicalPlan"], + local: bool, + eager: bool, + storage_level: Optional[StorageLevel] = None, + ) -> None: super().__init__(child) self._local = local self._eager = eager + self._storage_level = storage_level def command(self, session: "SparkConnectClient") -> proto.Command: cmd = proto.Command() assert self._child is not None - cmd.checkpoint_command.CopyFrom( - proto.CheckpointCommand( - relation=self._child.plan(session), - local=self._local, - eager=self._eager, - ) + checkpoint_command = proto.CheckpointCommand( + relation=self._child.plan(session), + local=self._local, + eager=self._eager, ) + if self._storage_level is not None: + checkpoint_command.storage_level.CopyFrom(storage_level_to_proto(self._storage_level)) + cmd.checkpoint_command.CopyFrom(checkpoint_command) return cmd diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py b/python/pyspark/sql/connect/proto/commands_pb2.py index 43390ffa36d33..562e9d817f5fe 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.py +++ b/python/pyspark/sql/connect/proto/commands_pb2.py @@ -35,7 +35,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\x90\r\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x18\x03 \x01(\x0b\x32).spark.connect.CreateDataFrameViewCommandH\x00R\x13\x63reateDataframeView\x12O\n\x12write_operation_v2\x18\x04 \x01(\x0b\x32\x1f.spark.connect.WriteOperationV2H\x00R\x10writeOperationV2\x12<\n\x0bsql_command\x18\x05 \x01(\x0b\x32\x19.spark.connect.SqlCommandH\x00R\nsqlCommand\x12k\n\x1cwrite_stream_operation_start\x18\x06 \x01(\x0b\x32(.spark.connect.WriteStreamOperationStartH\x00R\x19writeStreamOperationStart\x12^\n\x17streaming_query_command\x18\x07 \x01(\x0b\x32$.spark.connect.StreamingQueryCommandH\x00R\x15streamingQueryCommand\x12X\n\x15get_resources_command\x18\x08 \x01(\x0b\x32".spark.connect.GetResourcesCommandH\x00R\x13getResourcesCommand\x12t\n\x1fstreaming_query_manager_command\x18\t \x01(\x0b\x32+.spark.connect.StreamingQueryManagerCommandH\x00R\x1cstreamingQueryManagerCommand\x12m\n\x17register_table_function\x18\n \x01(\x0b\x32\x33.spark.connect.CommonInlineUserDefinedTableFunctionH\x00R\x15registerTableFunction\x12\x81\x01\n$streaming_query_listener_bus_command\x18\x0b \x01(\x0b\x32/.spark.connect.StreamingQueryListenerBusCommandH\x00R streamingQueryListenerBusCommand\x12\x64\n\x14register_data_source\x18\x0c \x01(\x0b\x32\x30.spark.connect.CommonInlineUserDefinedDataSourceH\x00R\x12registerDataSource\x12t\n\x1f\x63reate_resource_profile_command\x18\r \x01(\x0b\x32+.spark.connect.CreateResourceProfileCommandH\x00R\x1c\x63reateResourceProfileCommand\x12Q\n\x12\x63heckpoint_command\x18\x0e \x01(\x0b\x32 .spark.connect.CheckpointCommandH\x00R\x11\x63heckpointCommand\x12\x84\x01\n%remove_cached_remote_relation_command\x18\x0f \x01(\x0b\x32\x30.spark.connect.RemoveCachedRemoteRelationCommandH\x00R!removeCachedRemoteRelationCommand\x12_\n\x18merge_into_table_command\x18\x10 \x01(\x0b\x32$.spark.connect.MergeIntoTableCommandH\x00R\x15mergeIntoTableCommand\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x0e\n\x0c\x63ommand_type"\xaa\x04\n\nSqlCommand\x12\x14\n\x03sql\x18\x01 \x01(\tB\x02\x18\x01R\x03sql\x12;\n\x04\x61rgs\x18\x02 \x03(\x0b\x32#.spark.connect.SqlCommand.ArgsEntryB\x02\x18\x01R\x04\x61rgs\x12@\n\x08pos_args\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralB\x02\x18\x01R\x07posArgs\x12Z\n\x0fnamed_arguments\x18\x04 \x03(\x0b\x32-.spark.connect.SqlCommand.NamedArgumentsEntryB\x02\x18\x01R\x0enamedArguments\x12\x42\n\rpos_arguments\x18\x05 \x03(\x0b\x32\x19.spark.connect.ExpressionB\x02\x18\x01R\x0cposArguments\x12-\n\x05input\x18\x06 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x1aZ\n\tArgsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value:\x02\x38\x01\x1a\\\n\x13NamedArgumentsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05value:\x02\x38\x01"\x96\x01\n\x1a\x43reateDataFrameViewCommand\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x1b\n\tis_global\x18\x03 \x01(\x08R\x08isGlobal\x12\x18\n\x07replace\x18\x04 \x01(\x08R\x07replace"\xca\x08\n\x0eWriteOperation\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1b\n\x06source\x18\x02 \x01(\tH\x01R\x06source\x88\x01\x01\x12\x14\n\x04path\x18\x03 \x01(\tH\x00R\x04path\x12?\n\x05table\x18\x04 \x01(\x0b\x32\'.spark.connect.WriteOperation.SaveTableH\x00R\x05table\x12:\n\x04mode\x18\x05 \x01(\x0e\x32&.spark.connect.WriteOperation.SaveModeR\x04mode\x12*\n\x11sort_column_names\x18\x06 \x03(\tR\x0fsortColumnNames\x12\x31\n\x14partitioning_columns\x18\x07 \x03(\tR\x13partitioningColumns\x12\x43\n\tbucket_by\x18\x08 \x01(\x0b\x32&.spark.connect.WriteOperation.BucketByR\x08\x62ucketBy\x12\x44\n\x07options\x18\t \x03(\x0b\x32*.spark.connect.WriteOperation.OptionsEntryR\x07options\x12-\n\x12\x63lustering_columns\x18\n \x03(\tR\x11\x63lusteringColumns\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x82\x02\n\tSaveTable\x12\x1d\n\ntable_name\x18\x01 \x01(\tR\ttableName\x12X\n\x0bsave_method\x18\x02 \x01(\x0e\x32\x37.spark.connect.WriteOperation.SaveTable.TableSaveMethodR\nsaveMethod"|\n\x0fTableSaveMethod\x12!\n\x1dTABLE_SAVE_METHOD_UNSPECIFIED\x10\x00\x12#\n\x1fTABLE_SAVE_METHOD_SAVE_AS_TABLE\x10\x01\x12!\n\x1dTABLE_SAVE_METHOD_INSERT_INTO\x10\x02\x1a[\n\x08\x42ucketBy\x12.\n\x13\x62ucket_column_names\x18\x01 \x03(\tR\x11\x62ucketColumnNames\x12\x1f\n\x0bnum_buckets\x18\x02 \x01(\x05R\nnumBuckets"\x89\x01\n\x08SaveMode\x12\x19\n\x15SAVE_MODE_UNSPECIFIED\x10\x00\x12\x14\n\x10SAVE_MODE_APPEND\x10\x01\x12\x17\n\x13SAVE_MODE_OVERWRITE\x10\x02\x12\x1d\n\x19SAVE_MODE_ERROR_IF_EXISTS\x10\x03\x12\x14\n\x10SAVE_MODE_IGNORE\x10\x04\x42\x0b\n\tsave_typeB\t\n\x07_source"\xdc\x06\n\x10WriteOperationV2\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1d\n\ntable_name\x18\x02 \x01(\tR\ttableName\x12\x1f\n\x08provider\x18\x03 \x01(\tH\x00R\x08provider\x88\x01\x01\x12L\n\x14partitioning_columns\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13partitioningColumns\x12\x46\n\x07options\x18\x05 \x03(\x0b\x32,.spark.connect.WriteOperationV2.OptionsEntryR\x07options\x12_\n\x10table_properties\x18\x06 \x03(\x0b\x32\x34.spark.connect.WriteOperationV2.TablePropertiesEntryR\x0ftableProperties\x12\x38\n\x04mode\x18\x07 \x01(\x0e\x32$.spark.connect.WriteOperationV2.ModeR\x04mode\x12J\n\x13overwrite_condition\x18\x08 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x12overwriteCondition\x12-\n\x12\x63lustering_columns\x18\t \x03(\tR\x11\x63lusteringColumns\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"\x9f\x01\n\x04Mode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\x0f\n\x0bMODE_CREATE\x10\x01\x12\x12\n\x0eMODE_OVERWRITE\x10\x02\x12\x1d\n\x19MODE_OVERWRITE_PARTITIONS\x10\x03\x12\x0f\n\x0bMODE_APPEND\x10\x04\x12\x10\n\x0cMODE_REPLACE\x10\x05\x12\x1a\n\x16MODE_CREATE_OR_REPLACE\x10\x06\x42\x0b\n\t_provider"\xd8\x06\n\x19WriteStreamOperationStart\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06\x66ormat\x18\x02 \x01(\tR\x06\x66ormat\x12O\n\x07options\x18\x03 \x03(\x0b\x32\x35.spark.connect.WriteStreamOperationStart.OptionsEntryR\x07options\x12:\n\x19partitioning_column_names\x18\x04 \x03(\tR\x17partitioningColumnNames\x12:\n\x18processing_time_interval\x18\x05 \x01(\tH\x00R\x16processingTimeInterval\x12%\n\ravailable_now\x18\x06 \x01(\x08H\x00R\x0c\x61vailableNow\x12\x14\n\x04once\x18\x07 \x01(\x08H\x00R\x04once\x12\x46\n\x1e\x63ontinuous_checkpoint_interval\x18\x08 \x01(\tH\x00R\x1c\x63ontinuousCheckpointInterval\x12\x1f\n\x0boutput_mode\x18\t \x01(\tR\noutputMode\x12\x1d\n\nquery_name\x18\n \x01(\tR\tqueryName\x12\x14\n\x04path\x18\x0b \x01(\tH\x01R\x04path\x12\x1f\n\ntable_name\x18\x0c \x01(\tH\x01R\ttableName\x12N\n\x0e\x66oreach_writer\x18\r \x01(\x0b\x32\'.spark.connect.StreamingForeachFunctionR\rforeachWriter\x12L\n\rforeach_batch\x18\x0e \x01(\x0b\x32\'.spark.connect.StreamingForeachFunctionR\x0c\x66oreachBatch\x12\x36\n\x17\x63lustering_column_names\x18\x0f \x03(\tR\x15\x63lusteringColumnNames\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07triggerB\x12\n\x10sink_destination"\xb3\x01\n\x18StreamingForeachFunction\x12\x43\n\x0fpython_function\x18\x01 \x01(\x0b\x32\x18.spark.connect.PythonUDFH\x00R\x0epythonFunction\x12\x46\n\x0escala_function\x18\x02 \x01(\x0b\x32\x1d.spark.connect.ScalarScalaUDFH\x00R\rscalaFunctionB\n\n\x08\x66unction"\xd4\x01\n\x1fWriteStreamOperationStartResult\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12<\n\x18query_started_event_json\x18\x03 \x01(\tH\x00R\x15queryStartedEventJson\x88\x01\x01\x42\x1b\n\x19_query_started_event_json"A\n\x18StreamingQueryInstanceId\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x15\n\x06run_id\x18\x02 \x01(\tR\x05runId"\xf8\x04\n\x15StreamingQueryCommand\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12\x18\n\x06status\x18\x02 \x01(\x08H\x00R\x06status\x12%\n\rlast_progress\x18\x03 \x01(\x08H\x00R\x0clastProgress\x12)\n\x0frecent_progress\x18\x04 \x01(\x08H\x00R\x0erecentProgress\x12\x14\n\x04stop\x18\x05 \x01(\x08H\x00R\x04stop\x12\x34\n\x15process_all_available\x18\x06 \x01(\x08H\x00R\x13processAllAvailable\x12O\n\x07\x65xplain\x18\x07 \x01(\x0b\x32\x33.spark.connect.StreamingQueryCommand.ExplainCommandH\x00R\x07\x65xplain\x12\x1e\n\texception\x18\x08 \x01(\x08H\x00R\texception\x12k\n\x11\x61wait_termination\x18\t \x01(\x0b\x32<.spark.connect.StreamingQueryCommand.AwaitTerminationCommandH\x00R\x10\x61waitTermination\x1a,\n\x0e\x45xplainCommand\x12\x1a\n\x08\x65xtended\x18\x01 \x01(\x08R\x08\x65xtended\x1aL\n\x17\x41waitTerminationCommand\x12"\n\ntimeout_ms\x18\x02 \x01(\x03H\x00R\ttimeoutMs\x88\x01\x01\x42\r\n\x0b_timeout_msB\t\n\x07\x63ommand"\xf5\x08\n\x1bStreamingQueryCommandResult\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12Q\n\x06status\x18\x02 \x01(\x0b\x32\x37.spark.connect.StreamingQueryCommandResult.StatusResultH\x00R\x06status\x12j\n\x0frecent_progress\x18\x03 \x01(\x0b\x32?.spark.connect.StreamingQueryCommandResult.RecentProgressResultH\x00R\x0erecentProgress\x12T\n\x07\x65xplain\x18\x04 \x01(\x0b\x32\x38.spark.connect.StreamingQueryCommandResult.ExplainResultH\x00R\x07\x65xplain\x12Z\n\texception\x18\x05 \x01(\x0b\x32:.spark.connect.StreamingQueryCommandResult.ExceptionResultH\x00R\texception\x12p\n\x11\x61wait_termination\x18\x06 \x01(\x0b\x32\x41.spark.connect.StreamingQueryCommandResult.AwaitTerminationResultH\x00R\x10\x61waitTermination\x1a\xaa\x01\n\x0cStatusResult\x12%\n\x0estatus_message\x18\x01 \x01(\tR\rstatusMessage\x12*\n\x11is_data_available\x18\x02 \x01(\x08R\x0fisDataAvailable\x12*\n\x11is_trigger_active\x18\x03 \x01(\x08R\x0fisTriggerActive\x12\x1b\n\tis_active\x18\x04 \x01(\x08R\x08isActive\x1aH\n\x14RecentProgressResult\x12\x30\n\x14recent_progress_json\x18\x05 \x03(\tR\x12recentProgressJson\x1a\'\n\rExplainResult\x12\x16\n\x06result\x18\x01 \x01(\tR\x06result\x1a\xc5\x01\n\x0f\x45xceptionResult\x12\x30\n\x11\x65xception_message\x18\x01 \x01(\tH\x00R\x10\x65xceptionMessage\x88\x01\x01\x12$\n\x0b\x65rror_class\x18\x02 \x01(\tH\x01R\nerrorClass\x88\x01\x01\x12$\n\x0bstack_trace\x18\x03 \x01(\tH\x02R\nstackTrace\x88\x01\x01\x42\x14\n\x12_exception_messageB\x0e\n\x0c_error_classB\x0e\n\x0c_stack_trace\x1a\x38\n\x16\x41waitTerminationResult\x12\x1e\n\nterminated\x18\x01 \x01(\x08R\nterminatedB\r\n\x0bresult_type"\xbd\x06\n\x1cStreamingQueryManagerCommand\x12\x18\n\x06\x61\x63tive\x18\x01 \x01(\x08H\x00R\x06\x61\x63tive\x12\x1d\n\tget_query\x18\x02 \x01(\tH\x00R\x08getQuery\x12|\n\x15\x61wait_any_termination\x18\x03 \x01(\x0b\x32\x46.spark.connect.StreamingQueryManagerCommand.AwaitAnyTerminationCommandH\x00R\x13\x61waitAnyTermination\x12+\n\x10reset_terminated\x18\x04 \x01(\x08H\x00R\x0fresetTerminated\x12n\n\x0c\x61\x64\x64_listener\x18\x05 \x01(\x0b\x32I.spark.connect.StreamingQueryManagerCommand.StreamingQueryListenerCommandH\x00R\x0b\x61\x64\x64Listener\x12t\n\x0fremove_listener\x18\x06 \x01(\x0b\x32I.spark.connect.StreamingQueryManagerCommand.StreamingQueryListenerCommandH\x00R\x0eremoveListener\x12\'\n\x0elist_listeners\x18\x07 \x01(\x08H\x00R\rlistListeners\x1aO\n\x1a\x41waitAnyTerminationCommand\x12"\n\ntimeout_ms\x18\x01 \x01(\x03H\x00R\ttimeoutMs\x88\x01\x01\x42\r\n\x0b_timeout_ms\x1a\xcd\x01\n\x1dStreamingQueryListenerCommand\x12)\n\x10listener_payload\x18\x01 \x01(\x0cR\x0flistenerPayload\x12U\n\x17python_listener_payload\x18\x02 \x01(\x0b\x32\x18.spark.connect.PythonUDFH\x00R\x15pythonListenerPayload\x88\x01\x01\x12\x0e\n\x02id\x18\x03 \x01(\tR\x02idB\x1a\n\x18_python_listener_payloadB\t\n\x07\x63ommand"\xb4\x08\n"StreamingQueryManagerCommandResult\x12X\n\x06\x61\x63tive\x18\x01 \x01(\x0b\x32>.spark.connect.StreamingQueryManagerCommandResult.ActiveResultH\x00R\x06\x61\x63tive\x12`\n\x05query\x18\x02 \x01(\x0b\x32H.spark.connect.StreamingQueryManagerCommandResult.StreamingQueryInstanceH\x00R\x05query\x12\x81\x01\n\x15\x61wait_any_termination\x18\x03 \x01(\x0b\x32K.spark.connect.StreamingQueryManagerCommandResult.AwaitAnyTerminationResultH\x00R\x13\x61waitAnyTermination\x12+\n\x10reset_terminated\x18\x04 \x01(\x08H\x00R\x0fresetTerminated\x12#\n\x0c\x61\x64\x64_listener\x18\x05 \x01(\x08H\x00R\x0b\x61\x64\x64Listener\x12)\n\x0fremove_listener\x18\x06 \x01(\x08H\x00R\x0eremoveListener\x12{\n\x0elist_listeners\x18\x07 \x01(\x0b\x32R.spark.connect.StreamingQueryManagerCommandResult.ListStreamingQueryListenerResultH\x00R\rlistListeners\x1a\x7f\n\x0c\x41\x63tiveResult\x12o\n\x0e\x61\x63tive_queries\x18\x01 \x03(\x0b\x32H.spark.connect.StreamingQueryManagerCommandResult.StreamingQueryInstanceR\ractiveQueries\x1as\n\x16StreamingQueryInstance\x12\x37\n\x02id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x02id\x12\x17\n\x04name\x18\x02 \x01(\tH\x00R\x04name\x88\x01\x01\x42\x07\n\x05_name\x1a;\n\x19\x41waitAnyTerminationResult\x12\x1e\n\nterminated\x18\x01 \x01(\x08R\nterminated\x1aK\n\x1eStreamingQueryListenerInstance\x12)\n\x10listener_payload\x18\x01 \x01(\x0cR\x0flistenerPayload\x1a\x45\n ListStreamingQueryListenerResult\x12!\n\x0clistener_ids\x18\x01 \x03(\tR\x0blistenerIdsB\r\n\x0bresult_type"\xad\x01\n StreamingQueryListenerBusCommand\x12;\n\x19\x61\x64\x64_listener_bus_listener\x18\x01 \x01(\x08H\x00R\x16\x61\x64\x64ListenerBusListener\x12\x41\n\x1cremove_listener_bus_listener\x18\x02 \x01(\x08H\x00R\x19removeListenerBusListenerB\t\n\x07\x63ommand"\x83\x01\n\x1bStreamingQueryListenerEvent\x12\x1d\n\nevent_json\x18\x01 \x01(\tR\teventJson\x12\x45\n\nevent_type\x18\x02 \x01(\x0e\x32&.spark.connect.StreamingQueryEventTypeR\teventType"\xcc\x01\n"StreamingQueryListenerEventsResult\x12\x42\n\x06\x65vents\x18\x01 \x03(\x0b\x32*.spark.connect.StreamingQueryListenerEventR\x06\x65vents\x12\x42\n\x1blistener_bus_listener_added\x18\x02 \x01(\x08H\x00R\x18listenerBusListenerAdded\x88\x01\x01\x42\x1e\n\x1c_listener_bus_listener_added"\x15\n\x13GetResourcesCommand"\xd4\x01\n\x19GetResourcesCommandResult\x12U\n\tresources\x18\x01 \x03(\x0b\x32\x37.spark.connect.GetResourcesCommandResult.ResourcesEntryR\tresources\x1a`\n\x0eResourcesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x38\n\x05value\x18\x02 \x01(\x0b\x32".spark.connect.ResourceInformationR\x05value:\x02\x38\x01"X\n\x1c\x43reateResourceProfileCommand\x12\x38\n\x07profile\x18\x01 \x01(\x0b\x32\x1e.spark.connect.ResourceProfileR\x07profile"C\n"CreateResourceProfileCommandResult\x12\x1d\n\nprofile_id\x18\x01 \x01(\x05R\tprofileId"d\n!RemoveCachedRemoteRelationCommand\x12?\n\x08relation\x18\x01 \x01(\x0b\x32#.spark.connect.CachedRemoteRelationR\x08relation"t\n\x11\x43heckpointCommand\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x14\n\x05local\x18\x02 \x01(\x08R\x05local\x12\x14\n\x05\x65\x61ger\x18\x03 \x01(\x08R\x05\x65\x61ger"\xe8\x03\n\x15MergeIntoTableCommand\x12*\n\x11target_table_name\x18\x01 \x01(\tR\x0ftargetTableName\x12\x43\n\x11source_table_plan\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x0fsourceTablePlan\x12\x42\n\x0fmerge_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0emergeCondition\x12>\n\rmatch_actions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0cmatchActions\x12I\n\x13not_matched_actions\x18\x05 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x11notMatchedActions\x12[\n\x1dnot_matched_by_source_actions\x18\x06 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x19notMatchedBySourceActions\x12\x32\n\x15with_schema_evolution\x18\x07 \x01(\x08R\x13withSchemaEvolution*\x85\x01\n\x17StreamingQueryEventType\x12\x1e\n\x1aQUERY_PROGRESS_UNSPECIFIED\x10\x00\x12\x18\n\x14QUERY_PROGRESS_EVENT\x10\x01\x12\x1a\n\x16QUERY_TERMINATED_EVENT\x10\x02\x12\x14\n\x10QUERY_IDLE_EVENT\x10\x03\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\x90\r\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x18\x03 \x01(\x0b\x32).spark.connect.CreateDataFrameViewCommandH\x00R\x13\x63reateDataframeView\x12O\n\x12write_operation_v2\x18\x04 \x01(\x0b\x32\x1f.spark.connect.WriteOperationV2H\x00R\x10writeOperationV2\x12<\n\x0bsql_command\x18\x05 \x01(\x0b\x32\x19.spark.connect.SqlCommandH\x00R\nsqlCommand\x12k\n\x1cwrite_stream_operation_start\x18\x06 \x01(\x0b\x32(.spark.connect.WriteStreamOperationStartH\x00R\x19writeStreamOperationStart\x12^\n\x17streaming_query_command\x18\x07 \x01(\x0b\x32$.spark.connect.StreamingQueryCommandH\x00R\x15streamingQueryCommand\x12X\n\x15get_resources_command\x18\x08 \x01(\x0b\x32".spark.connect.GetResourcesCommandH\x00R\x13getResourcesCommand\x12t\n\x1fstreaming_query_manager_command\x18\t \x01(\x0b\x32+.spark.connect.StreamingQueryManagerCommandH\x00R\x1cstreamingQueryManagerCommand\x12m\n\x17register_table_function\x18\n \x01(\x0b\x32\x33.spark.connect.CommonInlineUserDefinedTableFunctionH\x00R\x15registerTableFunction\x12\x81\x01\n$streaming_query_listener_bus_command\x18\x0b \x01(\x0b\x32/.spark.connect.StreamingQueryListenerBusCommandH\x00R streamingQueryListenerBusCommand\x12\x64\n\x14register_data_source\x18\x0c \x01(\x0b\x32\x30.spark.connect.CommonInlineUserDefinedDataSourceH\x00R\x12registerDataSource\x12t\n\x1f\x63reate_resource_profile_command\x18\r \x01(\x0b\x32+.spark.connect.CreateResourceProfileCommandH\x00R\x1c\x63reateResourceProfileCommand\x12Q\n\x12\x63heckpoint_command\x18\x0e \x01(\x0b\x32 .spark.connect.CheckpointCommandH\x00R\x11\x63heckpointCommand\x12\x84\x01\n%remove_cached_remote_relation_command\x18\x0f \x01(\x0b\x32\x30.spark.connect.RemoveCachedRemoteRelationCommandH\x00R!removeCachedRemoteRelationCommand\x12_\n\x18merge_into_table_command\x18\x10 \x01(\x0b\x32$.spark.connect.MergeIntoTableCommandH\x00R\x15mergeIntoTableCommand\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x0e\n\x0c\x63ommand_type"\xaa\x04\n\nSqlCommand\x12\x14\n\x03sql\x18\x01 \x01(\tB\x02\x18\x01R\x03sql\x12;\n\x04\x61rgs\x18\x02 \x03(\x0b\x32#.spark.connect.SqlCommand.ArgsEntryB\x02\x18\x01R\x04\x61rgs\x12@\n\x08pos_args\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralB\x02\x18\x01R\x07posArgs\x12Z\n\x0fnamed_arguments\x18\x04 \x03(\x0b\x32-.spark.connect.SqlCommand.NamedArgumentsEntryB\x02\x18\x01R\x0enamedArguments\x12\x42\n\rpos_arguments\x18\x05 \x03(\x0b\x32\x19.spark.connect.ExpressionB\x02\x18\x01R\x0cposArguments\x12-\n\x05input\x18\x06 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x1aZ\n\tArgsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value:\x02\x38\x01\x1a\\\n\x13NamedArgumentsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05value:\x02\x38\x01"\x96\x01\n\x1a\x43reateDataFrameViewCommand\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x1b\n\tis_global\x18\x03 \x01(\x08R\x08isGlobal\x12\x18\n\x07replace\x18\x04 \x01(\x08R\x07replace"\xca\x08\n\x0eWriteOperation\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1b\n\x06source\x18\x02 \x01(\tH\x01R\x06source\x88\x01\x01\x12\x14\n\x04path\x18\x03 \x01(\tH\x00R\x04path\x12?\n\x05table\x18\x04 \x01(\x0b\x32\'.spark.connect.WriteOperation.SaveTableH\x00R\x05table\x12:\n\x04mode\x18\x05 \x01(\x0e\x32&.spark.connect.WriteOperation.SaveModeR\x04mode\x12*\n\x11sort_column_names\x18\x06 \x03(\tR\x0fsortColumnNames\x12\x31\n\x14partitioning_columns\x18\x07 \x03(\tR\x13partitioningColumns\x12\x43\n\tbucket_by\x18\x08 \x01(\x0b\x32&.spark.connect.WriteOperation.BucketByR\x08\x62ucketBy\x12\x44\n\x07options\x18\t \x03(\x0b\x32*.spark.connect.WriteOperation.OptionsEntryR\x07options\x12-\n\x12\x63lustering_columns\x18\n \x03(\tR\x11\x63lusteringColumns\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x82\x02\n\tSaveTable\x12\x1d\n\ntable_name\x18\x01 \x01(\tR\ttableName\x12X\n\x0bsave_method\x18\x02 \x01(\x0e\x32\x37.spark.connect.WriteOperation.SaveTable.TableSaveMethodR\nsaveMethod"|\n\x0fTableSaveMethod\x12!\n\x1dTABLE_SAVE_METHOD_UNSPECIFIED\x10\x00\x12#\n\x1fTABLE_SAVE_METHOD_SAVE_AS_TABLE\x10\x01\x12!\n\x1dTABLE_SAVE_METHOD_INSERT_INTO\x10\x02\x1a[\n\x08\x42ucketBy\x12.\n\x13\x62ucket_column_names\x18\x01 \x03(\tR\x11\x62ucketColumnNames\x12\x1f\n\x0bnum_buckets\x18\x02 \x01(\x05R\nnumBuckets"\x89\x01\n\x08SaveMode\x12\x19\n\x15SAVE_MODE_UNSPECIFIED\x10\x00\x12\x14\n\x10SAVE_MODE_APPEND\x10\x01\x12\x17\n\x13SAVE_MODE_OVERWRITE\x10\x02\x12\x1d\n\x19SAVE_MODE_ERROR_IF_EXISTS\x10\x03\x12\x14\n\x10SAVE_MODE_IGNORE\x10\x04\x42\x0b\n\tsave_typeB\t\n\x07_source"\xdc\x06\n\x10WriteOperationV2\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1d\n\ntable_name\x18\x02 \x01(\tR\ttableName\x12\x1f\n\x08provider\x18\x03 \x01(\tH\x00R\x08provider\x88\x01\x01\x12L\n\x14partitioning_columns\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13partitioningColumns\x12\x46\n\x07options\x18\x05 \x03(\x0b\x32,.spark.connect.WriteOperationV2.OptionsEntryR\x07options\x12_\n\x10table_properties\x18\x06 \x03(\x0b\x32\x34.spark.connect.WriteOperationV2.TablePropertiesEntryR\x0ftableProperties\x12\x38\n\x04mode\x18\x07 \x01(\x0e\x32$.spark.connect.WriteOperationV2.ModeR\x04mode\x12J\n\x13overwrite_condition\x18\x08 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x12overwriteCondition\x12-\n\x12\x63lustering_columns\x18\t \x03(\tR\x11\x63lusteringColumns\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"\x9f\x01\n\x04Mode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\x0f\n\x0bMODE_CREATE\x10\x01\x12\x12\n\x0eMODE_OVERWRITE\x10\x02\x12\x1d\n\x19MODE_OVERWRITE_PARTITIONS\x10\x03\x12\x0f\n\x0bMODE_APPEND\x10\x04\x12\x10\n\x0cMODE_REPLACE\x10\x05\x12\x1a\n\x16MODE_CREATE_OR_REPLACE\x10\x06\x42\x0b\n\t_provider"\xd8\x06\n\x19WriteStreamOperationStart\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06\x66ormat\x18\x02 \x01(\tR\x06\x66ormat\x12O\n\x07options\x18\x03 \x03(\x0b\x32\x35.spark.connect.WriteStreamOperationStart.OptionsEntryR\x07options\x12:\n\x19partitioning_column_names\x18\x04 \x03(\tR\x17partitioningColumnNames\x12:\n\x18processing_time_interval\x18\x05 \x01(\tH\x00R\x16processingTimeInterval\x12%\n\ravailable_now\x18\x06 \x01(\x08H\x00R\x0c\x61vailableNow\x12\x14\n\x04once\x18\x07 \x01(\x08H\x00R\x04once\x12\x46\n\x1e\x63ontinuous_checkpoint_interval\x18\x08 \x01(\tH\x00R\x1c\x63ontinuousCheckpointInterval\x12\x1f\n\x0boutput_mode\x18\t \x01(\tR\noutputMode\x12\x1d\n\nquery_name\x18\n \x01(\tR\tqueryName\x12\x14\n\x04path\x18\x0b \x01(\tH\x01R\x04path\x12\x1f\n\ntable_name\x18\x0c \x01(\tH\x01R\ttableName\x12N\n\x0e\x66oreach_writer\x18\r \x01(\x0b\x32\'.spark.connect.StreamingForeachFunctionR\rforeachWriter\x12L\n\rforeach_batch\x18\x0e \x01(\x0b\x32\'.spark.connect.StreamingForeachFunctionR\x0c\x66oreachBatch\x12\x36\n\x17\x63lustering_column_names\x18\x0f \x03(\tR\x15\x63lusteringColumnNames\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07triggerB\x12\n\x10sink_destination"\xb3\x01\n\x18StreamingForeachFunction\x12\x43\n\x0fpython_function\x18\x01 \x01(\x0b\x32\x18.spark.connect.PythonUDFH\x00R\x0epythonFunction\x12\x46\n\x0escala_function\x18\x02 \x01(\x0b\x32\x1d.spark.connect.ScalarScalaUDFH\x00R\rscalaFunctionB\n\n\x08\x66unction"\xd4\x01\n\x1fWriteStreamOperationStartResult\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12<\n\x18query_started_event_json\x18\x03 \x01(\tH\x00R\x15queryStartedEventJson\x88\x01\x01\x42\x1b\n\x19_query_started_event_json"A\n\x18StreamingQueryInstanceId\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x15\n\x06run_id\x18\x02 \x01(\tR\x05runId"\xf8\x04\n\x15StreamingQueryCommand\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12\x18\n\x06status\x18\x02 \x01(\x08H\x00R\x06status\x12%\n\rlast_progress\x18\x03 \x01(\x08H\x00R\x0clastProgress\x12)\n\x0frecent_progress\x18\x04 \x01(\x08H\x00R\x0erecentProgress\x12\x14\n\x04stop\x18\x05 \x01(\x08H\x00R\x04stop\x12\x34\n\x15process_all_available\x18\x06 \x01(\x08H\x00R\x13processAllAvailable\x12O\n\x07\x65xplain\x18\x07 \x01(\x0b\x32\x33.spark.connect.StreamingQueryCommand.ExplainCommandH\x00R\x07\x65xplain\x12\x1e\n\texception\x18\x08 \x01(\x08H\x00R\texception\x12k\n\x11\x61wait_termination\x18\t \x01(\x0b\x32<.spark.connect.StreamingQueryCommand.AwaitTerminationCommandH\x00R\x10\x61waitTermination\x1a,\n\x0e\x45xplainCommand\x12\x1a\n\x08\x65xtended\x18\x01 \x01(\x08R\x08\x65xtended\x1aL\n\x17\x41waitTerminationCommand\x12"\n\ntimeout_ms\x18\x02 \x01(\x03H\x00R\ttimeoutMs\x88\x01\x01\x42\r\n\x0b_timeout_msB\t\n\x07\x63ommand"\xf5\x08\n\x1bStreamingQueryCommandResult\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12Q\n\x06status\x18\x02 \x01(\x0b\x32\x37.spark.connect.StreamingQueryCommandResult.StatusResultH\x00R\x06status\x12j\n\x0frecent_progress\x18\x03 \x01(\x0b\x32?.spark.connect.StreamingQueryCommandResult.RecentProgressResultH\x00R\x0erecentProgress\x12T\n\x07\x65xplain\x18\x04 \x01(\x0b\x32\x38.spark.connect.StreamingQueryCommandResult.ExplainResultH\x00R\x07\x65xplain\x12Z\n\texception\x18\x05 \x01(\x0b\x32:.spark.connect.StreamingQueryCommandResult.ExceptionResultH\x00R\texception\x12p\n\x11\x61wait_termination\x18\x06 \x01(\x0b\x32\x41.spark.connect.StreamingQueryCommandResult.AwaitTerminationResultH\x00R\x10\x61waitTermination\x1a\xaa\x01\n\x0cStatusResult\x12%\n\x0estatus_message\x18\x01 \x01(\tR\rstatusMessage\x12*\n\x11is_data_available\x18\x02 \x01(\x08R\x0fisDataAvailable\x12*\n\x11is_trigger_active\x18\x03 \x01(\x08R\x0fisTriggerActive\x12\x1b\n\tis_active\x18\x04 \x01(\x08R\x08isActive\x1aH\n\x14RecentProgressResult\x12\x30\n\x14recent_progress_json\x18\x05 \x03(\tR\x12recentProgressJson\x1a\'\n\rExplainResult\x12\x16\n\x06result\x18\x01 \x01(\tR\x06result\x1a\xc5\x01\n\x0f\x45xceptionResult\x12\x30\n\x11\x65xception_message\x18\x01 \x01(\tH\x00R\x10\x65xceptionMessage\x88\x01\x01\x12$\n\x0b\x65rror_class\x18\x02 \x01(\tH\x01R\nerrorClass\x88\x01\x01\x12$\n\x0bstack_trace\x18\x03 \x01(\tH\x02R\nstackTrace\x88\x01\x01\x42\x14\n\x12_exception_messageB\x0e\n\x0c_error_classB\x0e\n\x0c_stack_trace\x1a\x38\n\x16\x41waitTerminationResult\x12\x1e\n\nterminated\x18\x01 \x01(\x08R\nterminatedB\r\n\x0bresult_type"\xbd\x06\n\x1cStreamingQueryManagerCommand\x12\x18\n\x06\x61\x63tive\x18\x01 \x01(\x08H\x00R\x06\x61\x63tive\x12\x1d\n\tget_query\x18\x02 \x01(\tH\x00R\x08getQuery\x12|\n\x15\x61wait_any_termination\x18\x03 \x01(\x0b\x32\x46.spark.connect.StreamingQueryManagerCommand.AwaitAnyTerminationCommandH\x00R\x13\x61waitAnyTermination\x12+\n\x10reset_terminated\x18\x04 \x01(\x08H\x00R\x0fresetTerminated\x12n\n\x0c\x61\x64\x64_listener\x18\x05 \x01(\x0b\x32I.spark.connect.StreamingQueryManagerCommand.StreamingQueryListenerCommandH\x00R\x0b\x61\x64\x64Listener\x12t\n\x0fremove_listener\x18\x06 \x01(\x0b\x32I.spark.connect.StreamingQueryManagerCommand.StreamingQueryListenerCommandH\x00R\x0eremoveListener\x12\'\n\x0elist_listeners\x18\x07 \x01(\x08H\x00R\rlistListeners\x1aO\n\x1a\x41waitAnyTerminationCommand\x12"\n\ntimeout_ms\x18\x01 \x01(\x03H\x00R\ttimeoutMs\x88\x01\x01\x42\r\n\x0b_timeout_ms\x1a\xcd\x01\n\x1dStreamingQueryListenerCommand\x12)\n\x10listener_payload\x18\x01 \x01(\x0cR\x0flistenerPayload\x12U\n\x17python_listener_payload\x18\x02 \x01(\x0b\x32\x18.spark.connect.PythonUDFH\x00R\x15pythonListenerPayload\x88\x01\x01\x12\x0e\n\x02id\x18\x03 \x01(\tR\x02idB\x1a\n\x18_python_listener_payloadB\t\n\x07\x63ommand"\xb4\x08\n"StreamingQueryManagerCommandResult\x12X\n\x06\x61\x63tive\x18\x01 \x01(\x0b\x32>.spark.connect.StreamingQueryManagerCommandResult.ActiveResultH\x00R\x06\x61\x63tive\x12`\n\x05query\x18\x02 \x01(\x0b\x32H.spark.connect.StreamingQueryManagerCommandResult.StreamingQueryInstanceH\x00R\x05query\x12\x81\x01\n\x15\x61wait_any_termination\x18\x03 \x01(\x0b\x32K.spark.connect.StreamingQueryManagerCommandResult.AwaitAnyTerminationResultH\x00R\x13\x61waitAnyTermination\x12+\n\x10reset_terminated\x18\x04 \x01(\x08H\x00R\x0fresetTerminated\x12#\n\x0c\x61\x64\x64_listener\x18\x05 \x01(\x08H\x00R\x0b\x61\x64\x64Listener\x12)\n\x0fremove_listener\x18\x06 \x01(\x08H\x00R\x0eremoveListener\x12{\n\x0elist_listeners\x18\x07 \x01(\x0b\x32R.spark.connect.StreamingQueryManagerCommandResult.ListStreamingQueryListenerResultH\x00R\rlistListeners\x1a\x7f\n\x0c\x41\x63tiveResult\x12o\n\x0e\x61\x63tive_queries\x18\x01 \x03(\x0b\x32H.spark.connect.StreamingQueryManagerCommandResult.StreamingQueryInstanceR\ractiveQueries\x1as\n\x16StreamingQueryInstance\x12\x37\n\x02id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x02id\x12\x17\n\x04name\x18\x02 \x01(\tH\x00R\x04name\x88\x01\x01\x42\x07\n\x05_name\x1a;\n\x19\x41waitAnyTerminationResult\x12\x1e\n\nterminated\x18\x01 \x01(\x08R\nterminated\x1aK\n\x1eStreamingQueryListenerInstance\x12)\n\x10listener_payload\x18\x01 \x01(\x0cR\x0flistenerPayload\x1a\x45\n ListStreamingQueryListenerResult\x12!\n\x0clistener_ids\x18\x01 \x03(\tR\x0blistenerIdsB\r\n\x0bresult_type"\xad\x01\n StreamingQueryListenerBusCommand\x12;\n\x19\x61\x64\x64_listener_bus_listener\x18\x01 \x01(\x08H\x00R\x16\x61\x64\x64ListenerBusListener\x12\x41\n\x1cremove_listener_bus_listener\x18\x02 \x01(\x08H\x00R\x19removeListenerBusListenerB\t\n\x07\x63ommand"\x83\x01\n\x1bStreamingQueryListenerEvent\x12\x1d\n\nevent_json\x18\x01 \x01(\tR\teventJson\x12\x45\n\nevent_type\x18\x02 \x01(\x0e\x32&.spark.connect.StreamingQueryEventTypeR\teventType"\xcc\x01\n"StreamingQueryListenerEventsResult\x12\x42\n\x06\x65vents\x18\x01 \x03(\x0b\x32*.spark.connect.StreamingQueryListenerEventR\x06\x65vents\x12\x42\n\x1blistener_bus_listener_added\x18\x02 \x01(\x08H\x00R\x18listenerBusListenerAdded\x88\x01\x01\x42\x1e\n\x1c_listener_bus_listener_added"\x15\n\x13GetResourcesCommand"\xd4\x01\n\x19GetResourcesCommandResult\x12U\n\tresources\x18\x01 \x03(\x0b\x32\x37.spark.connect.GetResourcesCommandResult.ResourcesEntryR\tresources\x1a`\n\x0eResourcesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x38\n\x05value\x18\x02 \x01(\x0b\x32".spark.connect.ResourceInformationR\x05value:\x02\x38\x01"X\n\x1c\x43reateResourceProfileCommand\x12\x38\n\x07profile\x18\x01 \x01(\x0b\x32\x1e.spark.connect.ResourceProfileR\x07profile"C\n"CreateResourceProfileCommandResult\x12\x1d\n\nprofile_id\x18\x01 \x01(\x05R\tprofileId"d\n!RemoveCachedRemoteRelationCommand\x12?\n\x08relation\x18\x01 \x01(\x0b\x32#.spark.connect.CachedRemoteRelationR\x08relation"\xcd\x01\n\x11\x43heckpointCommand\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x14\n\x05local\x18\x02 \x01(\x08R\x05local\x12\x14\n\x05\x65\x61ger\x18\x03 \x01(\x08R\x05\x65\x61ger\x12\x45\n\rstorage_level\x18\x04 \x01(\x0b\x32\x1b.spark.connect.StorageLevelH\x00R\x0cstorageLevel\x88\x01\x01\x42\x10\n\x0e_storage_level"\xe8\x03\n\x15MergeIntoTableCommand\x12*\n\x11target_table_name\x18\x01 \x01(\tR\x0ftargetTableName\x12\x43\n\x11source_table_plan\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x0fsourceTablePlan\x12\x42\n\x0fmerge_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0emergeCondition\x12>\n\rmatch_actions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0cmatchActions\x12I\n\x13not_matched_actions\x18\x05 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x11notMatchedActions\x12[\n\x1dnot_matched_by_source_actions\x18\x06 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x19notMatchedBySourceActions\x12\x32\n\x15with_schema_evolution\x18\x07 \x01(\x08R\x13withSchemaEvolution*\x85\x01\n\x17StreamingQueryEventType\x12\x1e\n\x1aQUERY_PROGRESS_UNSPECIFIED\x10\x00\x12\x18\n\x14QUERY_PROGRESS_EVENT\x10\x01\x12\x1a\n\x16QUERY_TERMINATED_EVENT\x10\x02\x12\x14\n\x10QUERY_IDLE_EVENT\x10\x03\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -71,8 +71,8 @@ _WRITESTREAMOPERATIONSTART_OPTIONSENTRY._serialized_options = b"8\001" _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._options = None _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_options = b"8\001" - _STREAMINGQUERYEVENTTYPE._serialized_start = 11162 - _STREAMINGQUERYEVENTTYPE._serialized_end = 11295 + _STREAMINGQUERYEVENTTYPE._serialized_start = 11252 + _STREAMINGQUERYEVENTTYPE._serialized_end = 11385 _COMMAND._serialized_start = 167 _COMMAND._serialized_end = 1847 _SQLCOMMAND._serialized_start = 1850 @@ -167,8 +167,8 @@ _CREATERESOURCEPROFILECOMMANDRESULT._serialized_end = 10448 _REMOVECACHEDREMOTERELATIONCOMMAND._serialized_start = 10450 _REMOVECACHEDREMOTERELATIONCOMMAND._serialized_end = 10550 - _CHECKPOINTCOMMAND._serialized_start = 10552 - _CHECKPOINTCOMMAND._serialized_end = 10668 - _MERGEINTOTABLECOMMAND._serialized_start = 10671 - _MERGEINTOTABLECOMMAND._serialized_end = 11159 + _CHECKPOINTCOMMAND._serialized_start = 10553 + _CHECKPOINTCOMMAND._serialized_end = 10758 + _MERGEINTOTABLECOMMAND._serialized_start = 10761 + _MERGEINTOTABLECOMMAND._serialized_end = 11249 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi b/python/pyspark/sql/connect/proto/commands_pb2.pyi index 2dedcdfc8e3e4..6192a29607cbf 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.pyi +++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi @@ -2188,6 +2188,7 @@ class CheckpointCommand(google.protobuf.message.Message): RELATION_FIELD_NUMBER: builtins.int LOCAL_FIELD_NUMBER: builtins.int EAGER_FIELD_NUMBER: builtins.int + STORAGE_LEVEL_FIELD_NUMBER: builtins.int @property def relation(self) -> pyspark.sql.connect.proto.relations_pb2.Relation: """(Required) The logical plan to checkpoint.""" @@ -2197,22 +2198,46 @@ class CheckpointCommand(google.protobuf.message.Message): """ eager: builtins.bool """(Required) Whether to checkpoint this dataframe immediately.""" + @property + def storage_level(self) -> pyspark.sql.connect.proto.common_pb2.StorageLevel: + """(Optional) For local checkpoint, the storage level to use.""" def __init__( self, *, relation: pyspark.sql.connect.proto.relations_pb2.Relation | None = ..., local: builtins.bool = ..., eager: builtins.bool = ..., + storage_level: pyspark.sql.connect.proto.common_pb2.StorageLevel | None = ..., ) -> None: ... def HasField( - self, field_name: typing_extensions.Literal["relation", b"relation"] + self, + field_name: typing_extensions.Literal[ + "_storage_level", + b"_storage_level", + "relation", + b"relation", + "storage_level", + b"storage_level", + ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "eager", b"eager", "local", b"local", "relation", b"relation" + "_storage_level", + b"_storage_level", + "eager", + b"eager", + "local", + b"local", + "relation", + b"relation", + "storage_level", + b"storage_level", ], ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_storage_level", b"_storage_level"] + ) -> typing_extensions.Literal["storage_level"] | None: ... global___CheckpointCommand = CheckpointCommand diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ef35b73332572..62f2129e5be62 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -17,6 +17,8 @@ # mypy: disable-error-code="empty-body" +import sys +import random from typing import ( Any, Callable, @@ -43,6 +45,7 @@ from pyspark.sql.types import StructType, Row from pyspark.sql.utils import dispatch_df_method + if TYPE_CHECKING: from py4j.java_gateway import JavaObject import pyarrow as pa @@ -65,6 +68,7 @@ ArrowMapIterFunction, DataFrameLike as PandasDataFrameLike, ) + from pyspark.sql.plot import PySparkPlotAccessor from pyspark.sql.metrics import ExecutionInfo @@ -1013,7 +1017,9 @@ def checkpoint(self, eager: bool = True) -> "DataFrame": """ ... - def localCheckpoint(self, eager: bool = True) -> "DataFrame": + def localCheckpoint( + self, eager: bool = True, storageLevel: Optional[StorageLevel] = None + ) -> "DataFrame": """Returns a locally checkpointed version of this :class:`DataFrame`. Checkpointing can be used to truncate the logical plan of this :class:`DataFrame`, which is especially useful in iterative algorithms where the plan may grow exponentially. Local checkpoints @@ -1024,12 +1030,17 @@ def localCheckpoint(self, eager: bool = True) -> "DataFrame": .. versionchanged:: 4.0.0 Supports Spark Connect. + Added storageLevel parameter. Parameters ---------- eager : bool, optional, default True Whether to checkpoint this :class:`DataFrame` immediately. + storageLevel : :class:`StorageLevel`, optional, default None + The StorageLevel with which the checkpoint will be stored. + If not specified, default for RDD local checkpoints. + Returns ------- :class:`DataFrame` @@ -2038,6 +2049,46 @@ def sample( """ ... + def _preapare_args_for_sample( + self, + withReplacement: Optional[Union[float, bool]] = None, + fraction: Optional[Union[int, float]] = None, + seed: Optional[int] = None, + ) -> Tuple[bool, float, int]: + from pyspark.errors import PySparkTypeError + + if isinstance(withReplacement, bool) and isinstance(fraction, float): + # For the cases below: + # sample(True, 0.5 [, seed]) + # sample(True, fraction=0.5 [, seed]) + # sample(withReplacement=False, fraction=0.5 [, seed]) + _seed = int(seed) if seed is not None else random.randint(0, sys.maxsize) + return withReplacement, fraction, _seed + + elif withReplacement is None and isinstance(fraction, float): + # For the case below: + # sample(faction=0.5 [, seed]) + _seed = int(seed) if seed is not None else random.randint(0, sys.maxsize) + return False, fraction, _seed + + elif isinstance(withReplacement, float): + # For the case below: + # sample(0.5 [, seed]) + _seed = int(fraction) if fraction is not None else random.randint(0, sys.maxsize) + _fraction = float(withReplacement) + return False, _fraction, _seed + + else: + argtypes = [type(arg).__name__ for arg in [withReplacement, fraction, seed]] + raise PySparkTypeError( + errorClass="NOT_BOOL_OR_FLOAT_OR_INT", + messageParameters={ + "arg_name": "withReplacement (optional), " + + "fraction (required) and seed (optional)", + "arg_type": ", ".join(argtypes), + }, + ) + @dispatch_df_method def sampleBy( self, col: "ColumnOrName", fractions: Dict[Any, float], seed: Optional[int] = None @@ -2889,6 +2940,62 @@ def sort( """ ... + def _preapare_cols_for_sort( + self, + _to_col: Callable[[str], Column], + cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]], + kwargs: Dict[str, Any], + ) -> Sequence[Column]: + from pyspark.errors import PySparkTypeError, PySparkValueError, PySparkIndexError + + if not cols: + raise PySparkValueError( + errorClass="CANNOT_BE_EMPTY", messageParameters={"item": "cols"} + ) + + if len(cols) == 1 and isinstance(cols[0], list): + cols = cols[0] + + _cols: List[Column] = [] + for c in cols: + if isinstance(c, int) and not isinstance(c, bool): + # ordinal is 1-based + if c > 0: + _cols.append(self[c - 1]) + # negative ordinal means sort by desc + elif c < 0: + _cols.append(self[-c - 1].desc()) + else: + raise PySparkIndexError( + errorClass="ZERO_INDEX", + messageParameters={}, + ) + elif isinstance(c, Column): + _cols.append(c) + elif isinstance(c, str): + _cols.append(_to_col(c)) + else: + raise PySparkTypeError( + errorClass="NOT_COLUMN_OR_INT_OR_STR", + messageParameters={ + "arg_name": "col", + "arg_type": type(c).__name__, + }, + ) + + ascending = kwargs.get("ascending", True) + if isinstance(ascending, (bool, int)): + if not ascending: + _cols = [c.desc() for c in _cols] + elif isinstance(ascending, list): + _cols = [c if asc else c.desc() for asc, c in zip(ascending, _cols)] + else: + raise PySparkTypeError( + errorClass="NOT_COLUMN_OR_INT_OR_STR", + messageParameters={"arg_name": "ascending", "arg_type": type(ascending).__name__}, + ) + return _cols + orderBy = sort @dispatch_df_method @@ -3349,7 +3456,7 @@ def selectExpr(self, *expr: Union[str, List[str]]) -> "DataFrame": ... @dispatch_df_method - def filter(self, condition: "ColumnOrName") -> "DataFrame": + def filter(self, condition: Union[Column, str]) -> "DataFrame": """Filters rows using the given condition. :func:`where` is an alias for :func:`filter`. @@ -5900,7 +6007,7 @@ def inputFiles(self) -> List[str]: ... @dispatch_df_method - def where(self, condition: "ColumnOrName") -> "DataFrame": + def where(self, condition: Union[Column, str]) -> "DataFrame": """ :func:`where` is an alias for :func:`filter`. @@ -6394,6 +6501,32 @@ def executionInfo(self) -> Optional["ExecutionInfo"]: """ ... + @property + def plot(self) -> "PySparkPlotAccessor": + """ + Returns a :class:`PySparkPlotAccessor` for plotting functions. + + .. versionadded:: 4.0.0 + + Returns + ------- + :class:`PySparkPlotAccessor` + + Notes + ----- + This API is experimental. + + Examples + -------- + >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)] + >>> columns = ["category", "int_val", "float_val"] + >>> df = spark.createDataFrame(data, columns) + >>> type(df.plot) + + >>> df.plot.line(x="category", y=["int_val", "float_val"]) # doctest: +SKIP + """ + ... + class DataFrameNaFunctions: """Functionality for working with missing data in :class:`DataFrame`. diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 781bf3d9f83a2..b75d1b2f59faf 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -4921,44 +4921,44 @@ def array_agg(col: "ColumnOrName") -> Column: >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([[1],[1],[2]], ["c"]) >>> df.agg(sf.sort_array(sf.array_agg('c'))).show() - +---------------------------------+ - |sort_array(collect_list(c), true)| - +---------------------------------+ - | [1, 1, 2]| - +---------------------------------+ + +------------------------------+ + |sort_array(array_agg(c), true)| + +------------------------------+ + | [1, 1, 2]| + +------------------------------+ Example 2: Using array_agg function on a string column >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([["apple"],["apple"],["banana"]], ["c"]) >>> df.agg(sf.sort_array(sf.array_agg('c'))).show(truncate=False) - +---------------------------------+ - |sort_array(collect_list(c), true)| - +---------------------------------+ - |[apple, apple, banana] | - +---------------------------------+ + +------------------------------+ + |sort_array(array_agg(c), true)| + +------------------------------+ + |[apple, apple, banana] | + +------------------------------+ Example 3: Using array_agg function on a column with null values >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([[1],[None],[2]], ["c"]) >>> df.agg(sf.sort_array(sf.array_agg('c'))).show() - +---------------------------------+ - |sort_array(collect_list(c), true)| - +---------------------------------+ - | [1, 2]| - +---------------------------------+ + +------------------------------+ + |sort_array(array_agg(c), true)| + +------------------------------+ + | [1, 2]| + +------------------------------+ Example 4: Using array_agg function on a column with different data types >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([[1],["apple"],[2]], ["c"]) >>> df.agg(sf.sort_array(sf.array_agg('c'))).show() - +---------------------------------+ - |sort_array(collect_list(c), true)| - +---------------------------------+ - | [1, 2, apple]| - +---------------------------------+ + +------------------------------+ + |sort_array(array_agg(c), true)| + +------------------------------+ + | [1, 2, apple]| + +------------------------------+ """ return _invoke_function_over_columns("array_agg", col) @@ -6015,9 +6015,9 @@ def grouping_id(*cols: "ColumnOrName") -> Column: @_try_remote_functions def count_min_sketch( col: "ColumnOrName", - eps: "ColumnOrName", - confidence: "ColumnOrName", - seed: "ColumnOrName", + eps: Union[Column, float], + confidence: Union[Column, float], + seed: Optional[Union[Column, int]] = None, ) -> Column: """ Returns a count-min sketch of a column with the given esp, confidence and seed. @@ -6031,13 +6031,24 @@ def count_min_sketch( ---------- col : :class:`~pyspark.sql.Column` or str target column to compute on. - eps : :class:`~pyspark.sql.Column` or str + eps : :class:`~pyspark.sql.Column` or float relative error, must be positive - confidence : :class:`~pyspark.sql.Column` or str + + .. versionchanged:: 4.0.0 + `eps` now accepts float value. + + confidence : :class:`~pyspark.sql.Column` or float confidence, must be positive and less than 1.0 - seed : :class:`~pyspark.sql.Column` or str + + .. versionchanged:: 4.0.0 + `confidence` now accepts float value. + + seed : :class:`~pyspark.sql.Column` or int, optional random seed + .. versionchanged:: 4.0.0 + `seed` now accepts int value. + Returns ------- :class:`~pyspark.sql.Column` @@ -6045,12 +6056,60 @@ def count_min_sketch( Examples -------- - >>> df = spark.createDataFrame([[1], [2], [1]], ['data']) - >>> df = df.agg(count_min_sketch(df.data, lit(0.5), lit(0.5), lit(1)).alias('sketch')) - >>> df.select(hex(df.sketch).alias('r')).collect() - [Row(r='0000000100000000000000030000000100000004000000005D8D6AB90000000000000000000000000000000200000000000000010000000000000000')] - """ - return _invoke_function_over_columns("count_min_sketch", col, eps, confidence, seed) + Example 1: Using columns as arguments + + >>> from pyspark.sql import functions as sf + >>> spark.range(100).select( + ... sf.hex(sf.count_min_sketch(sf.col("id"), sf.lit(3.0), sf.lit(0.1), sf.lit(1))) + ... ).show(truncate=False) + +------------------------------------------------------------------------+ + |hex(count_min_sketch(id, 3.0, 0.1, 1)) | + +------------------------------------------------------------------------+ + |0000000100000000000000640000000100000001000000005D8D6AB90000000000000064| + +------------------------------------------------------------------------+ + + Example 2: Using numbers as arguments + + >>> from pyspark.sql import functions as sf + >>> spark.range(100).select( + ... sf.hex(sf.count_min_sketch("id", 1.0, 0.3, 2)) + ... ).show(truncate=False) + +----------------------------------------------------------------------------------------+ + |hex(count_min_sketch(id, 1.0, 0.3, 2)) | + +----------------------------------------------------------------------------------------+ + |0000000100000000000000640000000100000002000000005D96391C00000000000000320000000000000032| + +----------------------------------------------------------------------------------------+ + + Example 3: Using a long seed + + >>> from pyspark.sql import functions as sf + >>> spark.range(100).select( + ... sf.hex(sf.count_min_sketch("id", sf.lit(1.5), 0.2, 1111111111111111111)) + ... ).show(truncate=False) + +----------------------------------------------------------------------------------------+ + |hex(count_min_sketch(id, 1.5, 0.2, 1111111111111111111)) | + +----------------------------------------------------------------------------------------+ + |00000001000000000000006400000001000000020000000044078BA100000000000000320000000000000032| + +----------------------------------------------------------------------------------------+ + + Example 4: Using a random seed + + >>> from pyspark.sql import functions as sf + >>> spark.range(100).select( + ... sf.hex(sf.count_min_sketch("id", sf.lit(1.5), 0.6)) + ... ).show(truncate=False) # doctest: +SKIP + +----------------------------------------------------------------------------------------------------------------------------------------+ + |hex(count_min_sketch(id, 1.5, 0.6, 2120704260)) | + +----------------------------------------------------------------------------------------------------------------------------------------+ + |0000000100000000000000640000000200000002000000005ADECCEE00000000153EBE090000000000000033000000000000003100000000000000320000000000000032| + +----------------------------------------------------------------------------------------------------------------------------------------+ + """ # noqa: E501 + _eps = lit(eps) + _conf = lit(confidence) + if seed is None: + return _invoke_function_over_columns("count_min_sketch", col, _eps, _conf) + else: + return _invoke_function_over_columns("count_min_sketch", col, _eps, _conf, lit(seed)) @_try_remote_functions @@ -8653,31 +8712,31 @@ def dateadd(start: "ColumnOrName", days: Union["ColumnOrName", int]) -> Column: >>> spark.createDataFrame( ... [('2015-04-08', 2,)], ['dt', 'add'] ... ).select(sf.dateadd("dt", 1)).show() - +---------------+ - |date_add(dt, 1)| - +---------------+ - | 2015-04-09| - +---------------+ + +--------------+ + |dateadd(dt, 1)| + +--------------+ + | 2015-04-09| + +--------------+ >>> import pyspark.sql.functions as sf >>> spark.createDataFrame( ... [('2015-04-08', 2,)], ['dt', 'add'] ... ).select(sf.dateadd("dt", sf.lit(2))).show() - +---------------+ - |date_add(dt, 2)| - +---------------+ - | 2015-04-10| - +---------------+ + +--------------+ + |dateadd(dt, 2)| + +--------------+ + | 2015-04-10| + +--------------+ >>> import pyspark.sql.functions as sf >>> spark.createDataFrame( ... [('2015-04-08', 2,)], ['dt', 'add'] ... ).select(sf.dateadd("dt", -1)).show() - +----------------+ - |date_add(dt, -1)| - +----------------+ - | 2015-04-07| - +----------------+ + +---------------+ + |dateadd(dt, -1)| + +---------------+ + | 2015-04-07| + +---------------+ """ days = _enum_to_value(days) days = lit(days) if isinstance(days, int) else days @@ -9044,15 +9103,19 @@ def to_timestamp(col: "ColumnOrName", format: Optional[str] = None) -> Column: :class:`~pyspark.sql.Column` timestamp value as :class:`pyspark.sql.types.TimestampType` type. + See Also + -------- + :meth:`pyspark.sql.functions.try_to_timestamp` + Examples -------- Example 1: Convert string to a timestamp >>> import pyspark.sql.functions as sf >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(sf.try_to_timestamp(df.t).alias('dt')).show() + >>> df.select(sf.to_timestamp(df.t)).show() +-------------------+ - | dt| + | to_timestamp(t)| +-------------------+ |1997-02-28 10:30:00| +-------------------+ @@ -9061,12 +9124,12 @@ def to_timestamp(col: "ColumnOrName", format: Optional[str] = None) -> Column: >>> import pyspark.sql.functions as sf >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(sf.try_to_timestamp(df.t, sf.lit('yyyy-MM-dd HH:mm:ss')).alias('dt')).show() - +-------------------+ - | dt| - +-------------------+ - |1997-02-28 10:30:00| - +-------------------+ + >>> df.select(sf.to_timestamp(df.t, 'yyyy-MM-dd HH:mm:ss')).show() + +------------------------------------+ + |to_timestamp(t, yyyy-MM-dd HH:mm:ss)| + +------------------------------------+ + | 1997-02-28 10:30:00| + +------------------------------------+ """ from pyspark.sql.classic.column import _to_java_column @@ -9092,6 +9155,10 @@ def try_to_timestamp(col: "ColumnOrName", format: Optional["ColumnOrName"] = Non format: str, optional format to use to convert timestamp values. + See Also + -------- + :meth:`pyspark.sql.functions.to_timestamp` + Examples -------- Example 1: Convert string to a timestamp @@ -10276,11 +10343,11 @@ def current_database() -> Column: Examples -------- >>> spark.range(1).select(current_database()).show() - +----------------+ - |current_schema()| - +----------------+ - | default| - +----------------+ + +------------------+ + |current_database()| + +------------------+ + | default| + +------------------+ """ return _invoke_function("current_database") @@ -10846,7 +10913,7 @@ def unbase64(col: "ColumnOrName") -> Column: @_try_remote_functions -def ltrim(col: "ColumnOrName") -> Column: +def ltrim(col: "ColumnOrName", trim: Optional["ColumnOrName"] = None) -> Column: """ Trim the spaces from left end for the specified string value. @@ -10859,6 +10926,10 @@ def ltrim(col: "ColumnOrName") -> Column: ---------- col : :class:`~pyspark.sql.Column` or str target column to work on. + trim : :class:`~pyspark.sql.Column` or str, optional + The trim string characters to trim, the default value is a single space + + .. versionadded:: 4.0.0 Returns ------- @@ -10867,21 +10938,40 @@ def ltrim(col: "ColumnOrName") -> Column: Examples -------- + Example 1: Trim the spaces + + >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([" Spark", "Spark ", " Spark"], "STRING") - >>> df.select(ltrim("value").alias("r")).withColumn("length", length("r")).show() - +-------+------+ - | r|length| - +-------+------+ - | Spark| 5| - |Spark | 7| - | Spark| 5| - +-------+------+ + >>> df.select("*", sf.ltrim("value")).show() + +--------+------------+ + | value|ltrim(value)| + +--------+------------+ + | Spark| Spark| + | Spark | Spark | + | Spark| Spark| + +--------+------------+ + + Example 2: Trim specified characters + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame(["***Spark", "Spark**", "*Spark"], "STRING") + >>> df.select("*", sf.ltrim("value", sf.lit("*"))).show() + +--------+--------------------------+ + | value|TRIM(LEADING * FROM value)| + +--------+--------------------------+ + |***Spark| Spark| + | Spark**| Spark**| + | *Spark| Spark| + +--------+--------------------------+ """ - return _invoke_function_over_columns("ltrim", col) + if trim is not None: + return _invoke_function_over_columns("ltrim", col, trim) + else: + return _invoke_function_over_columns("ltrim", col) @_try_remote_functions -def rtrim(col: "ColumnOrName") -> Column: +def rtrim(col: "ColumnOrName", trim: Optional["ColumnOrName"] = None) -> Column: """ Trim the spaces from right end for the specified string value. @@ -10894,6 +10984,10 @@ def rtrim(col: "ColumnOrName") -> Column: ---------- col : :class:`~pyspark.sql.Column` or str target column to work on. + trim : :class:`~pyspark.sql.Column` or str, optional + The trim string characters to trim, the default value is a single space + + .. versionadded:: 4.0.0 Returns ------- @@ -10902,21 +10996,40 @@ def rtrim(col: "ColumnOrName") -> Column: Examples -------- + Example 1: Trim the spaces + + >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([" Spark", "Spark ", " Spark"], "STRING") - >>> df.select(rtrim("value").alias("r")).withColumn("length", length("r")).show() - +--------+------+ - | r|length| - +--------+------+ - | Spark| 8| - | Spark| 5| - | Spark| 6| - +--------+------+ + >>> df.select("*", sf.rtrim("value")).show() + +--------+------------+ + | value|rtrim(value)| + +--------+------------+ + | Spark| Spark| + | Spark | Spark| + | Spark| Spark| + +--------+------------+ + + Example 2: Trim specified characters + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame(["***Spark", "Spark**", "*Spark"], "STRING") + >>> df.select("*", sf.rtrim("value", sf.lit("*"))).show() + +--------+---------------------------+ + | value|TRIM(TRAILING * FROM value)| + +--------+---------------------------+ + |***Spark| ***Spark| + | Spark**| Spark| + | *Spark| *Spark| + +--------+---------------------------+ """ - return _invoke_function_over_columns("rtrim", col) + if trim is not None: + return _invoke_function_over_columns("rtrim", col, trim) + else: + return _invoke_function_over_columns("rtrim", col) @_try_remote_functions -def trim(col: "ColumnOrName") -> Column: +def trim(col: "ColumnOrName", trim: Optional["ColumnOrName"] = None) -> Column: """ Trim the spaces from both ends for the specified string column. @@ -10929,6 +11042,10 @@ def trim(col: "ColumnOrName") -> Column: ---------- col : :class:`~pyspark.sql.Column` or str target column to work on. + trim : :class:`~pyspark.sql.Column` or str, optional + The trim string characters to trim, the default value is a single space + + .. versionadded:: 4.0.0 Returns ------- @@ -10937,17 +11054,36 @@ def trim(col: "ColumnOrName") -> Column: Examples -------- + Example 1: Trim the spaces + + >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([" Spark", "Spark ", " Spark"], "STRING") - >>> df.select(trim("value").alias("r")).withColumn("length", length("r")).show() - +-----+------+ - | r|length| - +-----+------+ - |Spark| 5| - |Spark| 5| - |Spark| 5| - +-----+------+ + >>> df.select("*", sf.trim("value")).show() + +--------+-----------+ + | value|trim(value)| + +--------+-----------+ + | Spark| Spark| + | Spark | Spark| + | Spark| Spark| + +--------+-----------+ + + Example 2: Trim specified characters + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame(["***Spark", "Spark**", "*Spark"], "STRING") + >>> df.select("*", sf.trim("value", sf.lit("*"))).show() + +--------+-----------------------+ + | value|TRIM(BOTH * FROM value)| + +--------+-----------------------+ + |***Spark| Spark| + | Spark**| Spark| + | *Spark| Spark| + +--------+-----------------------+ """ - return _invoke_function_over_columns("trim", col) + if trim is not None: + return _invoke_function_over_columns("trim", col, trim) + else: + return _invoke_function_over_columns("trim", col) @_try_remote_functions @@ -11309,7 +11445,9 @@ def sentences( @_try_remote_functions def substring( - str: "ColumnOrName", pos: Union["ColumnOrName", int], len: Union["ColumnOrName", int] + str: "ColumnOrName", + pos: Union["ColumnOrName", int], + len: Union["ColumnOrName", int], ) -> Column: """ Substring starts at `pos` and is of length `len` when str is String type or @@ -11348,16 +11486,59 @@ def substring( Examples -------- + Example 1: Using literal integers as arguments + + >>> import pyspark.sql.functions as sf >>> df = spark.createDataFrame([('abcd',)], ['s',]) - >>> df.select(substring(df.s, 1, 2).alias('s')).collect() - [Row(s='ab')] + >>> df.select('*', sf.substring(df.s, 1, 2)).show() + +----+------------------+ + | s|substring(s, 1, 2)| + +----+------------------+ + |abcd| ab| + +----+------------------+ + + Example 2: Using columns as arguments + + >>> import pyspark.sql.functions as sf + >>> df = spark.createDataFrame([('Spark', 2, 3)], ['s', 'p', 'l']) + >>> df.select('*', sf.substring(df.s, 2, df.l)).show() + +-----+---+---+------------------+ + | s| p| l|substring(s, 2, l)| + +-----+---+---+------------------+ + |Spark| 2| 3| par| + +-----+---+---+------------------+ + + >>> df.select('*', sf.substring(df.s, df.p, 3)).show() + +-----+---+---+------------------+ + | s| p| l|substring(s, p, 3)| + +-----+---+---+------------------+ + |Spark| 2| 3| par| + +-----+---+---+------------------+ + + >>> df.select('*', sf.substring(df.s, df.p, df.l)).show() + +-----+---+---+------------------+ + | s| p| l|substring(s, p, l)| + +-----+---+---+------------------+ + |Spark| 2| 3| par| + +-----+---+---+------------------+ + + Example 3: Using column names as arguments + + >>> import pyspark.sql.functions as sf >>> df = spark.createDataFrame([('Spark', 2, 3)], ['s', 'p', 'l']) - >>> df.select(substring(df.s, 2, df.l).alias('s')).collect() - [Row(s='par')] - >>> df.select(substring(df.s, df.p, 3).alias('s')).collect() - [Row(s='par')] - >>> df.select(substring(df.s, df.p, df.l).alias('s')).collect() - [Row(s='par')] + >>> df.select('*', sf.substring(df.s, 2, 'l')).show() + +-----+---+---+------------------+ + | s| p| l|substring(s, 2, l)| + +-----+---+---+------------------+ + |Spark| 2| 3| par| + +-----+---+---+------------------+ + + >>> df.select('*', sf.substring('s', 'p', 'l')).show() + +-----+---+---+------------------+ + | s| p| l|substring(s, p, l)| + +-----+---+---+------------------+ + |Spark| 2| 3| par| + +-----+---+---+------------------+ """ pos = _enum_to_value(pos) pos = lit(pos) if isinstance(pos, int) else pos @@ -11861,6 +12042,47 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: return _invoke_function_over_columns("regexp_like", str, regexp) +@_try_remote_functions +def randstr(length: Union[Column, int], seed: Optional[Union[Column, int]] = None) -> Column: + """Returns a string of the specified length whose characters are chosen uniformly at random from + the following pool of characters: 0-9, a-z, A-Z. The random seed is optional. The string length + must be a constant two-byte or four-byte integer (SMALLINT or INT, respectively). + + .. versionadded:: 4.0.0 + + Parameters + ---------- + length : :class:`~pyspark.sql.Column` or int + Number of characters in the string to generate. + seed : :class:`~pyspark.sql.Column` or int + Optional random number seed to use. + + Returns + ------- + :class:`~pyspark.sql.Column` + The generated random string with the specified length. + + Examples + -------- + >>> spark.createDataFrame([('3',)], ['a']) \\ + ... .select(randstr(lit(5), lit(0)).alias('result')) \\ + ... .selectExpr("length(result) > 0").show() + +--------------------+ + |(length(result) > 0)| + +--------------------+ + | true| + +--------------------+ + """ + length = _enum_to_value(length) + length = lit(length) + if seed is None: + return _invoke_function_over_columns("randstr", length) + else: + seed = _enum_to_value(seed) + seed = lit(seed) + return _invoke_function_over_columns("randstr", length, seed) + + @_try_remote_functions def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: r"""Returns a count of the number of times that the Java regex pattern `regexp` is matched @@ -12227,6 +12449,57 @@ def unhex(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("unhex", col) +@_try_remote_functions +def uniform( + min: Union[Column, int, float], + max: Union[Column, int, float], + seed: Optional[Union[Column, int]] = None, +) -> Column: + """Returns a random value with independent and identically distributed (i.i.d.) values with the + specified range of numbers. The random seed is optional. The provided numbers specifying the + minimum and maximum values of the range must be constant. If both of these numbers are integers, + then the result will also be an integer. Otherwise if one or both of these are floating-point + numbers, then the result will also be a floating-point number. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + min : :class:`~pyspark.sql.Column`, int, or float + Minimum value in the range. + max : :class:`~pyspark.sql.Column`, int, or float + Maximum value in the range. + seed : :class:`~pyspark.sql.Column` or int + Optional random number seed to use. + + Returns + ------- + :class:`~pyspark.sql.Column` + The generated random number within the specified range. + + Examples + -------- + >>> spark.createDataFrame([('3',)], ['a']) \\ + ... .select(uniform(lit(0), lit(10), lit(0)).alias('result')) \\ + ... .selectExpr("result < 15").show() + +-------------+ + |(result < 15)| + +-------------+ + | true| + +-------------+ + """ + min = _enum_to_value(min) + min = lit(min) + max = _enum_to_value(max) + max = lit(max) + if seed is None: + return _invoke_function_over_columns("uniform", min, max) + else: + seed = _enum_to_value(seed) + seed = lit(seed) + return _invoke_function_over_columns("uniform", min, max, seed) + + @_try_remote_functions def length(col: "ColumnOrName") -> Column: """Computes the character length of string data or number of bytes of binary data. @@ -17631,7 +17904,7 @@ def array_sort( @_try_remote_functions -def shuffle(col: "ColumnOrName") -> Column: +def shuffle(col: "ColumnOrName", seed: Optional[Union[Column, int]] = None) -> Column: """ Array function: Generates a random permutation of the given array. @@ -17644,6 +17917,10 @@ def shuffle(col: "ColumnOrName") -> Column: ---------- col : :class:`~pyspark.sql.Column` or str The name of the column or expression to be shuffled. + seed : :class:`~pyspark.sql.Column` or int, optional + Seed value for the random generator. + + .. versionadded:: 4.0.0 Returns ------- @@ -17660,48 +17937,51 @@ def shuffle(col: "ColumnOrName") -> Column: Example 1: Shuffling a simple array >>> import pyspark.sql.functions as sf - >>> df = spark.createDataFrame([([1, 20, 3, 5],)], ['data']) - >>> df.select(sf.shuffle(df.data)).show() # doctest: +SKIP - +-------------+ - |shuffle(data)| - +-------------+ - |[1, 3, 20, 5]| - +-------------+ + >>> df = spark.sql("SELECT ARRAY(1, 20, 3, 5) AS data") + >>> df.select("*", sf.shuffle(df.data, sf.lit(123))).show() + +-------------+-------------+ + | data|shuffle(data)| + +-------------+-------------+ + |[1, 20, 3, 5]|[5, 1, 20, 3]| + +-------------+-------------+ Example 2: Shuffling an array with null values >>> import pyspark.sql.functions as sf - >>> df = spark.createDataFrame([([1, 20, None, 3],)], ['data']) - >>> df.select(sf.shuffle(df.data)).show() # doctest: +SKIP - +----------------+ - | shuffle(data)| - +----------------+ - |[20, 3, NULL, 1]| - +----------------+ + >>> df = spark.sql("SELECT ARRAY(1, 20, NULL, 5) AS data") + >>> df.select("*", sf.shuffle(sf.col("data"), 234)).show() + +----------------+----------------+ + | data| shuffle(data)| + +----------------+----------------+ + |[1, 20, NULL, 5]|[NULL, 5, 20, 1]| + +----------------+----------------+ Example 3: Shuffling an array with duplicate values >>> import pyspark.sql.functions as sf - >>> df = spark.createDataFrame([([1, 2, 2, 3, 3, 3],)], ['data']) - >>> df.select(sf.shuffle(df.data)).show() # doctest: +SKIP - +------------------+ - | shuffle(data)| - +------------------+ - |[3, 2, 1, 3, 2, 3]| - +------------------+ + >>> df = spark.sql("SELECT ARRAY(1, 2, 2, 3, 3, 3) AS data") + >>> df.select("*", sf.shuffle("data", 345)).show() + +------------------+------------------+ + | data| shuffle(data)| + +------------------+------------------+ + |[1, 2, 2, 3, 3, 3]|[2, 3, 3, 1, 2, 3]| + +------------------+------------------+ - Example 4: Shuffling an array with different types of elements + Example 4: Shuffling an array with random seed >>> import pyspark.sql.functions as sf - >>> df = spark.createDataFrame([(['a', 'b', 'c', 1, 2, 3],)], ['data']) - >>> df.select(sf.shuffle(df.data)).show() # doctest: +SKIP - +------------------+ - | shuffle(data)| - +------------------+ - |[1, c, 2, a, b, 3]| - +------------------+ + >>> df = spark.sql("SELECT ARRAY(1, 2, 2, 3, 3, 3) AS data") + >>> df.select("*", sf.shuffle("data")).show() # doctest: +SKIP + +------------------+------------------+ + | data| shuffle(data)| + +------------------+------------------+ + |[1, 2, 2, 3, 3, 3]|[3, 3, 2, 3, 2, 1]| + +------------------+------------------+ """ - return _invoke_function_over_columns("shuffle", col) + if seed is not None: + return _invoke_function_over_columns("shuffle", col, lit(seed)) + else: + return _invoke_function_over_columns("shuffle", col) @_try_remote_functions diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index 53c72304adfaa..57e46901013fe 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -53,12 +53,17 @@ ) from pyspark.errors import PySparkTypeError, UnsupportedOperationException, PySparkValueError from pyspark.loose_version import LooseVersion +from pyspark.sql.utils import has_numpy + +if has_numpy: + import numpy as np if TYPE_CHECKING: import pandas as pd import pyarrow as pa from pyspark.sql.pandas._typing import SeriesLike as PandasSeriesLike + from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike def to_arrow_type( @@ -1344,3 +1349,34 @@ def _deduplicate_field_names(dt: DataType) -> DataType: ) else: return dt + + +def _to_numpy_type(type: DataType) -> Optional["np.dtype"]: + """Convert Spark data type to NumPy type.""" + import numpy as np + + if type == ByteType(): + return np.dtype("int8") + elif type == ShortType(): + return np.dtype("int16") + elif type == IntegerType(): + return np.dtype("int32") + elif type == LongType(): + return np.dtype("int64") + elif type == FloatType(): + return np.dtype("float32") + elif type == DoubleType(): + return np.dtype("float64") + return None + + +def convert_pandas_using_numpy_type( + df: "PandasDataFrameLike", schema: StructType +) -> "PandasDataFrameLike": + for field in schema.fields: + if isinstance( + field.dataType, (ByteType, ShortType, LongType, FloatType, DoubleType, IntegerType) + ): + np_type = _to_numpy_type(field.dataType) + df[field.name] = df[field.name].astype(np_type) + return df diff --git a/python/pyspark/sql/plot/__init__.py b/python/pyspark/sql/plot/__init__.py new file mode 100644 index 0000000000000..6da07061b2a09 --- /dev/null +++ b/python/pyspark/sql/plot/__init__.py @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This package includes the plotting APIs for PySpark DataFrame. +""" +from pyspark.sql.plot.core import * # noqa: F403, F401 diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py new file mode 100644 index 0000000000000..4bf75474d92c3 --- /dev/null +++ b/python/pyspark/sql/plot/core.py @@ -0,0 +1,487 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, TYPE_CHECKING, List, Optional, Union +from types import ModuleType +from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError +from pyspark.sql import Column, functions as F +from pyspark.sql.types import NumericType +from pyspark.sql.utils import is_remote, require_minimum_plotly_version + + +if TYPE_CHECKING: + from pyspark.sql import DataFrame, Row + from pyspark.sql._typing import ColumnOrName + import pandas as pd + from plotly.graph_objs import Figure + + +class PySparkTopNPlotBase: + def get_top_n(self, sdf: "DataFrame") -> "pd.DataFrame": + from pyspark.sql import SparkSession + + session = SparkSession.getActiveSession() + if session is None: + raise PySparkRuntimeError(errorClass="NO_ACTIVE_SESSION", messageParameters=dict()) + + max_rows = int( + session.conf.get("spark.sql.pyspark.plotting.max_rows") # type: ignore[arg-type] + ) + pdf = sdf.limit(max_rows + 1).toPandas() + + self.partial = False + if len(pdf) > max_rows: + self.partial = True + pdf = pdf.iloc[:max_rows] + + return pdf + + +class PySparkSampledPlotBase: + def get_sampled(self, sdf: "DataFrame") -> "pd.DataFrame": + from pyspark.sql import SparkSession, Observation, functions as F + + session = SparkSession.getActiveSession() + if session is None: + raise PySparkRuntimeError(errorClass="NO_ACTIVE_SESSION", messageParameters=dict()) + + max_rows = int( + session.conf.get("spark.sql.pyspark.plotting.max_rows") # type: ignore[arg-type] + ) + + observation = Observation("pyspark plotting") + + rand_col_name = "__pyspark_plotting_sampled_plot_base_rand__" + id_col_name = "__pyspark_plotting_sampled_plot_base_id__" + + sampled_sdf = ( + sdf.observe(observation, F.count(F.lit(1)).alias("count")) + .select( + "*", + F.rand().alias(rand_col_name), + F.monotonically_increasing_id().alias(id_col_name), + ) + .sort(rand_col_name) + .limit(max_rows + 1) + .coalesce(1) + .sortWithinPartitions(id_col_name) + .drop(rand_col_name, id_col_name) + ) + pdf = sampled_sdf.toPandas() + + if len(pdf) > max_rows: + try: + self.fraction = float(max_rows) / observation.get["count"] + except Exception: + pass + return pdf[:max_rows] + else: + self.fraction = 1.0 + return pdf + + +class PySparkPlotAccessor: + plot_data_map = { + "area": PySparkSampledPlotBase().get_sampled, + "bar": PySparkTopNPlotBase().get_top_n, + "barh": PySparkTopNPlotBase().get_top_n, + "line": PySparkSampledPlotBase().get_sampled, + "pie": PySparkTopNPlotBase().get_top_n, + "scatter": PySparkSampledPlotBase().get_sampled, + } + _backends = {} # type: ignore[var-annotated] + + def __init__(self, data: "DataFrame"): + self.data = data + + def __call__( + self, kind: str = "line", backend: Optional[str] = None, **kwargs: Any + ) -> "Figure": + plot_backend = PySparkPlotAccessor._get_plot_backend(backend) + + return plot_backend.plot_pyspark(self.data, kind=kind, **kwargs) + + @staticmethod + def _get_plot_backend(backend: Optional[str] = None) -> ModuleType: + backend = backend or "plotly" + + if backend in PySparkPlotAccessor._backends: + return PySparkPlotAccessor._backends[backend] + + if backend == "plotly": + require_minimum_plotly_version() + else: + raise PySparkValueError( + errorClass="UNSUPPORTED_PLOT_BACKEND", + messageParameters={"backend": backend, "supported_backends": ", ".join(["plotly"])}, + ) + from pyspark.sql.plot import plotly as module + + return module + + def line(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure": + """ + Plot DataFrame as lines. + + Parameters + ---------- + x : str + Name of column to use for the horizontal axis. + y : str or list of str + Name(s) of the column(s) to use for the vertical axis. Multiple columns can be plotted. + **kwargs : optional + Additional keyword arguments. + + Returns + ------- + :class:`plotly.graph_objs.Figure` + + Examples + -------- + >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)] + >>> columns = ["category", "int_val", "float_val"] + >>> df = spark.createDataFrame(data, columns) + >>> df.plot.line(x="category", y="int_val") # doctest: +SKIP + >>> df.plot.line(x="category", y=["int_val", "float_val"]) # doctest: +SKIP + """ + return self(kind="line", x=x, y=y, **kwargs) + + def bar(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure": + """ + Vertical bar plot. + + A bar plot is a plot that presents categorical data with rectangular bars with lengths + proportional to the values that they represent. A bar plot shows comparisons among + discrete categories. One axis of the plot shows the specific categories being compared, + and the other axis represents a measured value. + + Parameters + ---------- + x : str + Name of column to use for the horizontal axis. + y : str or list of str + Name(s) of the column(s) to use for the vertical axis. + Multiple columns can be plotted. + **kwargs : optional + Additional keyword arguments. + + Returns + ------- + :class:`plotly.graph_objs.Figure` + + Examples + -------- + >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)] + >>> columns = ["category", "int_val", "float_val"] + >>> df = spark.createDataFrame(data, columns) + >>> df.plot.bar(x="category", y="int_val") # doctest: +SKIP + >>> df.plot.bar(x="category", y=["int_val", "float_val"]) # doctest: +SKIP + """ + return self(kind="bar", x=x, y=y, **kwargs) + + def barh(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure": + """ + Make a horizontal bar plot. + + A horizontal bar plot is a plot that presents quantitative data with + rectangular bars with lengths proportional to the values that they + represent. A bar plot shows comparisons among discrete categories. One + axis of the plot shows the specific categories being compared, and the + other axis represents a measured value. + + Parameters + ---------- + x : str or list of str + Name(s) of the column(s) to use for the horizontal axis. + Multiple columns can be plotted. + y : str or list of str + Name(s) of the column(s) to use for the vertical axis. + Multiple columns can be plotted. + **kwargs : optional + Additional keyword arguments. + + Returns + ------- + :class:`plotly.graph_objs.Figure` + + Notes + ----- + In Plotly and Matplotlib, the interpretation of `x` and `y` for `barh` plots differs. + In Plotly, `x` refers to the values and `y` refers to the categories. + In Matplotlib, `x` refers to the categories and `y` refers to the values. + Ensure correct axis labeling based on the backend used. + + Examples + -------- + >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)] + >>> columns = ["category", "int_val", "float_val"] + >>> df = spark.createDataFrame(data, columns) + >>> df.plot.barh(x="int_val", y="category") # doctest: +SKIP + >>> df.plot.barh( + ... x=["int_val", "float_val"], y="category" + ... ) # doctest: +SKIP + """ + return self(kind="barh", x=x, y=y, **kwargs) + + def scatter(self, x: str, y: str, **kwargs: Any) -> "Figure": + """ + Create a scatter plot with varying marker point size and color. + + The coordinates of each point are defined by two dataframe columns and + filled circles are used to represent each point. This kind of plot is + useful to see complex correlations between two variables. Points could + be for instance natural 2D coordinates like longitude and latitude in + a map or, in general, any pair of metrics that can be plotted against + each other. + + Parameters + ---------- + x : str + Name of column to use as horizontal coordinates for each point. + y : str or list of str + Name of column to use as vertical coordinates for each point. + **kwargs: Optional + Additional keyword arguments. + + Returns + ------- + :class:`plotly.graph_objs.Figure` + + Examples + -------- + >>> data = [(5.1, 3.5, 0), (4.9, 3.0, 0), (7.0, 3.2, 1), (6.4, 3.2, 1), (5.9, 3.0, 2)] + >>> columns = ['length', 'width', 'species'] + >>> df = spark.createDataFrame(data, columns) + >>> df.plot.scatter(x='length', y='width') # doctest: +SKIP + """ + return self(kind="scatter", x=x, y=y, **kwargs) + + def area(self, x: str, y: str, **kwargs: Any) -> "Figure": + """ + Draw a stacked area plot. + + An area plot displays quantitative data visually. + + Parameters + ---------- + x : str + Name of column to use for the horizontal axis. + y : str or list of str + Name(s) of the column(s) to plot. + **kwargs: Optional + Additional keyword arguments. + + Returns + ------- + :class:`plotly.graph_objs.Figure` + + Examples + -------- + >>> from datetime import datetime + >>> data = [ + ... (3, 5, 20, datetime(2018, 1, 31)), + ... (2, 5, 42, datetime(2018, 2, 28)), + ... (3, 6, 28, datetime(2018, 3, 31)), + ... (9, 12, 62, datetime(2018, 4, 30)) + ... ] + >>> columns = ["sales", "signups", "visits", "date"] + >>> df = spark.createDataFrame(data, columns) + >>> df.plot.area(x='date', y=['sales', 'signups', 'visits']) # doctest: +SKIP + """ + return self(kind="area", x=x, y=y, **kwargs) + + def pie(self, x: str, y: str, **kwargs: Any) -> "Figure": + """ + Generate a pie plot. + + A pie plot is a proportional representation of the numerical data in a + column. + + Parameters + ---------- + x : str + Name of column to be used as the category labels for the pie plot. + y : str + Name of the column to plot. + **kwargs + Additional keyword arguments. + + Returns + ------- + :class:`plotly.graph_objs.Figure` + + Examples + -------- + """ + schema = self.data.schema + + # Check if 'y' is a numerical column + y_field = schema[y] if y in schema.names else None + if y_field is None or not isinstance(y_field.dataType, NumericType): + raise PySparkTypeError( + errorClass="PLOT_NOT_NUMERIC_COLUMN", + messageParameters={ + "arg_name": "y", + "arg_type": str(y_field.dataType) if y_field else "None", + }, + ) + return self(kind="pie", x=x, y=y, **kwargs) + + def box( + self, column: Union[str, List[str]], precision: float = 0.01, **kwargs: Any + ) -> "Figure": + """ + Make a box plot of the DataFrame columns. + + Make a box-and-whisker plot from DataFrame columns, optionally grouped by some + other columns. A box plot is a method for graphically depicting groups of numerical + data through their quartiles. The box extends from the Q1 to Q3 quartile values of + the data, with a line at the median (Q2). The whiskers extend from the edges of box + to show the range of the data. By default, they extend no more than + 1.5 * IQR (IQR = Q3 - Q1) from the edges of the box, ending at the farthest data point + within that interval. Outliers are plotted as separate dots. + + Parameters + ---------- + column: str or list of str + Column name or list of names to be used for creating the boxplot. + precision: float, default = 0.01 + This argument is used by pyspark to compute approximate statistics + for building a boxplot. + **kwargs + Additional keyword arguments. + + Returns + ------- + :class:`plotly.graph_objs.Figure` + + Examples + -------- + >>> data = [ + ... ("A", 50, 55), + ... ("B", 55, 60), + ... ("C", 60, 65), + ... ("D", 65, 70), + ... ("E", 70, 75), + ... ("F", 10, 15), + ... ("G", 85, 90), + ... ("H", 5, 150), + ... ] + >>> columns = ["student", "math_score", "english_score"] + >>> df = spark.createDataFrame(data, columns) + >>> df.plot.box(column="math_score") # doctest: +SKIP + >>> df.plot.box(column=["math_score", "english_score"]) # doctest: +SKIP + """ + return self(kind="box", column=column, precision=precision, **kwargs) + + +class PySparkBoxPlotBase: + @staticmethod + def compute_box( + sdf: "DataFrame", colnames: List[str], whis: float, precision: float, showfliers: bool + ) -> Optional["Row"]: + assert len(colnames) > 0 + formatted_colnames = ["`{}`".format(colname) for colname in colnames] + + stats_scols = [] + for i, colname in enumerate(formatted_colnames): + percentiles = F.percentile_approx(colname, [0.25, 0.50, 0.75], int(1.0 / precision)) + q1 = F.get(percentiles, 0) + med = F.get(percentiles, 1) + q3 = F.get(percentiles, 2) + iqr = q3 - q1 + lfence = q1 - F.lit(whis) * iqr + ufence = q3 + F.lit(whis) * iqr + + stats_scols.append( + F.struct( + F.mean(colname).alias("mean"), + med.alias("med"), + q1.alias("q1"), + q3.alias("q3"), + lfence.alias("lfence"), + ufence.alias("ufence"), + ).alias(f"_box_plot_stats_{i}") + ) + + sdf_stats = sdf.select(*stats_scols) + + result_scols = [] + for i, colname in enumerate(formatted_colnames): + value = F.col(colname) + + lfence = F.col(f"_box_plot_stats_{i}.lfence") + ufence = F.col(f"_box_plot_stats_{i}.ufence") + mean = F.col(f"_box_plot_stats_{i}.mean") + med = F.col(f"_box_plot_stats_{i}.med") + q1 = F.col(f"_box_plot_stats_{i}.q1") + q3 = F.col(f"_box_plot_stats_{i}.q3") + + outlier = ~value.between(lfence, ufence) + + # Computes min and max values of non-outliers - the whiskers + upper_whisker = F.max(F.when(~outlier, value).otherwise(F.lit(None))) + lower_whisker = F.min(F.when(~outlier, value).otherwise(F.lit(None))) + + # If it shows fliers, take the top 1k with the highest absolute values + # Here we normalize the values by subtracting the median. + if showfliers: + pair = F.when( + outlier, + F.struct(F.abs(value - med), value.alias("val")), + ).otherwise(F.lit(None)) + topk = collect_top_k(pair, 1001, False) + fliers = F.when(F.size(topk) > 0, topk["val"]).otherwise(F.lit(None)) + else: + fliers = F.lit(None) + + result_scols.append( + F.struct( + F.first(mean).alias("mean"), + F.first(med).alias("med"), + F.first(q1).alias("q1"), + F.first(q3).alias("q3"), + upper_whisker.alias("upper_whisker"), + lower_whisker.alias("lower_whisker"), + fliers.alias("fliers"), + ).alias(f"_box_plot_results_{i}") + ) + + sdf_result = sdf.join(sdf_stats.hint("broadcast")).select(*result_scols) + return sdf_result.first() + + +def _invoke_internal_function_over_columns(name: str, *cols: "ColumnOrName") -> Column: + if is_remote(): + from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns + + return _invoke_function_over_columns(name, *cols) + + else: + from pyspark.sql.classic.column import _to_seq, _to_java_column + from pyspark import SparkContext + + sc = SparkContext._active_spark_context + return Column( + sc._jvm.PythonSQLUtils.internalFn( # type: ignore + name, _to_seq(sc, cols, _to_java_column) # type: ignore + ) + ) + + +def collect_top_k(col: Column, num: int, reverse: bool) -> Column: + return _invoke_internal_function_over_columns("collect_top_k", col, F.lit(num), F.lit(reverse)) diff --git a/python/pyspark/sql/plot/plotly.py b/python/pyspark/sql/plot/plotly.py new file mode 100644 index 0000000000000..71d40720e874d --- /dev/null +++ b/python/pyspark/sql/plot/plotly.py @@ -0,0 +1,120 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import TYPE_CHECKING, Any + +from pyspark.errors import PySparkValueError +from pyspark.sql.plot import PySparkPlotAccessor, PySparkBoxPlotBase + +if TYPE_CHECKING: + from pyspark.sql import DataFrame + from plotly.graph_objs import Figure + + +def plot_pyspark(data: "DataFrame", kind: str, **kwargs: Any) -> "Figure": + import plotly + + if kind == "pie": + return plot_pie(data, **kwargs) + if kind == "box": + return plot_box(data, **kwargs) + + return plotly.plot(PySparkPlotAccessor.plot_data_map[kind](data), kind, **kwargs) + + +def plot_pie(data: "DataFrame", **kwargs: Any) -> "Figure": + # TODO(SPARK-49530): Support pie subplots with plotly backend + from plotly import express + + pdf = PySparkPlotAccessor.plot_data_map["pie"](data) + x = kwargs.pop("x", None) + y = kwargs.pop("y", None) + fig = express.pie(pdf, values=y, names=x, **kwargs) + + return fig + + +def plot_box(data: "DataFrame", **kwargs: Any) -> "Figure": + import plotly.graph_objs as go + + # 'whis' isn't actually an argument in plotly (but in matplotlib). But seems like + # plotly doesn't expose the reach of the whiskers to the beyond the first and + # third quartiles (?). Looks they use default 1.5. + whis = kwargs.pop("whis", 1.5) + # 'precision' is pyspark specific to control precision for approx_percentile + precision = kwargs.pop("precision", 0.01) + colnames = kwargs.pop("column", None) + if isinstance(colnames, str): + colnames = [colnames] + + # Plotly options + boxpoints = kwargs.pop("boxpoints", "suspectedoutliers") + notched = kwargs.pop("notched", False) + if boxpoints not in ["suspectedoutliers", False]: + raise PySparkValueError( + errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM", + messageParameters={ + "backend": "plotly", + "param": "boxpoints", + "value": str(boxpoints), + "supported_values": ", ".join(["suspectedoutliers", "False"]), + }, + ) + if notched: + raise PySparkValueError( + errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM", + messageParameters={ + "backend": "plotly", + "param": "notched", + "value": str(notched), + "supported_values": ", ".join(["False"]), + }, + ) + + fig = go.Figure() + + results = PySparkBoxPlotBase.compute_box( + data, + colnames, + whis, + precision, + boxpoints is not None, + ) + assert len(results) == len(colnames) # type: ignore + + for i, colname in enumerate(colnames): + result = results[i] # type: ignore + + fig.add_trace( + go.Box( + x=[i], + name=colname, + q1=[result["q1"]], + median=[result["med"]], + q3=[result["q3"]], + mean=[result["mean"]], + lowerfence=[result["lower_whisker"]], + upperfence=[result["upper_whisker"]], + y=[result["fliers"]] if result["fliers"] else None, + boxpoints=boxpoints, + notched=notched, + **kwargs, + ) + ) + + fig["layout"]["yaxis"]["title"] = "value" + return fig diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index b513d8d4111b9..96344efba2d2a 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -77,6 +77,7 @@ from pyspark.sql.udf import UDFRegistration from pyspark.sql.udtf import UDTFRegistration from pyspark.sql.datasource import DataSourceRegistration + from pyspark.sql.dataframe import DataFrame as ParentDataFrame # Running MyPy type checks will always require pandas and # other dependencies so importing here is fine. @@ -1641,7 +1642,7 @@ def prepare(obj: Any) -> Any: def sql( self, sqlQuery: str, args: Optional[Union[Dict[str, Any], List]] = None, **kwargs: Any - ) -> DataFrame: + ) -> "ParentDataFrame": """Returns a :class:`DataFrame` representing the result of the given query. When ``kwargs`` is specified, this method formats the given string by using the Python standard formatter. The method binds named parameters to SQL literals or diff --git a/python/pyspark/sql/streaming/StateMessage_pb2.py b/python/pyspark/sql/streaming/StateMessage_pb2.py index a22f004fd3048..e75d0394ea0f5 100644 --- a/python/pyspark/sql/streaming/StateMessage_pb2.py +++ b/python/pyspark/sql/streaming/StateMessage_pb2.py @@ -16,12 +16,14 @@ # # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE # source: StateMessage.proto +# Protobuf Python Version: 5.27.3 """Generated protocol buffer code.""" -from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) @@ -29,45 +31,54 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x12StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\xe9\x02\n\x0cStateRequest\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x66\n\x15statefulProcessorCall\x18\x02 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00\x12\x64\n\x14stateVariableRequest\x18\x03 \x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00\x12p\n\x1aimplicitGroupingKeyRequest\x18\x04 \x01(\x0b\x32J.org.apache.spark.sql.execution.streaming.state.ImplicitGroupingKeyRequestH\x00\x42\x08\n\x06method"H\n\rStateResponse\x12\x12\n\nstatusCode\x18\x01 \x01(\x05\x12\x14\n\x0c\x65rrorMessage\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x0c"\x89\x03\n\x15StatefulProcessorCall\x12X\n\x0esetHandleState\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetHandleStateH\x00\x12Y\n\rgetValueState\x18\x02 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12X\n\x0cgetListState\x18\x03 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12W\n\x0bgetMapState\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x42\x08\n\x06method"z\n\x14StateVariableRequest\x12X\n\x0evalueStateCall\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.ValueStateCallH\x00\x42\x08\n\x06method"\xe0\x01\n\x1aImplicitGroupingKeyRequest\x12X\n\x0esetImplicitKey\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetImplicitKeyH\x00\x12^\n\x11removeImplicitKey\x18\x02 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.RemoveImplicitKeyH\x00\x42\x08\n\x06method"}\n\x10StateCallCommand\x12\x11\n\tstateName\x18\x01 \x01(\t\x12\x0e\n\x06schema\x18\x02 \x01(\t\x12\x46\n\x03ttl\x18\x03 \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.TTLConfig"\xe1\x02\n\x0eValueStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12\x42\n\x03get\x18\x03 \x01(\x0b\x32\x33.org.apache.spark.sql.execution.streaming.state.GetH\x00\x12\\\n\x10valueStateUpdate\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.ValueStateUpdateH\x00\x12\x46\n\x05\x63lear\x18\x05 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\x1d\n\x0eSetImplicitKey\x12\x0b\n\x03key\x18\x01 \x01(\x0c"\x13\n\x11RemoveImplicitKey"\x08\n\x06\x45xists"\x05\n\x03Get"!\n\x10ValueStateUpdate\x12\r\n\x05value\x18\x01 \x01(\x0c"\x07\n\x05\x43lear"\\\n\x0eSetHandleState\x12J\n\x05state\x18\x01 \x01(\x0e\x32;.org.apache.spark.sql.execution.streaming.state.HandleState"\x1f\n\tTTLConfig\x12\x12\n\ndurationMs\x18\x01 \x01(\x05*K\n\x0bHandleState\x12\x0b\n\x07\x43REATED\x10\x00\x12\x0f\n\x0bINITIALIZED\x10\x01\x12\x12\n\x0e\x44\x41TA_PROCESSED\x10\x02\x12\n\n\x06\x43LOSED\x10\x03\x62\x06proto3' # noqa: E501 + b'\n\x12StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\xe9\x02\n\x0cStateRequest\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x66\n\x15statefulProcessorCall\x18\x02 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00\x12\x64\n\x14stateVariableRequest\x18\x03 \x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00\x12p\n\x1aimplicitGroupingKeyRequest\x18\x04 \x01(\x0b\x32J.org.apache.spark.sql.execution.streaming.state.ImplicitGroupingKeyRequestH\x00\x42\x08\n\x06method"H\n\rStateResponse\x12\x12\n\nstatusCode\x18\x01 \x01(\x05\x12\x14\n\x0c\x65rrorMessage\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x0c"\x89\x03\n\x15StatefulProcessorCall\x12X\n\x0esetHandleState\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetHandleStateH\x00\x12Y\n\rgetValueState\x18\x02 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12X\n\x0cgetListState\x18\x03 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12W\n\x0bgetMapState\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x42\x08\n\x06method"\xd2\x01\n\x14StateVariableRequest\x12X\n\x0evalueStateCall\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.ValueStateCallH\x00\x12V\n\rlistStateCall\x18\x02 \x01(\x0b\x32=.org.apache.spark.sql.execution.streaming.state.ListStateCallH\x00\x42\x08\n\x06method"\xe0\x01\n\x1aImplicitGroupingKeyRequest\x12X\n\x0esetImplicitKey\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetImplicitKeyH\x00\x12^\n\x11removeImplicitKey\x18\x02 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.RemoveImplicitKeyH\x00\x42\x08\n\x06method"}\n\x10StateCallCommand\x12\x11\n\tstateName\x18\x01 \x01(\t\x12\x0e\n\x06schema\x18\x02 \x01(\t\x12\x46\n\x03ttl\x18\x03 \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.TTLConfig"\xe1\x02\n\x0eValueStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12\x42\n\x03get\x18\x03 \x01(\x0b\x32\x33.org.apache.spark.sql.execution.streaming.state.GetH\x00\x12\\\n\x10valueStateUpdate\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.ValueStateUpdateH\x00\x12\x46\n\x05\x63lear\x18\x05 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\x90\x04\n\rListStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12T\n\x0clistStateGet\x18\x03 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.ListStateGetH\x00\x12T\n\x0clistStatePut\x18\x04 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.ListStatePutH\x00\x12R\n\x0b\x61ppendValue\x18\x05 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.AppendValueH\x00\x12P\n\nappendList\x18\x06 \x01(\x0b\x32:.org.apache.spark.sql.execution.streaming.state.AppendListH\x00\x12\x46\n\x05\x63lear\x18\x07 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\x1d\n\x0eSetImplicitKey\x12\x0b\n\x03key\x18\x01 \x01(\x0c"\x13\n\x11RemoveImplicitKey"\x08\n\x06\x45xists"\x05\n\x03Get"!\n\x10ValueStateUpdate\x12\r\n\x05value\x18\x01 \x01(\x0c"\x07\n\x05\x43lear""\n\x0cListStateGet\x12\x12\n\niteratorId\x18\x01 \x01(\t"\x0e\n\x0cListStatePut"\x1c\n\x0b\x41ppendValue\x12\r\n\x05value\x18\x01 \x01(\x0c"\x0c\n\nAppendList"\\\n\x0eSetHandleState\x12J\n\x05state\x18\x01 \x01(\x0e\x32;.org.apache.spark.sql.execution.streaming.state.HandleState"\x1f\n\tTTLConfig\x12\x12\n\ndurationMs\x18\x01 \x01(\x05*K\n\x0bHandleState\x12\x0b\n\x07\x43REATED\x10\x00\x12\x0f\n\x0bINITIALIZED\x10\x01\x12\x12\n\x0e\x44\x41TA_PROCESSED\x10\x02\x12\n\n\x06\x43LOSED\x10\x03\x62\x06proto3' # noqa: E501 ) _globals = globals() - _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "StateMessage_pb2", _globals) if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._options = None - _globals["_HANDLESTATE"]._serialized_start = 1978 - _globals["_HANDLESTATE"]._serialized_end = 2053 + DESCRIPTOR._loaded_options = None + _globals["_HANDLESTATE"]._serialized_start = 2694 + _globals["_HANDLESTATE"]._serialized_end = 2769 _globals["_STATEREQUEST"]._serialized_start = 71 _globals["_STATEREQUEST"]._serialized_end = 432 _globals["_STATERESPONSE"]._serialized_start = 434 _globals["_STATERESPONSE"]._serialized_end = 506 _globals["_STATEFULPROCESSORCALL"]._serialized_start = 509 _globals["_STATEFULPROCESSORCALL"]._serialized_end = 902 - _globals["_STATEVARIABLEREQUEST"]._serialized_start = 904 - _globals["_STATEVARIABLEREQUEST"]._serialized_end = 1026 - _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 1029 - _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 1253 - _globals["_STATECALLCOMMAND"]._serialized_start = 1255 - _globals["_STATECALLCOMMAND"]._serialized_end = 1380 - _globals["_VALUESTATECALL"]._serialized_start = 1383 - _globals["_VALUESTATECALL"]._serialized_end = 1736 - _globals["_SETIMPLICITKEY"]._serialized_start = 1738 - _globals["_SETIMPLICITKEY"]._serialized_end = 1767 - _globals["_REMOVEIMPLICITKEY"]._serialized_start = 1769 - _globals["_REMOVEIMPLICITKEY"]._serialized_end = 1788 - _globals["_EXISTS"]._serialized_start = 1790 - _globals["_EXISTS"]._serialized_end = 1798 - _globals["_GET"]._serialized_start = 1800 - _globals["_GET"]._serialized_end = 1805 - _globals["_VALUESTATEUPDATE"]._serialized_start = 1807 - _globals["_VALUESTATEUPDATE"]._serialized_end = 1840 - _globals["_CLEAR"]._serialized_start = 1842 - _globals["_CLEAR"]._serialized_end = 1849 - _globals["_SETHANDLESTATE"]._serialized_start = 1851 - _globals["_SETHANDLESTATE"]._serialized_end = 1943 - _globals["_TTLCONFIG"]._serialized_start = 1945 - _globals["_TTLCONFIG"]._serialized_end = 1976 + _globals["_STATEVARIABLEREQUEST"]._serialized_start = 905 + _globals["_STATEVARIABLEREQUEST"]._serialized_end = 1115 + _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 1118 + _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 1342 + _globals["_STATECALLCOMMAND"]._serialized_start = 1344 + _globals["_STATECALLCOMMAND"]._serialized_end = 1469 + _globals["_VALUESTATECALL"]._serialized_start = 1472 + _globals["_VALUESTATECALL"]._serialized_end = 1825 + _globals["_LISTSTATECALL"]._serialized_start = 1828 + _globals["_LISTSTATECALL"]._serialized_end = 2356 + _globals["_SETIMPLICITKEY"]._serialized_start = 2358 + _globals["_SETIMPLICITKEY"]._serialized_end = 2387 + _globals["_REMOVEIMPLICITKEY"]._serialized_start = 2389 + _globals["_REMOVEIMPLICITKEY"]._serialized_end = 2408 + _globals["_EXISTS"]._serialized_start = 2410 + _globals["_EXISTS"]._serialized_end = 2418 + _globals["_GET"]._serialized_start = 2420 + _globals["_GET"]._serialized_end = 2425 + _globals["_VALUESTATEUPDATE"]._serialized_start = 2427 + _globals["_VALUESTATEUPDATE"]._serialized_end = 2460 + _globals["_CLEAR"]._serialized_start = 2462 + _globals["_CLEAR"]._serialized_end = 2469 + _globals["_LISTSTATEGET"]._serialized_start = 2471 + _globals["_LISTSTATEGET"]._serialized_end = 2505 + _globals["_LISTSTATEPUT"]._serialized_start = 2507 + _globals["_LISTSTATEPUT"]._serialized_end = 2521 + _globals["_APPENDVALUE"]._serialized_start = 2523 + _globals["_APPENDVALUE"]._serialized_end = 2551 + _globals["_APPENDLIST"]._serialized_start = 2553 + _globals["_APPENDLIST"]._serialized_end = 2565 + _globals["_SETHANDLESTATE"]._serialized_start = 2567 + _globals["_SETHANDLESTATE"]._serialized_end = 2659 + _globals["_TTLCONFIG"]._serialized_start = 2661 + _globals["_TTLCONFIG"]._serialized_end = 2692 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/streaming/StateMessage_pb2.pyi b/python/pyspark/sql/streaming/StateMessage_pb2.pyi index 1ab48a27c8f87..b1f5f0f7d2a1e 100644 --- a/python/pyspark/sql/streaming/StateMessage_pb2.pyi +++ b/python/pyspark/sql/streaming/StateMessage_pb2.pyi @@ -13,167 +13,238 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message -from typing import ClassVar, Mapping, Optional, Union +from typing import ( + ClassVar as _ClassVar, + Mapping as _Mapping, + Optional as _Optional, + Union as _Union, +) -CLOSED: HandleState -CREATED: HandleState -DATA_PROCESSED: HandleState DESCRIPTOR: _descriptor.FileDescriptor -INITIALIZED: HandleState -class Clear(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class Exists(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class Get(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class ImplicitGroupingKeyRequest(_message.Message): - __slots__ = ["removeImplicitKey", "setImplicitKey"] - REMOVEIMPLICITKEY_FIELD_NUMBER: ClassVar[int] - SETIMPLICITKEY_FIELD_NUMBER: ClassVar[int] - removeImplicitKey: RemoveImplicitKey - setImplicitKey: SetImplicitKey - def __init__( - self, - setImplicitKey: Optional[Union[SetImplicitKey, Mapping]] = ..., - removeImplicitKey: Optional[Union[RemoveImplicitKey, Mapping]] = ..., - ) -> None: ... - -class RemoveImplicitKey(_message.Message): +class HandleState(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): __slots__ = () - def __init__(self) -> None: ... - -class SetHandleState(_message.Message): - __slots__ = ["state"] - STATE_FIELD_NUMBER: ClassVar[int] - state: HandleState - def __init__(self, state: Optional[Union[HandleState, str]] = ...) -> None: ... - -class SetImplicitKey(_message.Message): - __slots__ = ["key"] - KEY_FIELD_NUMBER: ClassVar[int] - key: bytes - def __init__(self, key: Optional[bytes] = ...) -> None: ... + CREATED: _ClassVar[HandleState] + INITIALIZED: _ClassVar[HandleState] + DATA_PROCESSED: _ClassVar[HandleState] + CLOSED: _ClassVar[HandleState] -class StateCallCommand(_message.Message): - __slots__ = ["schema", "stateName", "ttl"] - SCHEMA_FIELD_NUMBER: ClassVar[int] - STATENAME_FIELD_NUMBER: ClassVar[int] - TTL_FIELD_NUMBER: ClassVar[int] - schema: str - stateName: str - ttl: TTLConfig - def __init__( - self, - stateName: Optional[str] = ..., - schema: Optional[str] = ..., - ttl: Optional[Union[TTLConfig, Mapping]] = ..., - ) -> None: ... +CREATED: HandleState +INITIALIZED: HandleState +DATA_PROCESSED: HandleState +CLOSED: HandleState class StateRequest(_message.Message): - __slots__ = [ - "implicitGroupingKeyRequest", - "stateVariableRequest", - "statefulProcessorCall", + __slots__ = ( "version", - ] - IMPLICITGROUPINGKEYREQUEST_FIELD_NUMBER: ClassVar[int] - STATEFULPROCESSORCALL_FIELD_NUMBER: ClassVar[int] - STATEVARIABLEREQUEST_FIELD_NUMBER: ClassVar[int] - VERSION_FIELD_NUMBER: ClassVar[int] - implicitGroupingKeyRequest: ImplicitGroupingKeyRequest - stateVariableRequest: StateVariableRequest - statefulProcessorCall: StatefulProcessorCall + "statefulProcessorCall", + "stateVariableRequest", + "implicitGroupingKeyRequest", + ) + VERSION_FIELD_NUMBER: _ClassVar[int] + STATEFULPROCESSORCALL_FIELD_NUMBER: _ClassVar[int] + STATEVARIABLEREQUEST_FIELD_NUMBER: _ClassVar[int] + IMPLICITGROUPINGKEYREQUEST_FIELD_NUMBER: _ClassVar[int] version: int + statefulProcessorCall: StatefulProcessorCall + stateVariableRequest: StateVariableRequest + implicitGroupingKeyRequest: ImplicitGroupingKeyRequest def __init__( self, - version: Optional[int] = ..., - statefulProcessorCall: Optional[Union[StatefulProcessorCall, Mapping]] = ..., - stateVariableRequest: Optional[Union[StateVariableRequest, Mapping]] = ..., - implicitGroupingKeyRequest: Optional[Union[ImplicitGroupingKeyRequest, Mapping]] = ..., + version: _Optional[int] = ..., + statefulProcessorCall: _Optional[_Union[StatefulProcessorCall, _Mapping]] = ..., + stateVariableRequest: _Optional[_Union[StateVariableRequest, _Mapping]] = ..., + implicitGroupingKeyRequest: _Optional[_Union[ImplicitGroupingKeyRequest, _Mapping]] = ..., ) -> None: ... class StateResponse(_message.Message): - __slots__ = ["errorMessage", "statusCode", "value"] - ERRORMESSAGE_FIELD_NUMBER: ClassVar[int] - STATUSCODE_FIELD_NUMBER: ClassVar[int] - VALUE_FIELD_NUMBER: ClassVar[int] - errorMessage: str + __slots__ = ("statusCode", "errorMessage", "value") + STATUSCODE_FIELD_NUMBER: _ClassVar[int] + ERRORMESSAGE_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] statusCode: int + errorMessage: str value: bytes def __init__( self, - statusCode: Optional[int] = ..., - errorMessage: Optional[str] = ..., - value: Optional[bytes] = ..., + statusCode: _Optional[int] = ..., + errorMessage: _Optional[str] = ..., + value: _Optional[bytes] = ..., ) -> None: ... -class StateVariableRequest(_message.Message): - __slots__ = ["valueStateCall"] - VALUESTATECALL_FIELD_NUMBER: ClassVar[int] - valueStateCall: ValueStateCall - def __init__(self, valueStateCall: Optional[Union[ValueStateCall, Mapping]] = ...) -> None: ... - class StatefulProcessorCall(_message.Message): - __slots__ = ["getListState", "getMapState", "getValueState", "setHandleState"] - GETLISTSTATE_FIELD_NUMBER: ClassVar[int] - GETMAPSTATE_FIELD_NUMBER: ClassVar[int] - GETVALUESTATE_FIELD_NUMBER: ClassVar[int] - SETHANDLESTATE_FIELD_NUMBER: ClassVar[int] + __slots__ = ("setHandleState", "getValueState", "getListState", "getMapState") + SETHANDLESTATE_FIELD_NUMBER: _ClassVar[int] + GETVALUESTATE_FIELD_NUMBER: _ClassVar[int] + GETLISTSTATE_FIELD_NUMBER: _ClassVar[int] + GETMAPSTATE_FIELD_NUMBER: _ClassVar[int] + setHandleState: SetHandleState + getValueState: StateCallCommand getListState: StateCallCommand getMapState: StateCallCommand - getValueState: StateCallCommand - setHandleState: SetHandleState def __init__( self, - setHandleState: Optional[Union[SetHandleState, Mapping]] = ..., - getValueState: Optional[Union[StateCallCommand, Mapping]] = ..., - getListState: Optional[Union[StateCallCommand, Mapping]] = ..., - getMapState: Optional[Union[StateCallCommand, Mapping]] = ..., + setHandleState: _Optional[_Union[SetHandleState, _Mapping]] = ..., + getValueState: _Optional[_Union[StateCallCommand, _Mapping]] = ..., + getListState: _Optional[_Union[StateCallCommand, _Mapping]] = ..., + getMapState: _Optional[_Union[StateCallCommand, _Mapping]] = ..., ) -> None: ... -class TTLConfig(_message.Message): - __slots__ = ["durationMs"] - DURATIONMS_FIELD_NUMBER: ClassVar[int] - durationMs: int - def __init__(self, durationMs: Optional[int] = ...) -> None: ... +class StateVariableRequest(_message.Message): + __slots__ = ("valueStateCall", "listStateCall") + VALUESTATECALL_FIELD_NUMBER: _ClassVar[int] + LISTSTATECALL_FIELD_NUMBER: _ClassVar[int] + valueStateCall: ValueStateCall + listStateCall: ListStateCall + def __init__( + self, + valueStateCall: _Optional[_Union[ValueStateCall, _Mapping]] = ..., + listStateCall: _Optional[_Union[ListStateCall, _Mapping]] = ..., + ) -> None: ... + +class ImplicitGroupingKeyRequest(_message.Message): + __slots__ = ("setImplicitKey", "removeImplicitKey") + SETIMPLICITKEY_FIELD_NUMBER: _ClassVar[int] + REMOVEIMPLICITKEY_FIELD_NUMBER: _ClassVar[int] + setImplicitKey: SetImplicitKey + removeImplicitKey: RemoveImplicitKey + def __init__( + self, + setImplicitKey: _Optional[_Union[SetImplicitKey, _Mapping]] = ..., + removeImplicitKey: _Optional[_Union[RemoveImplicitKey, _Mapping]] = ..., + ) -> None: ... + +class StateCallCommand(_message.Message): + __slots__ = ("stateName", "schema", "ttl") + STATENAME_FIELD_NUMBER: _ClassVar[int] + SCHEMA_FIELD_NUMBER: _ClassVar[int] + TTL_FIELD_NUMBER: _ClassVar[int] + stateName: str + schema: str + ttl: TTLConfig + def __init__( + self, + stateName: _Optional[str] = ..., + schema: _Optional[str] = ..., + ttl: _Optional[_Union[TTLConfig, _Mapping]] = ..., + ) -> None: ... class ValueStateCall(_message.Message): - __slots__ = ["clear", "exists", "get", "stateName", "valueStateUpdate"] - CLEAR_FIELD_NUMBER: ClassVar[int] - EXISTS_FIELD_NUMBER: ClassVar[int] - GET_FIELD_NUMBER: ClassVar[int] - STATENAME_FIELD_NUMBER: ClassVar[int] - VALUESTATEUPDATE_FIELD_NUMBER: ClassVar[int] - clear: Clear + __slots__ = ("stateName", "exists", "get", "valueStateUpdate", "clear") + STATENAME_FIELD_NUMBER: _ClassVar[int] + EXISTS_FIELD_NUMBER: _ClassVar[int] + GET_FIELD_NUMBER: _ClassVar[int] + VALUESTATEUPDATE_FIELD_NUMBER: _ClassVar[int] + CLEAR_FIELD_NUMBER: _ClassVar[int] + stateName: str exists: Exists get: Get - stateName: str valueStateUpdate: ValueStateUpdate + clear: Clear + def __init__( + self, + stateName: _Optional[str] = ..., + exists: _Optional[_Union[Exists, _Mapping]] = ..., + get: _Optional[_Union[Get, _Mapping]] = ..., + valueStateUpdate: _Optional[_Union[ValueStateUpdate, _Mapping]] = ..., + clear: _Optional[_Union[Clear, _Mapping]] = ..., + ) -> None: ... + +class ListStateCall(_message.Message): + __slots__ = ( + "stateName", + "exists", + "listStateGet", + "listStatePut", + "appendValue", + "appendList", + "clear", + ) + STATENAME_FIELD_NUMBER: _ClassVar[int] + EXISTS_FIELD_NUMBER: _ClassVar[int] + LISTSTATEGET_FIELD_NUMBER: _ClassVar[int] + LISTSTATEPUT_FIELD_NUMBER: _ClassVar[int] + APPENDVALUE_FIELD_NUMBER: _ClassVar[int] + APPENDLIST_FIELD_NUMBER: _ClassVar[int] + CLEAR_FIELD_NUMBER: _ClassVar[int] + stateName: str + exists: Exists + listStateGet: ListStateGet + listStatePut: ListStatePut + appendValue: AppendValue + appendList: AppendList + clear: Clear def __init__( self, - stateName: Optional[str] = ..., - exists: Optional[Union[Exists, Mapping]] = ..., - get: Optional[Union[Get, Mapping]] = ..., - valueStateUpdate: Optional[Union[ValueStateUpdate, Mapping]] = ..., - clear: Optional[Union[Clear, Mapping]] = ..., + stateName: _Optional[str] = ..., + exists: _Optional[_Union[Exists, _Mapping]] = ..., + listStateGet: _Optional[_Union[ListStateGet, _Mapping]] = ..., + listStatePut: _Optional[_Union[ListStatePut, _Mapping]] = ..., + appendValue: _Optional[_Union[AppendValue, _Mapping]] = ..., + appendList: _Optional[_Union[AppendList, _Mapping]] = ..., + clear: _Optional[_Union[Clear, _Mapping]] = ..., ) -> None: ... +class SetImplicitKey(_message.Message): + __slots__ = ("key",) + KEY_FIELD_NUMBER: _ClassVar[int] + key: bytes + def __init__(self, key: _Optional[bytes] = ...) -> None: ... + +class RemoveImplicitKey(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class Exists(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class Get(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + class ValueStateUpdate(_message.Message): - __slots__ = ["value"] - VALUE_FIELD_NUMBER: ClassVar[int] + __slots__ = ("value",) + VALUE_FIELD_NUMBER: _ClassVar[int] value: bytes - def __init__(self, value: Optional[bytes] = ...) -> None: ... + def __init__(self, value: _Optional[bytes] = ...) -> None: ... -class HandleState(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): +class Clear(_message.Message): __slots__ = () + def __init__(self) -> None: ... + +class ListStateGet(_message.Message): + __slots__ = ("iteratorId",) + ITERATORID_FIELD_NUMBER: _ClassVar[int] + iteratorId: str + def __init__(self, iteratorId: _Optional[str] = ...) -> None: ... + +class ListStatePut(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class AppendValue(_message.Message): + __slots__ = ("value",) + VALUE_FIELD_NUMBER: _ClassVar[int] + value: bytes + def __init__(self, value: _Optional[bytes] = ...) -> None: ... + +class AppendList(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class SetHandleState(_message.Message): + __slots__ = ("state",) + STATE_FIELD_NUMBER: _ClassVar[int] + state: HandleState + def __init__(self, state: _Optional[_Union[HandleState, str]] = ...) -> None: ... + +class TTLConfig(_message.Message): + __slots__ = ("durationMs",) + DURATIONMS_FIELD_NUMBER: _ClassVar[int] + durationMs: int + def __init__(self, durationMs: _Optional[int] = ...) -> None: ... diff --git a/python/pyspark/sql/streaming/list_state_client.py b/python/pyspark/sql/streaming/list_state_client.py new file mode 100644 index 0000000000000..93306eca425eb --- /dev/null +++ b/python/pyspark/sql/streaming/list_state_client.py @@ -0,0 +1,187 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Dict, Iterator, List, Union, cast, Tuple + +from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient +from pyspark.sql.types import StructType, TYPE_CHECKING, _parse_datatype_string +from pyspark.errors import PySparkRuntimeError +import uuid + +if TYPE_CHECKING: + from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike + +__all__ = ["ListStateClient"] + + +class ListStateClient: + def __init__(self, stateful_processor_api_client: StatefulProcessorApiClient) -> None: + self._stateful_processor_api_client = stateful_processor_api_client + # A dictionary to store the mapping between list state name and a tuple of pandas DataFrame + # and the index of the last row that was read. + self.pandas_df_dict: Dict[str, Tuple["PandasDataFrameLike", int]] = {} + + def exists(self, state_name: str) -> bool: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + exists_call = stateMessage.Exists() + list_state_call = stateMessage.ListStateCall(stateName=state_name, exists=exists_call) + state_variable_request = stateMessage.StateVariableRequest(listStateCall=list_state_call) + message = stateMessage.StateRequest(stateVariableRequest=state_variable_request) + + self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) + response_message = self._stateful_processor_api_client._receive_proto_message() + status = response_message[0] + if status == 0: + return True + elif status == 2: + # Expect status code is 2 when state variable doesn't have a value. + return False + else: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError( + f"Error checking value state exists: " f"{response_message[1]}" + ) + + def get(self, state_name: str, iterator_id: str) -> Tuple: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + if iterator_id in self.pandas_df_dict: + # If the state is already in the dictionary, return the next row. + pandas_df, index = self.pandas_df_dict[iterator_id] + else: + # If the state is not in the dictionary, fetch the state from the server. + get_call = stateMessage.ListStateGet(iteratorId=iterator_id) + list_state_call = stateMessage.ListStateCall( + stateName=state_name, listStateGet=get_call + ) + state_variable_request = stateMessage.StateVariableRequest( + listStateCall=list_state_call + ) + message = stateMessage.StateRequest(stateVariableRequest=state_variable_request) + + self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) + response_message = self._stateful_processor_api_client._receive_proto_message() + status = response_message[0] + if status == 0: + iterator = self._stateful_processor_api_client._read_arrow_state() + batch = next(iterator) + pandas_df = batch.to_pandas() + index = 0 + else: + raise StopIteration() + + new_index = index + 1 + if new_index < len(pandas_df): + # Update the index in the dictionary. + self.pandas_df_dict[iterator_id] = (pandas_df, new_index) + else: + # If the index is at the end of the DataFrame, remove the state from the dictionary. + self.pandas_df_dict.pop(iterator_id, None) + pandas_row = pandas_df.iloc[index] + return tuple(pandas_row) + + def append_value(self, state_name: str, schema: Union[StructType, str], value: Tuple) -> None: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + if isinstance(schema, str): + schema = cast(StructType, _parse_datatype_string(schema)) + bytes = self._stateful_processor_api_client._serialize_to_bytes(schema, value) + append_value_call = stateMessage.AppendValue(value=bytes) + list_state_call = stateMessage.ListStateCall( + stateName=state_name, appendValue=append_value_call + ) + state_variable_request = stateMessage.StateVariableRequest(listStateCall=list_state_call) + message = stateMessage.StateRequest(stateVariableRequest=state_variable_request) + + self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) + response_message = self._stateful_processor_api_client._receive_proto_message() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error updating value state: " f"{response_message[1]}") + + def append_list( + self, state_name: str, schema: Union[StructType, str], values: List[Tuple] + ) -> None: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + if isinstance(schema, str): + schema = cast(StructType, _parse_datatype_string(schema)) + append_list_call = stateMessage.AppendList() + list_state_call = stateMessage.ListStateCall( + stateName=state_name, appendList=append_list_call + ) + state_variable_request = stateMessage.StateVariableRequest(listStateCall=list_state_call) + message = stateMessage.StateRequest(stateVariableRequest=state_variable_request) + + self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) + + self._stateful_processor_api_client._send_arrow_state(schema, values) + response_message = self._stateful_processor_api_client._receive_proto_message() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error updating value state: " f"{response_message[1]}") + + def put(self, state_name: str, schema: Union[StructType, str], values: List[Tuple]) -> None: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + if isinstance(schema, str): + schema = cast(StructType, _parse_datatype_string(schema)) + put_call = stateMessage.ListStatePut() + list_state_call = stateMessage.ListStateCall(stateName=state_name, listStatePut=put_call) + state_variable_request = stateMessage.StateVariableRequest(listStateCall=list_state_call) + message = stateMessage.StateRequest(stateVariableRequest=state_variable_request) + + self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) + + self._stateful_processor_api_client._send_arrow_state(schema, values) + response_message = self._stateful_processor_api_client._receive_proto_message() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error updating value state: " f"{response_message[1]}") + + def clear(self, state_name: str) -> None: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + clear_call = stateMessage.Clear() + list_state_call = stateMessage.ListStateCall(stateName=state_name, clear=clear_call) + state_variable_request = stateMessage.StateVariableRequest(listStateCall=list_state_call) + message = stateMessage.StateRequest(stateVariableRequest=state_variable_request) + + self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) + response_message = self._stateful_processor_api_client._receive_proto_message() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error clearing value state: " f"{response_message[1]}") + + +class ListStateIterator: + def __init__(self, list_state_client: ListStateClient, state_name: str): + self.list_state_client = list_state_client + self.state_name = state_name + # Generate a unique identifier for the iterator to make sure iterators from the same + # list state do not interfere with each other. + self.iterator_id = str(uuid.uuid4()) + + def __iter__(self) -> Iterator[Tuple]: + return self + + def __next__(self) -> Tuple: + return self.list_state_client.get(self.state_name, self.iterator_id) diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py b/python/pyspark/sql/streaming/python_streaming_source_runner.py index c50bd3915784b..a7349779dc626 100644 --- a/python/pyspark/sql/streaming/python_streaming_source_runner.py +++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py @@ -193,6 +193,8 @@ def main(infile: IO, outfile: IO) -> None: reader.stop() except BaseException as e: handle_worker_exception(e, outfile) + # ensure that the updates to the socket are flushed + outfile.flush() sys.exit(-1) send_accumulator_updates(outfile) diff --git a/python/pyspark/sql/streaming/stateful_processor.py b/python/pyspark/sql/streaming/stateful_processor.py index 9045c81e287cd..6b8de0f8ac4ec 100644 --- a/python/pyspark/sql/streaming/stateful_processor.py +++ b/python/pyspark/sql/streaming/stateful_processor.py @@ -16,12 +16,12 @@ # from abc import ABC, abstractmethod -from typing import Any, TYPE_CHECKING, Iterator, Optional, Union, cast +from typing import Any, List, TYPE_CHECKING, Iterator, Optional, Union, Tuple -from pyspark.sql import Row from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient +from pyspark.sql.streaming.list_state_client import ListStateClient, ListStateIterator from pyspark.sql.streaming.value_state_client import ValueStateClient -from pyspark.sql.types import StructType, _create_row, _parse_datatype_string +from pyspark.sql.types import StructType if TYPE_CHECKING: from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike @@ -50,21 +50,13 @@ def exists(self) -> bool: """ return self._value_state_client.exists(self._state_name) - def get(self) -> Optional[Row]: + def get(self) -> Optional[Tuple]: """ Get the state value if it exists. Returns None if the state variable does not have a value. """ - value = self._value_state_client.get(self._state_name) - if value is None: - return None - schema = self.schema - if isinstance(schema, str): - schema = cast(StructType, _parse_datatype_string(schema)) - # Create the Row using the values and schema fields - row = _create_row(schema.fieldNames(), value) - return row + return self._value_state_client.get(self._state_name) - def update(self, new_value: Any) -> None: + def update(self, new_value: Tuple) -> None: """ Update the value of the state. """ @@ -77,6 +69,58 @@ def clear(self) -> None: self._value_state_client.clear(self._state_name) +class ListState: + """ + Class used for arbitrary stateful operations with transformWithState to capture list value + state. + + .. versionadded:: 4.0.0 + """ + + def __init__( + self, list_state_client: ListStateClient, state_name: str, schema: Union[StructType, str] + ) -> None: + self._list_state_client = list_state_client + self._state_name = state_name + self.schema = schema + + def exists(self) -> bool: + """ + Whether list state exists or not. + """ + return self._list_state_client.exists(self._state_name) + + def get(self) -> Iterator[Tuple]: + """ + Get list state with an iterator. + """ + return ListStateIterator(self._list_state_client, self._state_name) + + def put(self, new_state: List[Tuple]) -> None: + """ + Update the values of the list state. + """ + self._list_state_client.put(self._state_name, self.schema, new_state) + + def append_value(self, new_state: Tuple) -> None: + """ + Append a new value to the list state. + """ + self._list_state_client.append_value(self._state_name, self.schema, new_state) + + def append_list(self, new_state: List[Tuple]) -> None: + """ + Append a list of new values to the list state. + """ + self._list_state_client.append_list(self._state_name, self.schema, new_state) + + def clear(self) -> None: + """ + Remove this state. + """ + self._list_state_client.clear(self._state_name) + + class StatefulProcessorHandle: """ Represents the operation handle provided to the stateful processor used in transformWithState @@ -112,6 +156,30 @@ def getValueState( self.stateful_processor_api_client.get_value_state(state_name, schema, ttl_duration_ms) return ValueState(ValueStateClient(self.stateful_processor_api_client), state_name, schema) + def getListState( + self, state_name: str, schema: Union[StructType, str], ttl_duration_ms: Optional[int] = None + ) -> ListState: + """ + Function to create new or return existing single value state variable of given type. + The user must ensure to call this function only within the `init()` method of the + :class:`StatefulProcessor`. + + Parameters + ---------- + state_name : str + name of the state variable + schema : :class:`pyspark.sql.types.DataType` or str + The schema of the state variable. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. + ttlDurationMs: int + Time to live duration of the state in milliseconds. State values will not be returned + past ttlDuration and will be eventually removed from the state store. Any state update + resets the expiration time to current processing time plus ttlDuration. + If ttl is not specified the state will never expire. + """ + self.stateful_processor_api_client.get_list_state(state_name, schema, ttl_duration_ms) + return ListState(ListStateClient(self.stateful_processor_api_client), state_name, schema) + class StatefulProcessor(ABC): """ diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py b/python/pyspark/sql/streaming/stateful_processor_api_client.py index 9703aa17d3474..449d5a2ad55dc 100644 --- a/python/pyspark/sql/streaming/stateful_processor_api_client.py +++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py @@ -17,10 +17,16 @@ from enum import Enum import os import socket -from typing import Any, Union, Optional, cast, Tuple +from typing import Any, List, Union, Optional, cast, Tuple from pyspark.serializers import write_int, read_int, UTF8Deserializer -from pyspark.sql.types import StructType, _parse_datatype_string, Row +from pyspark.sql.pandas.serializers import ArrowStreamSerializer +from pyspark.sql.types import ( + StructType, + _parse_datatype_string, + Row, +) +from pyspark.sql.pandas.types import convert_pandas_using_numpy_type from pyspark.sql.utils import has_numpy from pyspark.serializers import CPickleSerializer from pyspark.errors import PySparkRuntimeError @@ -46,6 +52,7 @@ def __init__(self, state_server_port: int, key_schema: StructType) -> None: self.handle_state = StatefulProcessorHandleState.CREATED self.utf8_deserializer = UTF8Deserializer() self.pickleSer = CPickleSerializer() + self.serializer = ArrowStreamSerializer() def set_handle_state(self, state: StatefulProcessorHandleState) -> None: import pyspark.sql.streaming.StateMessage_pb2 as stateMessage @@ -124,6 +131,29 @@ def get_value_state( # TODO(SPARK-49233): Classify user facing errors. raise PySparkRuntimeError(f"Error initializing value state: " f"{response_message[1]}") + def get_list_state( + self, state_name: str, schema: Union[StructType, str], ttl_duration_ms: Optional[int] + ) -> None: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + if isinstance(schema, str): + schema = cast(StructType, _parse_datatype_string(schema)) + + state_call_command = stateMessage.StateCallCommand() + state_call_command.stateName = state_name + state_call_command.schema = schema.json() + if ttl_duration_ms is not None: + state_call_command.ttl.durationMs = ttl_duration_ms + call = stateMessage.StatefulProcessorCall(getListState=state_call_command) + message = stateMessage.StateRequest(statefulProcessorCall=call) + + self._send_proto_message(message.SerializeToString()) + response_message = self._receive_proto_message() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error initializing value state: " f"{response_message[1]}") + def _send_proto_message(self, message: bytes) -> None: # Writing zero here to indicate message version. This allows us to evolve the message # format or even changing the message protocol in the future. @@ -168,3 +198,18 @@ def _serialize_to_bytes(self, schema: StructType, data: Tuple) -> bytes: def _deserialize_from_bytes(self, value: bytes) -> Any: return self.pickleSer.loads(value) + + def _send_arrow_state(self, schema: StructType, state: List[Tuple]) -> None: + import pyarrow as pa + import pandas as pd + + column_names = [field.name for field in schema.fields] + pandas_df = convert_pandas_using_numpy_type( + pd.DataFrame(state, columns=column_names), schema + ) + batch = pa.RecordBatch.from_pandas(pandas_df) + self.serializer.dump_stream(iter([batch]), self.sockfile) + self.sockfile.flush() + + def _read_arrow_state(self) -> Any: + return self.serializer.load_stream(self.sockfile) diff --git a/python/pyspark/sql/streaming/value_state_client.py b/python/pyspark/sql/streaming/value_state_client.py index e902f70cb40a5..3fe32bcc5235c 100644 --- a/python/pyspark/sql/streaming/value_state_client.py +++ b/python/pyspark/sql/streaming/value_state_client.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Any, Union, cast, Tuple +from typing import Union, cast, Tuple, Optional from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient from pyspark.sql.types import StructType, _parse_datatype_string @@ -49,7 +49,7 @@ def exists(self, state_name: str) -> bool: f"Error checking value state exists: " f"{response_message[1]}" ) - def get(self, state_name: str) -> Any: + def get(self, state_name: str) -> Optional[Tuple]: import pyspark.sql.streaming.StateMessage_pb2 as stateMessage get_call = stateMessage.Get() @@ -63,8 +63,8 @@ def get(self, state_name: str) -> Any: if status == 0: if len(response_message[2]) == 0: return None - row = self._stateful_processor_api_client._deserialize_from_bytes(response_message[2]) - return row + data = self._stateful_processor_api_client._deserialize_from_bytes(response_message[2]) + return tuple(data) else: # TODO(SPARK-49233): Classify user facing errors. raise PySparkRuntimeError(f"Error getting value state: " f"{response_message[1]}") diff --git a/python/pyspark/sql/tests/connect/client/test_artifact.py b/python/pyspark/sql/tests/connect/client/test_artifact.py index c886ff36d776f..0857591c306ae 100644 --- a/python/pyspark/sql/tests/connect/client/test_artifact.py +++ b/python/pyspark/sql/tests/connect/client/test_artifact.py @@ -21,7 +21,6 @@ import os from pyspark.util import is_remote_only -from pyspark.errors.exceptions.connect import SparkConnectGrpcException from pyspark.sql import SparkSession from pyspark.testing.connectutils import ReusedConnectTestCase, should_test_connect from pyspark.testing.utils import SPARK_HOME @@ -30,6 +29,7 @@ if should_test_connect: from pyspark.sql.connect.client.artifact import ArtifactManager from pyspark.sql.connect.client import DefaultChannelBuilder + from pyspark.errors.exceptions.connect import SparkConnectGrpcException class ArtifactTestsMixin: diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index 5deb73a0ccf90..741d6b9c1104e 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -408,8 +408,8 @@ def not_found(): def checks(): self.assertEqual(1, stub.execute_calls) self.assertEqual(1, stub.attach_calls) - self.assertEqual(0, stub.release_calls) - self.assertEqual(0, stub.release_until_calls) + self.assertEqual(1, stub.release_calls) + self.assertEqual(1, stub.release_until_calls) eventually(timeout=1, catch_assertions=True)(checks)() diff --git a/python/pyspark/sql/tests/connect/test_parity_frame_plot.py b/python/pyspark/sql/tests/connect/test_parity_frame_plot.py new file mode 100644 index 0000000000000..c69e438bf7eb0 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_frame_plot.py @@ -0,0 +1,36 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.testing.connectutils import ReusedConnectTestCase +from pyspark.sql.tests.plot.test_frame_plot import DataFramePlotTestsMixin + + +class FramePlotParityTests(DataFramePlotTestsMixin, ReusedConnectTestCase): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.test_parity_frame_plot import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/connect/test_parity_frame_plot_plotly.py b/python/pyspark/sql/tests/connect/test_parity_frame_plot_plotly.py new file mode 100644 index 0000000000000..78508fe533379 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_frame_plot_plotly.py @@ -0,0 +1,36 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.testing.connectutils import ReusedConnectTestCase +from pyspark.sql.tests.plot.test_frame_plot_plotly import DataFramePlotPlotlyTestsMixin + + +class FramePlotPlotlyParityTests(DataFramePlotPlotlyTestsMixin, ReusedConnectTestCase): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.test_parity_frame_plot_plotly import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/connect/test_parity_readwriter.py b/python/pyspark/sql/tests/connect/test_parity_readwriter.py index 46333b555c351..f83f3edbfa787 100644 --- a/python/pyspark/sql/tests/connect/test_parity_readwriter.py +++ b/python/pyspark/sql/tests/connect/test_parity_readwriter.py @@ -33,6 +33,7 @@ def test_api(self): def test_partitioning_functions(self): self.check_partitioning_functions(DataFrameWriterV2) + self.partitioning_functions_user_error() if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/connect/test_parity_sql.py b/python/pyspark/sql/tests/connect/test_parity_sql.py new file mode 100644 index 0000000000000..4c6b11c60cbe9 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_sql.py @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.sql.tests.test_sql import SQLTestsMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class SQLParityTests(SQLTestsMixin, ReusedConnectTestCase): + pass + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.test_parity_sql import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index 8ad24704de3a4..01cd441941d93 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -59,6 +59,7 @@ def conf(cls): "spark.sql.streaming.stateStore.providerClass", "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider", ) + cfg.set("spark.sql.execution.arrow.transformWithStateInPandas.maxRecordsPerBatch", "2") return cfg def _prepare_input_data(self, input_path, col1, col2): @@ -211,6 +212,27 @@ def test_transform_with_state_in_pandas_query_restarts(self): Row(id="1", countAsString="2"), } + def test_transform_with_state_in_pandas_list_state(self): + def check_results(batch_df, _): + assert set(batch_df.sort("id").collect()) == { + Row(id="0", countAsString="2"), + Row(id="1", countAsString="2"), + } + + self._test_transform_with_state_in_pandas_basic(ListStateProcessor(), check_results, True) + + # test list state with ttl has the same behavior as list state when state doesn't expire. + def test_transform_with_state_in_pandas_list_state_large_ttl(self): + def check_results(batch_df, _): + assert set(batch_df.sort("id").collect()) == { + Row(id="0", countAsString="2"), + Row(id="1", countAsString="2"), + } + + self._test_transform_with_state_in_pandas_basic( + ListStateLargeTTLProcessor(), check_results, True, "processingTime" + ) + # test value state with ttl has the same behavior as value state when # state doesn't expire. def test_value_state_ttl_basic(self): @@ -238,8 +260,10 @@ def check_results(batch_df, batch_id): [ Row(id="ttl-count-0", count=1), Row(id="count-0", count=1), + Row(id="ttl-list-state-count-0", count=1), Row(id="ttl-count-1", count=1), Row(id="count-1", count=1), + Row(id="ttl-list-state-count-1", count=1), ], ) elif batch_id == 1: @@ -248,21 +272,29 @@ def check_results(batch_df, batch_id): [ Row(id="ttl-count-0", count=2), Row(id="count-0", count=2), + Row(id="ttl-list-state-count-0", count=3), Row(id="ttl-count-1", count=2), Row(id="count-1", count=2), + Row(id="ttl-list-state-count-1", count=3), ], ) elif batch_id == 2: # ttl-count-0 expire and restart from count 0. - # ttl-count-1 get reset in batch 1 and keep the state + # The TTL for value state ttl_count_state gets reset in batch 1 because of the + # update operation and ttl-count-1 keeps the state. + # ttl-list-state-count-0 expire and restart from count 0. + # The TTL for list state ttl_list_state gets reset in batch 1 because of the + # put operation and ttl-list-state-count-1 keeps the state. # non-ttl state never expires assertDataFrameEqual( batch_df, [ Row(id="ttl-count-0", count=1), Row(id="count-0", count=3), + Row(id="ttl-list-state-count-0", count=1), Row(id="ttl-count-1", count=3), Row(id="count-1", count=3), + Row(id="ttl-list-state-count-1", count=7), ], ) if batch_id == 0 or batch_id == 1: @@ -352,25 +384,38 @@ def init(self, handle: StatefulProcessorHandle) -> None: state_schema = StructType([StructField("value", IntegerType(), True)]) self.ttl_count_state = handle.getValueState("ttl-state", state_schema, 10000) self.count_state = handle.getValueState("state", state_schema) + self.ttl_list_state = handle.getListState("ttl-list-state", state_schema, 10000) def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]: count = 0 ttl_count = 0 + ttl_list_state_count = 0 id = key[0] if self.count_state.exists(): count = self.count_state.get()[0] if self.ttl_count_state.exists(): ttl_count = self.ttl_count_state.get()[0] + if self.ttl_list_state.exists(): + iter = self.ttl_list_state.get() + for s in iter: + ttl_list_state_count += s[0] for pdf in rows: pdf_count = pdf.count().get("temperature") count += pdf_count ttl_count += pdf_count + ttl_list_state_count += pdf_count self.count_state.update((count,)) # skip updating state for the 2nd batch so that ttl state expire if not (ttl_count == 2 and id == "0"): self.ttl_count_state.update((ttl_count,)) - yield pd.DataFrame({"id": [f"ttl-count-{id}", f"count-{id}"], "count": [ttl_count, count]}) + self.ttl_list_state.put([(ttl_list_state_count,), (ttl_list_state_count,)]) + yield pd.DataFrame( + { + "id": [f"ttl-count-{id}", f"count-{id}", f"ttl-list-state-count-{id}"], + "count": [ttl_count, count, ttl_list_state_count], + } + ) def close(self) -> None: pass @@ -394,6 +439,68 @@ def close(self) -> None: pass +class ListStateProcessor(StatefulProcessor): + # Dict to store the expected results. The key represents the grouping key string, and the value + # is a dictionary of pandas dataframe index -> expected temperature value. Since we set + # maxRecordsPerBatch to 2, we expect the pandas dataframe dictionary to have 2 entries. + dict = {0: 120, 1: 20} + + def init(self, handle: StatefulProcessorHandle) -> None: + state_schema = StructType([StructField("temperature", IntegerType(), True)]) + self.list_state1 = handle.getListState("listState1", state_schema) + self.list_state2 = handle.getListState("listState2", state_schema) + + def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]: + count = 0 + for pdf in rows: + list_state_rows = [(120,), (20,)] + self.list_state1.put(list_state_rows) + self.list_state2.put(list_state_rows) + self.list_state1.append_value((111,)) + self.list_state2.append_value((222,)) + self.list_state1.append_list(list_state_rows) + self.list_state2.append_list(list_state_rows) + pdf_count = pdf.count() + count += pdf_count.get("temperature") + iter1 = self.list_state1.get() + iter2 = self.list_state2.get() + # Mixing the iterator to test it we can resume from the correct point + assert next(iter1)[0] == self.dict[0] + assert next(iter2)[0] == self.dict[0] + assert next(iter1)[0] == self.dict[1] + assert next(iter2)[0] == self.dict[1] + # Get another iterator for list_state1 to test if the 2 iterators (iter1 and iter3) don't + # interfere with each other. + iter3 = self.list_state1.get() + assert next(iter3)[0] == self.dict[0] + assert next(iter3)[0] == self.dict[1] + # the second arrow batch should contain the appended value 111 for list_state1 and + # 222 for list_state2 + assert next(iter1)[0] == 111 + assert next(iter2)[0] == 222 + assert next(iter3)[0] == 111 + # since we put another 2 rows after 111/222, check them here + assert next(iter1)[0] == self.dict[0] + assert next(iter2)[0] == self.dict[0] + assert next(iter3)[0] == self.dict[0] + assert next(iter1)[0] == self.dict[1] + assert next(iter2)[0] == self.dict[1] + assert next(iter3)[0] == self.dict[1] + yield pd.DataFrame({"id": key, "countAsString": str(count)}) + + def close(self) -> None: + pass + + +# A stateful processor that inherit all behavior of ListStateProcessor except that it use +# ttl state with a large timeout. +class ListStateLargeTTLProcessor(ListStateProcessor): + def init(self, handle: StatefulProcessorHandle) -> None: + state_schema = StructType([StructField("temperature", IntegerType(), True)]) + self.list_state1 = handle.getListState("listState1", state_schema, 30000) + self.list_state2 = handle.getListState("listState2", state_schema, 30000) + + class TransformWithStateInPandasTests(TransformWithStateInPandasTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/sql/tests/plot/__init__.py b/python/pyspark/sql/tests/plot/__init__.py new file mode 100644 index 0000000000000..cce3acad34a49 --- /dev/null +++ b/python/pyspark/sql/tests/plot/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/sql/tests/plot/test_frame_plot.py b/python/pyspark/sql/tests/plot/test_frame_plot.py new file mode 100644 index 0000000000000..2a6971e896292 --- /dev/null +++ b/python/pyspark/sql/tests/plot/test_frame_plot.py @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +from pyspark.errors import PySparkValueError +from pyspark.sql import Row +from pyspark.sql.plot import PySparkSampledPlotBase, PySparkTopNPlotBase +from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message + + +@unittest.skipIf(not have_plotly, plotly_requirement_message) +class DataFramePlotTestsMixin: + def test_backend(self): + accessor = self.spark.range(2).plot + backend = accessor._get_plot_backend() + self.assertEqual(backend.__name__, "pyspark.sql.plot.plotly") + + with self.assertRaises(PySparkValueError) as pe: + accessor._get_plot_backend("matplotlib") + + self.check_error( + exception=pe.exception, + errorClass="UNSUPPORTED_PLOT_BACKEND", + messageParameters={"backend": "matplotlib", "supported_backends": "plotly"}, + ) + + def test_topn_max_rows(self): + with self.sql_conf({"spark.sql.pyspark.plotting.max_rows": "1000"}): + self.spark.conf.set("spark.sql.pyspark.plotting.max_rows", "1000") + sdf = self.spark.range(2500) + pdf = PySparkTopNPlotBase().get_top_n(sdf) + self.assertEqual(len(pdf), 1000) + + def test_sampled_plot_with_max_rows(self): + data = [Row(a=i, b=i + 1, c=i + 2, d=i + 3) for i in range(2000)] + sdf = self.spark.createDataFrame(data) + pdf = PySparkSampledPlotBase().get_sampled(sdf) + self.assertEqual(round(len(pdf) / 2000, 1), 0.5) + + +class DataFramePlotTests(DataFramePlotTestsMixin, ReusedSQLTestCase): + pass + + +if __name__ == "__main__": + from pyspark.sql.tests.plot.test_frame_plot import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py new file mode 100644 index 0000000000000..d870cdbf9959b --- /dev/null +++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py @@ -0,0 +1,392 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +from datetime import datetime + +import pyspark.sql.plot # noqa: F401 +from pyspark.errors import PySparkTypeError, PySparkValueError +from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message + + +@unittest.skipIf(not have_plotly, plotly_requirement_message) +class DataFramePlotPlotlyTestsMixin: + @property + def sdf(self): + data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)] + columns = ["category", "int_val", "float_val"] + return self.spark.createDataFrame(data, columns) + + @property + def sdf2(self): + data = [(5.1, 3.5, 0), (4.9, 3.0, 0), (7.0, 3.2, 1), (6.4, 3.2, 1), (5.9, 3.0, 2)] + columns = ["length", "width", "species"] + return self.spark.createDataFrame(data, columns) + + @property + def sdf3(self): + data = [ + (3, 5, 20, datetime(2018, 1, 31)), + (2, 5, 42, datetime(2018, 2, 28)), + (3, 6, 28, datetime(2018, 3, 31)), + (9, 12, 62, datetime(2018, 4, 30)), + ] + columns = ["sales", "signups", "visits", "date"] + return self.spark.createDataFrame(data, columns) + + @property + def sdf4(self): + data = [ + ("A", 50, 55), + ("B", 55, 60), + ("C", 60, 65), + ("D", 65, 70), + ("E", 70, 75), + # outliers + ("F", 10, 15), + ("G", 85, 90), + ("H", 5, 150), + ] + columns = ["student", "math_score", "english_score"] + return self.spark.createDataFrame(data, columns) + + def _check_fig_data(self, fig_data, **kwargs): + for key, expected_value in kwargs.items(): + if key in ["x", "y", "labels", "values"]: + converted_values = [v.item() if hasattr(v, "item") else v for v in fig_data[key]] + self.assertEqual(converted_values, expected_value) + else: + self.assertEqual(fig_data[key], expected_value) + + def test_line_plot(self): + # single column as vertical axis + fig = self.sdf.plot(kind="line", x="category", y="int_val") + expected_fig_data = { + "mode": "lines", + "name": "", + "orientation": "v", + "x": ["A", "B", "C"], + "xaxis": "x", + "y": [10, 30, 20], + "yaxis": "y", + "type": "scatter", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) + + # multiple columns as vertical axis + fig = self.sdf.plot.line(x="category", y=["int_val", "float_val"]) + expected_fig_data = { + "mode": "lines", + "name": "int_val", + "orientation": "v", + "x": ["A", "B", "C"], + "xaxis": "x", + "y": [10, 30, 20], + "yaxis": "y", + "type": "scatter", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) + expected_fig_data = { + "mode": "lines", + "name": "float_val", + "orientation": "v", + "x": ["A", "B", "C"], + "xaxis": "x", + "y": [1.5, 2.5, 3.5], + "yaxis": "y", + "type": "scatter", + } + self._check_fig_data(fig["data"][1], **expected_fig_data) + + def test_bar_plot(self): + # single column as vertical axis + fig = self.sdf.plot(kind="bar", x="category", y="int_val") + expected_fig_data = { + "name": "", + "orientation": "v", + "x": ["A", "B", "C"], + "xaxis": "x", + "y": [10, 30, 20], + "yaxis": "y", + "type": "bar", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) + + # multiple columns as vertical axis + fig = self.sdf.plot.bar(x="category", y=["int_val", "float_val"]) + expected_fig_data = { + "name": "int_val", + "orientation": "v", + "x": ["A", "B", "C"], + "xaxis": "x", + "y": [10, 30, 20], + "yaxis": "y", + "type": "bar", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) + expected_fig_data = { + "name": "float_val", + "orientation": "v", + "x": ["A", "B", "C"], + "xaxis": "x", + "y": [1.5, 2.5, 3.5], + "yaxis": "y", + "type": "bar", + } + self._check_fig_data(fig["data"][1], **expected_fig_data) + + def test_barh_plot(self): + # single column as vertical axis + fig = self.sdf.plot(kind="barh", x="category", y="int_val") + expected_fig_data = { + "name": "", + "orientation": "h", + "x": ["A", "B", "C"], + "xaxis": "x", + "y": [10, 30, 20], + "yaxis": "y", + "type": "bar", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) + + # multiple columns as vertical axis + fig = self.sdf.plot.barh(x="category", y=["int_val", "float_val"]) + expected_fig_data = { + "name": "int_val", + "orientation": "h", + "x": ["A", "B", "C"], + "xaxis": "x", + "y": [10, 30, 20], + "yaxis": "y", + "type": "bar", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) + expected_fig_data = { + "name": "float_val", + "orientation": "h", + "x": ["A", "B", "C"], + "xaxis": "x", + "y": [1.5, 2.5, 3.5], + "yaxis": "y", + "type": "bar", + } + self._check_fig_data(fig["data"][1], **expected_fig_data) + + # multiple columns as horizontal axis + fig = self.sdf.plot.barh(x=["int_val", "float_val"], y="category") + expected_fig_data = { + "name": "int_val", + "orientation": "h", + "y": ["A", "B", "C"], + "xaxis": "x", + "x": [10, 30, 20], + "yaxis": "y", + "type": "bar", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) + expected_fig_data = { + "name": "float_val", + "orientation": "h", + "y": ["A", "B", "C"], + "xaxis": "x", + "x": [1.5, 2.5, 3.5], + "yaxis": "y", + "type": "bar", + } + self._check_fig_data(fig["data"][1], **expected_fig_data) + + def test_scatter_plot(self): + fig = self.sdf2.plot(kind="scatter", x="length", y="width") + expected_fig_data = { + "name": "", + "orientation": "v", + "x": [5.1, 4.9, 7.0, 6.4, 5.9], + "xaxis": "x", + "y": [3.5, 3.0, 3.2, 3.2, 3.0], + "yaxis": "y", + "type": "scatter", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) + expected_fig_data = { + "name": "", + "orientation": "v", + "y": [5.1, 4.9, 7.0, 6.4, 5.9], + "xaxis": "x", + "x": [3.5, 3.0, 3.2, 3.2, 3.0], + "yaxis": "y", + "type": "scatter", + } + fig = self.sdf2.plot.scatter(x="width", y="length") + self._check_fig_data(fig["data"][0], **expected_fig_data) + + def test_area_plot(self): + # single column as vertical axis + fig = self.sdf3.plot(kind="area", x="date", y="sales") + expected_x = [ + datetime(2018, 1, 31, 0, 0), + datetime(2018, 2, 28, 0, 0), + datetime(2018, 3, 31, 0, 0), + datetime(2018, 4, 30, 0, 0), + ] + expected_fig_data = { + "name": "", + "orientation": "v", + "x": expected_x, + "xaxis": "x", + "y": [3, 2, 3, 9], + "yaxis": "y", + "mode": "lines", + "type": "scatter", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) + + # multiple columns as vertical axis + fig = self.sdf3.plot.area(x="date", y=["sales", "signups", "visits"]) + expected_fig_data = { + "name": "sales", + "orientation": "v", + "x": expected_x, + "xaxis": "x", + "y": [3, 2, 3, 9], + "yaxis": "y", + "mode": "lines", + "type": "scatter", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) + expected_fig_data = { + "name": "signups", + "orientation": "v", + "x": expected_x, + "xaxis": "x", + "y": [5, 5, 6, 12], + "yaxis": "y", + "mode": "lines", + "type": "scatter", + } + self._check_fig_data(fig["data"][1], **expected_fig_data) + expected_fig_data = { + "name": "visits", + "orientation": "v", + "x": expected_x, + "xaxis": "x", + "y": [20, 42, 28, 62], + "yaxis": "y", + "mode": "lines", + "type": "scatter", + } + self._check_fig_data(fig["data"][2], **expected_fig_data) + + def test_pie_plot(self): + fig = self.sdf3.plot(kind="pie", x="date", y="sales") + expected_x = [ + datetime(2018, 1, 31, 0, 0), + datetime(2018, 2, 28, 0, 0), + datetime(2018, 3, 31, 0, 0), + datetime(2018, 4, 30, 0, 0), + ] + expected_fig_data = { + "name": "", + "labels": expected_x, + "values": [3, 2, 3, 9], + "type": "pie", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) + + # y is not a numerical column + with self.assertRaises(PySparkTypeError) as pe: + self.sdf.plot.pie(x="int_val", y="category") + + self.check_error( + exception=pe.exception, + errorClass="PLOT_NOT_NUMERIC_COLUMN", + messageParameters={"arg_name": "y", "arg_type": "StringType()"}, + ) + + def test_box_plot(self): + fig = self.sdf4.plot.box(column="math_score") + expected_fig_data = { + "boxpoints": "suspectedoutliers", + "lowerfence": (5,), + "mean": (50.0,), + "median": (55,), + "name": "math_score", + "notched": False, + "q1": (10,), + "q3": (65,), + "upperfence": (85,), + "x": [0], + "type": "box", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) + + fig = self.sdf4.plot(kind="box", column=["math_score", "english_score"]) + self._check_fig_data(fig["data"][0], **expected_fig_data) + expected_fig_data = { + "boxpoints": "suspectedoutliers", + "lowerfence": (55,), + "mean": (72.5,), + "median": (65,), + "name": "english_score", + "notched": False, + "q1": (55,), + "q3": (75,), + "upperfence": (90,), + "x": [1], + "y": [[150, 15]], + "type": "box", + } + self._check_fig_data(fig["data"][1], **expected_fig_data) + with self.assertRaises(PySparkValueError) as pe: + self.sdf4.plot.box(column="math_score", boxpoints=True) + self.check_error( + exception=pe.exception, + errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM", + messageParameters={ + "backend": "plotly", + "param": "boxpoints", + "value": "True", + "supported_values": ", ".join(["suspectedoutliers", "False"]), + }, + ) + with self.assertRaises(PySparkValueError) as pe: + self.sdf4.plot.box(column="math_score", notched=True) + self.check_error( + exception=pe.exception, + errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM", + messageParameters={ + "backend": "plotly", + "param": "notched", + "value": "True", + "supported_values": ", ".join(["False"]), + }, + ) + + +class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase): + pass + + +if __name__ == "__main__": + from pyspark.sql.tests.plot.test_frame_plot_plotly import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index c3ae62e64cc30..51f62f56a7c54 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -381,7 +381,12 @@ def verify(test_listener): .start() ) self.assertTrue(q.isActive) - q.awaitTermination(10) + wait_count = 0 + while progress_event is None or progress_event.progress.batchId == 0: + q.awaitTermination(0.5) + wait_count = wait_count + 1 + if wait_count > 100: + self.fail("Not getting progress event after 50 seconds") q.stop() # Make sure all events are empty diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index 2bd66baaa2bfe..5f1991973d27d 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -18,11 +18,14 @@ from enum import Enum from itertools import chain +import datetime +import unittest + from pyspark.sql import Column, Row from pyspark.sql import functions as sf from pyspark.sql.types import StructType, StructField, IntegerType, LongType from pyspark.errors import AnalysisException, PySparkTypeError, PySparkValueError -from pyspark.testing.sqlutils import ReusedSQLTestCase +from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, pandas_requirement_message class ColumnTestsMixin: @@ -280,6 +283,104 @@ def test_expr_str_representation(self): when_cond = sf.when(expression, sf.lit(None)) self.assertEqual(str(when_cond), "Column<'CASE WHEN foo THEN NULL END'>") + def test_col_field_ops_representation(self): + # SPARK-49894: Test string representation of columns + c = sf.col("c") + + # getField + self.assertEqual(str(c.x), "Column<'c['x']'>") + self.assertEqual(str(c.x.y), "Column<'c['x']['y']'>") + self.assertEqual(str(c.x.y.z), "Column<'c['x']['y']['z']'>") + + self.assertEqual(str(c["x"]), "Column<'c['x']'>") + self.assertEqual(str(c["x"]["y"]), "Column<'c['x']['y']'>") + self.assertEqual(str(c["x"]["y"]["z"]), "Column<'c['x']['y']['z']'>") + + self.assertEqual(str(c.getField("x")), "Column<'c['x']'>") + self.assertEqual( + str(c.getField("x").getField("y")), + "Column<'c['x']['y']'>", + ) + self.assertEqual( + str(c.getField("x").getField("y").getField("z")), + "Column<'c['x']['y']['z']'>", + ) + + self.assertEqual(str(c.getItem("x")), "Column<'c['x']'>") + self.assertEqual( + str(c.getItem("x").getItem("y")), + "Column<'c['x']['y']'>", + ) + self.assertEqual( + str(c.getItem("x").getItem("y").getItem("z")), + "Column<'c['x']['y']['z']'>", + ) + + self.assertEqual( + str(c.x["y"].getItem("z")), + "Column<'c['x']['y']['z']'>", + ) + self.assertEqual( + str(c["x"].getField("y").getItem("z")), + "Column<'c['x']['y']['z']'>", + ) + self.assertEqual( + str(c.getField("x").getItem("y").z), + "Column<'c['x']['y']['z']'>", + ) + self.assertEqual( + str(c["x"].y.getField("z")), + "Column<'c['x']['y']['z']'>", + ) + + # WithField + self.assertEqual( + str(c.withField("x", sf.col("y"))), + "Column<'update_field(c, x, y)'>", + ) + self.assertEqual( + str(c.withField("x", sf.col("y")).withField("x", sf.col("z"))), + "Column<'update_field(update_field(c, x, y), x, z)'>", + ) + + # DropFields + self.assertEqual(str(c.dropFields("x")), "Column<'drop_field(c, x)'>") + self.assertEqual( + str(c.dropFields("x", "y")), + "Column<'drop_field(drop_field(c, x), y)'>", + ) + self.assertEqual( + str(c.dropFields("x", "y", "z")), + "Column<'drop_field(drop_field(drop_field(c, x), y), z)'>", + ) + + def test_lit_time_representation(self): + dt = datetime.date(2021, 3, 4) + self.assertEqual(str(sf.lit(dt)), "Column<'2021-03-04'>") + + ts = datetime.datetime(2021, 3, 4, 12, 34, 56, 1234) + self.assertEqual(str(sf.lit(ts)), "Column<'2021-03-04 12:34:56.001234'>") + + @unittest.skipIf(not have_pandas, pandas_requirement_message) + def test_lit_delta_representation(self): + for delta in [ + datetime.timedelta(days=1), + datetime.timedelta(hours=2), + datetime.timedelta(minutes=3), + datetime.timedelta(seconds=4), + datetime.timedelta(microseconds=5), + datetime.timedelta(days=2, hours=21, microseconds=908), + datetime.timedelta(days=1, minutes=-3, microseconds=-1001), + datetime.timedelta(days=1, hours=2, minutes=3, seconds=4, microseconds=5), + ]: + import pandas as pd + + # Column<'PT69H0.000908S'> or Column<'P2DT21H0M0.000908S'> + s = str(sf.lit(delta)) + + # Parse the ISO string representation and compare + self.assertTrue(pd.Timedelta(s[8:-2]).to_pytimedelta() == delta) + def test_enum_literals(self): class IntEnum(Enum): X = 1 diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py b/python/pyspark/sql/tests/test_connect_compatibility.py new file mode 100644 index 0000000000000..dfa0fa63b2dd5 --- /dev/null +++ b/python/pyspark/sql/tests/test_connect_compatibility.py @@ -0,0 +1,192 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import inspect + +from pyspark.testing.connectutils import should_test_connect, connect_requirement_message +from pyspark.testing.sqlutils import ReusedSQLTestCase +from pyspark.sql.classic.dataframe import DataFrame as ClassicDataFrame +from pyspark.sql.classic.column import Column as ClassicColumn +from pyspark.sql.session import SparkSession as ClassicSparkSession + +if should_test_connect: + from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame + from pyspark.sql.connect.column import Column as ConnectColumn + from pyspark.sql.connect.session import SparkSession as ConnectSparkSession + + +class ConnectCompatibilityTestsMixin: + def get_public_methods(self, cls): + """Get public methods of a class.""" + return { + name: method + for name, method in inspect.getmembers(cls, predicate=inspect.isfunction) + if not name.startswith("_") + } + + def get_public_properties(self, cls): + """Get public properties of a class.""" + return { + name: member + for name, member in inspect.getmembers(cls) + if isinstance(member, property) and not name.startswith("_") + } + + def test_signature_comparison_between_classic_and_connect(self): + def compare_method_signatures(classic_cls, connect_cls, cls_name): + """Compare method signatures between classic and connect classes.""" + classic_methods = self.get_public_methods(classic_cls) + connect_methods = self.get_public_methods(connect_cls) + + common_methods = set(classic_methods.keys()) & set(connect_methods.keys()) + + for method in common_methods: + classic_signature = inspect.signature(classic_methods[method]) + connect_signature = inspect.signature(connect_methods[method]) + + # createDataFrame cannot be the same since RDD is not supported from Spark Connect + if not method == "createDataFrame": + self.assertEqual( + classic_signature, + connect_signature, + f"Signature mismatch in {cls_name} method '{method}'\n" + f"Classic: {classic_signature}\n" + f"Connect: {connect_signature}", + ) + + # DataFrame API signature comparison + compare_method_signatures(ClassicDataFrame, ConnectDataFrame, "DataFrame") + + # Column API signature comparison + compare_method_signatures(ClassicColumn, ConnectColumn, "Column") + + # SparkSession API signature comparison + compare_method_signatures(ClassicSparkSession, ConnectSparkSession, "SparkSession") + + def test_property_comparison_between_classic_and_connect(self): + def compare_property_lists(classic_cls, connect_cls, cls_name, expected_missing_properties): + """Compare properties between classic and connect classes.""" + classic_properties = self.get_public_properties(classic_cls) + connect_properties = self.get_public_properties(connect_cls) + + # Identify missing properties + classic_only_properties = set(classic_properties.keys()) - set( + connect_properties.keys() + ) + + # Compare the actual missing properties with the expected ones + self.assertEqual( + classic_only_properties, + expected_missing_properties, + f"{cls_name}: Unexpected missing properties in Connect: {classic_only_properties}", + ) + + # Expected missing properties for DataFrame + expected_missing_properties_for_dataframe = {"sql_ctx", "isStreaming"} + + # DataFrame properties comparison + compare_property_lists( + ClassicDataFrame, + ConnectDataFrame, + "DataFrame", + expected_missing_properties_for_dataframe, + ) + + # Expected missing properties for Column (if any, replace with actual values) + expected_missing_properties_for_column = set() + + # Column properties comparison + compare_property_lists( + ClassicColumn, ConnectColumn, "Column", expected_missing_properties_for_column + ) + + # Expected missing properties for SparkSession + expected_missing_properties_for_spark_session = {"sparkContext", "version"} + + # SparkSession properties comparison + compare_property_lists( + ClassicSparkSession, + ConnectSparkSession, + "SparkSession", + expected_missing_properties_for_spark_session, + ) + + def test_missing_methods(self): + def check_missing_methods(classic_cls, connect_cls, cls_name, expected_missing_methods): + """Check for expected missing methods between classic and connect classes.""" + classic_methods = self.get_public_methods(classic_cls) + connect_methods = self.get_public_methods(connect_cls) + + # Identify missing methods + classic_only_methods = set(classic_methods.keys()) - set(connect_methods.keys()) + + # Compare the actual missing methods with the expected ones + self.assertEqual( + classic_only_methods, + expected_missing_methods, + f"{cls_name}: Unexpected missing methods in Connect: {classic_only_methods}", + ) + + # Expected missing methods for DataFrame + expected_missing_methods_for_dataframe = { + "inputFiles", + "isLocal", + "semanticHash", + "isEmpty", + } + + # DataFrame missing method check + check_missing_methods( + ClassicDataFrame, ConnectDataFrame, "DataFrame", expected_missing_methods_for_dataframe + ) + + # Expected missing methods for Column (if any, replace with actual values) + expected_missing_methods_for_column = set() + + # Column missing method check + check_missing_methods( + ClassicColumn, ConnectColumn, "Column", expected_missing_methods_for_column + ) + + # Expected missing methods for SparkSession (if any, replace with actual values) + expected_missing_methods_for_spark_session = {"newSession"} + + # SparkSession missing method check + check_missing_methods( + ClassicSparkSession, + ConnectSparkSession, + "SparkSession", + expected_missing_methods_for_spark_session, + ) + + +@unittest.skipIf(not should_test_connect, connect_requirement_message) +class ConnectCompatibilityTests(ConnectCompatibilityTestsMixin, ReusedSQLTestCase): + pass + + +if __name__ == "__main__": + from pyspark.sql.tests.test_connect_compatibility import * # noqa: F401 + + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 8ec0839ec1fe4..2f53ca38743c1 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -15,6 +15,8 @@ # limitations under the License. # +import glob +import os import pydoc import shutil import tempfile @@ -47,6 +49,7 @@ pandas_requirement_message, pyarrow_requirement_message, ) +from pyspark.testing.utils import SPARK_HOME class DataFrameTestsMixin: @@ -506,14 +509,16 @@ def test_toDF_with_schema_string(self): # number of fields must match. self.assertRaisesRegex( - Exception, "FIELD_STRUCT_LENGTH_MISMATCH", lambda: rdd.toDF("key: int").collect() + Exception, + "FIELD_STRUCT_LENGTH_MISMATCH", + lambda: rdd.coalesce(1).toDF("key: int").collect(), ) # field types mismatch will cause exception at runtime. self.assertRaisesRegex( Exception, "FIELD_DATA_TYPE_UNACCEPTABLE", - lambda: rdd.toDF("key: float, value: string").collect(), + lambda: rdd.coalesce(1).toDF("key: float, value: string").collect(), ) # flat schema values will be wrapped into row. @@ -671,8 +676,7 @@ def test_repr_behaviors(self): |+---+-----+ || 1| 1| |+---+-----+ - |only showing top 1 row - |""" + |only showing top 1 row""" self.assertEqual(re.sub(pattern, "", expected3), df.__repr__()) # test when eager evaluation is enabled and _repr_html_ will be called @@ -778,6 +782,16 @@ def test_df_show(self): ) def test_df_merge_into(self): + filename_pattern = ( + "sql/catalyst/target/scala-*/test-classes/org/apache/spark/sql/connector/catalog/" + "InMemoryRowLevelOperationTableCatalog.class" + ) + if not bool(glob.glob(os.path.join(SPARK_HOME, filename_pattern))): + raise unittest.SkipTest( + "org.apache.spark.sql.connector.catalog.InMemoryRowLevelOperationTableCatalog' " + "is not available. Will skip the related tests" + ) + try: # InMemoryRowLevelOperationTableCatalog is a test catalog that is included in the # catalyst-test package. If Spark complains that it can't find this class, make sure @@ -951,11 +965,17 @@ def test_union_classmethod_usage(self): def test_isinstance_dataframe(self): self.assertIsInstance(self.spark.range(1), DataFrame) - def test_checkpoint_dataframe(self): + def test_local_checkpoint_dataframe(self): with io.StringIO() as buf, redirect_stdout(buf): self.spark.range(1).localCheckpoint().explain() self.assertIn("ExistingRDD", buf.getvalue()) + def test_local_checkpoint_dataframe_with_storage_level(self): + # We don't have a way to reach into the server and assert the storage level server side, but + # this test should cover for unexpected errors in the API. + df = self.spark.range(10).localCheckpoint(eager=True, storageLevel=StorageLevel.DISK_ONLY) + df.collect() + def test_transpose(self): df = self.spark.createDataFrame([{"a": "x", "b": "y", "c": "z"}]) diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index a0ab9bc9c7d40..a51156e895c62 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -29,7 +29,7 @@ from pyspark.sql import Row, Window, functions as F, types from pyspark.sql.avro.functions import from_avro, to_avro from pyspark.sql.column import Column -from pyspark.sql.functions.builtin import nullifzero, zeroifnull +from pyspark.sql.functions.builtin import nullifzero, randstr, uniform, zeroifnull from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils from pyspark.testing.utils import have_numpy @@ -1610,6 +1610,25 @@ def test_nullifzero_zeroifnull(self): result = df.select(zeroifnull(df.a).alias("r")).collect() self.assertEqual([Row(r=0), Row(r=1)], result) + def test_randstr_uniform(self): + df = self.spark.createDataFrame([(0,)], ["a"]) + result = df.select(randstr(F.lit(5), F.lit(0)).alias("x")).selectExpr("length(x)").collect() + self.assertEqual([Row(5)], result) + # The random seed is optional. + result = df.select(randstr(F.lit(5)).alias("x")).selectExpr("length(x)").collect() + self.assertEqual([Row(5)], result) + + df = self.spark.createDataFrame([(0,)], ["a"]) + result = ( + df.select(uniform(F.lit(10), F.lit(20), F.lit(0)).alias("x")) + .selectExpr("x > 5") + .collect() + ) + self.assertEqual([Row(True)], result) + # The random seed is optional. + result = df.select(uniform(F.lit(10), F.lit(20)).alias("x")).selectExpr("x > 5").collect() + self.assertEqual([Row(True)], result) + class FunctionsTests(ReusedSQLTestCase, FunctionsTestsMixin): pass diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py index f4f32dea9060a..2fca6b57decf9 100644 --- a/python/pyspark/sql/tests/test_readwriter.py +++ b/python/pyspark/sql/tests/test_readwriter.py @@ -255,6 +255,7 @@ def check_api(self, tpe): def test_partitioning_functions(self): self.check_partitioning_functions(DataFrameWriterV2) + self.partitioning_functions_user_error() def check_partitioning_functions(self, tpe): import datetime @@ -274,6 +275,35 @@ def check_partitioning_functions(self, tpe): self.assertIsInstance(writer.partitionedBy(bucket(11, col("id"))), tpe) self.assertIsInstance(writer.partitionedBy(bucket(3, "id"), hours(col("ts"))), tpe) + def partitioning_functions_user_error(self): + import datetime + from pyspark.sql.functions.partitioning import years, months, days, hours, bucket + + df = self.spark.createDataFrame( + [(1, datetime.datetime(2000, 1, 1), "foo")], ("id", "ts", "value") + ) + + with self.assertRaisesRegex( + Exception, "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY" + ): + df.select(years("ts")).collect() + with self.assertRaisesRegex( + Exception, "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY" + ): + df.select(months("ts")).collect() + with self.assertRaisesRegex( + Exception, "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY" + ): + df.select(days("ts")).collect() + with self.assertRaisesRegex( + Exception, "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY" + ): + df.select(hours("ts")).collect() + with self.assertRaisesRegex( + Exception, "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY" + ): + df.select(bucket(2, "ts")).collect() + def test_create(self): df = self.df with self.table("test_table"): diff --git a/python/pyspark/sql/tests/test_sql.py b/python/pyspark/sql/tests/test_sql.py new file mode 100644 index 0000000000000..bf50bbc11ac33 --- /dev/null +++ b/python/pyspark/sql/tests/test_sql.py @@ -0,0 +1,185 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.sql import Row +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class SQLTestsMixin: + def test_simple(self): + res = self.spark.sql("SELECT 1 + 1").collect() + self.assertEqual(len(res), 1) + self.assertEqual(res[0][0], 2) + + def test_args_dict(self): + with self.tempView("test"): + self.spark.range(10).createOrReplaceTempView("test") + df = self.spark.sql( + "SELECT * FROM IDENTIFIER(:table_name)", + args={"table_name": "test"}, + ) + + self.assertEqual(df.count(), 10) + self.assertEqual(df.limit(5).count(), 5) + self.assertEqual(df.offset(5).count(), 5) + + self.assertEqual(df.take(1), [Row(id=0)]) + self.assertEqual(df.tail(1), [Row(id=9)]) + + def test_args_list(self): + with self.tempView("test"): + self.spark.range(10).createOrReplaceTempView("test") + df = self.spark.sql( + "SELECT * FROM test WHERE ? < id AND id < ?", + args=[1, 6], + ) + + self.assertEqual(df.count(), 4) + self.assertEqual(df.limit(3).count(), 3) + self.assertEqual(df.offset(3).count(), 1) + + self.assertEqual(df.take(1), [Row(id=2)]) + self.assertEqual(df.tail(1), [Row(id=5)]) + + def test_kwargs_literal(self): + with self.tempView("test"): + self.spark.range(10).createOrReplaceTempView("test") + + df = self.spark.sql( + "SELECT * FROM IDENTIFIER(:table_name) WHERE {m1} < id AND id < {m2} OR id = {m3}", + args={"table_name": "test"}, + m1=3, + m2=7, + m3=9, + ) + + self.assertEqual(df.count(), 4) + self.assertEqual(df.collect(), [Row(id=4), Row(id=5), Row(id=6), Row(id=9)]) + self.assertEqual(df.take(1), [Row(id=4)]) + self.assertEqual(df.tail(1), [Row(id=9)]) + + def test_kwargs_literal_multiple_ref(self): + with self.tempView("test"): + self.spark.range(10).createOrReplaceTempView("test") + + df = self.spark.sql( + "SELECT * FROM IDENTIFIER(:table_name) WHERE {m} = id OR id > {m} OR {m} < 0", + args={"table_name": "test"}, + m=6, + ) + + self.assertEqual(df.count(), 4) + self.assertEqual(df.collect(), [Row(id=6), Row(id=7), Row(id=8), Row(id=9)]) + self.assertEqual(df.take(1), [Row(id=6)]) + self.assertEqual(df.tail(1), [Row(id=9)]) + + def test_kwargs_dataframe(self): + df0 = self.spark.range(10) + df1 = self.spark.sql( + "SELECT * FROM {df} WHERE id > 4", + df=df0, + ) + + self.assertEqual(df0.schema, df1.schema) + self.assertEqual(df1.count(), 5) + self.assertEqual(df1.take(1), [Row(id=5)]) + self.assertEqual(df1.tail(1), [Row(id=9)]) + + def test_kwargs_dataframe_with_column(self): + df0 = self.spark.range(10) + df1 = self.spark.sql( + "SELECT * FROM {df} WHERE {df.id} > :m1 AND {df[id]} < :m2", + {"m1": 4, "m2": 9}, + df=df0, + ) + + self.assertEqual(df0.schema, df1.schema) + self.assertEqual(df1.count(), 4) + self.assertEqual(df1.take(1), [Row(id=5)]) + self.assertEqual(df1.tail(1), [Row(id=8)]) + + def test_nested_view(self): + with self.tempView("v1", "v2", "v3", "v4"): + self.spark.range(10).createOrReplaceTempView("v1") + self.spark.sql( + "SELECT * FROM IDENTIFIER(:view) WHERE id > :m", + args={"view": "v1", "m": 1}, + ).createOrReplaceTempView("v2") + self.spark.sql( + "SELECT * FROM IDENTIFIER(:view) WHERE id > :m", + args={"view": "v2", "m": 2}, + ).createOrReplaceTempView("v3") + self.spark.sql( + "SELECT * FROM IDENTIFIER(:view) WHERE id > :m", + args={"view": "v3", "m": 3}, + ).createOrReplaceTempView("v4") + + df = self.spark.sql("select * from v4") + self.assertEqual(df.count(), 6) + self.assertEqual(df.take(1), [Row(id=4)]) + self.assertEqual(df.tail(1), [Row(id=9)]) + + def test_nested_dataframe(self): + df0 = self.spark.range(10) + df1 = self.spark.sql( + "SELECT * FROM {df} WHERE id > ?", + args=[1], + df=df0, + ) + df2 = self.spark.sql( + "SELECT * FROM {df} WHERE id > ?", + args=[2], + df=df1, + ) + df3 = self.spark.sql( + "SELECT * FROM {df} WHERE id > ?", + args=[3], + df=df2, + ) + + self.assertEqual(df0.schema, df1.schema) + self.assertEqual(df1.count(), 8) + self.assertEqual(df1.take(1), [Row(id=2)]) + self.assertEqual(df1.tail(1), [Row(id=9)]) + + self.assertEqual(df0.schema, df2.schema) + self.assertEqual(df2.count(), 7) + self.assertEqual(df2.take(1), [Row(id=3)]) + self.assertEqual(df2.tail(1), [Row(id=9)]) + + self.assertEqual(df0.schema, df3.schema) + self.assertEqual(df3.count(), 6) + self.assertEqual(df3.take(1), [Row(id=4)]) + self.assertEqual(df3.tail(1), [Row(id=9)]) + + +class SQLTests(SQLTestsMixin, ReusedSQLTestCase): + pass + + +if __name__ == "__main__": + from pyspark.sql.tests.test_sql import * # noqa: F401 + + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 8610ace52d86a..c240a84d1edb9 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -28,7 +28,6 @@ from pyspark.sql import Row from pyspark.sql import functions as F from pyspark.errors import ( - AnalysisException, ParseException, PySparkTypeError, PySparkValueError, @@ -1130,10 +1129,17 @@ def test_cast_to_string_with_udt(self): def test_cast_to_udt_with_udt(self): row = Row(point=ExamplePoint(1.0, 2.0), python_only_point=PythonOnlyPoint(1.0, 2.0)) df = self.spark.createDataFrame([row]) - with self.assertRaises(AnalysisException): - df.select(F.col("point").cast(PythonOnlyUDT())).collect() - with self.assertRaises(AnalysisException): - df.select(F.col("python_only_point").cast(ExamplePointUDT())).collect() + result = df.select(F.col("point").cast(PythonOnlyUDT())).collect() + self.assertEqual( + result, + [Row(point=PythonOnlyPoint(1.0, 2.0))], + ) + + result = df.select(F.col("python_only_point").cast(ExamplePointUDT())).collect() + self.assertEqual( + result, + [Row(python_only_point=ExamplePoint(1.0, 2.0))], + ) def test_struct_type(self): struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index 6f672b0ae5fb3..879329bd80c0b 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -237,11 +237,12 @@ def test_udf_in_join_condition(self): f = udf(lambda a, b: a == b, BooleanType()) # The udf uses attributes from both sides of join, so it is pulled out as Filter + # Cross join. - df = left.join(right, f("a", "b")) with self.sql_conf({"spark.sql.crossJoin.enabled": False}): + df = left.join(right, f("a", "b")) with self.assertRaisesRegex(AnalysisException, "Detected implicit cartesian product"): df.collect() with self.sql_conf({"spark.sql.crossJoin.enabled": True}): + df = left.join(right, f("a", "b")) self.assertEqual(df.collect(), [Row(a=1, b=1)]) def test_udf_in_left_outer_join_condition(self): diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 11b91612419a3..5d9ec92cbc830 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -41,6 +41,7 @@ PythonException, UnknownException, SparkUpgradeException, + PySparkImportError, PySparkNotImplementedError, PySparkRuntimeError, ) @@ -115,6 +116,22 @@ def require_test_compiled() -> None: ) +def require_minimum_plotly_version() -> None: + """Raise ImportError if plotly is not installed""" + minimum_plotly_version = "4.8" + + try: + import plotly # noqa: F401 + except ImportError as error: + raise PySparkImportError( + errorClass="PACKAGE_NOT_INSTALLED", + messageParameters={ + "package_name": "plotly", + "minimum_version": str(minimum_plotly_version), + }, + ) from error + + class ForeachBatchFunction: """ This is the Python implementation of Java interface 'ForeachBatchFunction'. This wraps diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py index 9f07c44c084cf..00ad40e68bd7c 100644 --- a/python/pyspark/testing/sqlutils.py +++ b/python/pyspark/testing/sqlutils.py @@ -48,6 +48,13 @@ except Exception as e: test_not_compiled_message = str(e) +plotly_requirement_message = None +try: + import plotly +except ImportError as e: + plotly_requirement_message = str(e) +have_plotly = plotly_requirement_message is None + from pyspark.sql import SparkSession from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row from pyspark.testing.utils import ReusedPySparkTestCase, PySparkErrorTestUtils diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index b8263769c28a9..eedf5d1fd5996 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -1565,14 +1565,15 @@ def map_batch(batch): num_output_rows = 0 for result_batch, result_type in result_iter: num_output_rows += len(result_batch) - # This assert is for Scalar Iterator UDF to fail fast. + # This check is for Scalar Iterator UDF to fail fast. # The length of the entire input can only be explicitly known # by consuming the input iterator in user side. Therefore, # it's very unlikely the output length is higher than # input length. - assert ( - is_map_pandas_iter or is_map_arrow_iter or num_output_rows <= num_input_rows - ), "Pandas SCALAR_ITER UDF outputted more rows than input rows." + if is_scalar_iter and num_output_rows > num_input_rows: + raise PySparkRuntimeError( + errorClass="PANDAS_UDF_OUTPUT_EXCEEDS_INPUT_ROWS", messageParameters={} + ) yield (result_batch, result_type) if is_scalar_iter: diff --git a/repl/pom.xml b/repl/pom.xml index 831379467a29e..1a1c6b92c9222 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -82,6 +82,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.apache.spark spark-tags_${scala.binary.version} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 3a4d68c19014d..db7fc85976c2a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -769,14 +769,17 @@ private[spark] object Config extends Logging { val KUBERNETES_VOLUMES_NFS_TYPE = "nfs" val KUBERNETES_VOLUMES_MOUNT_PATH_KEY = "mount.path" val KUBERNETES_VOLUMES_MOUNT_SUBPATH_KEY = "mount.subPath" + val KUBERNETES_VOLUMES_MOUNT_SUBPATHEXPR_KEY = "mount.subPathExpr" val KUBERNETES_VOLUMES_MOUNT_READONLY_KEY = "mount.readOnly" val KUBERNETES_VOLUMES_OPTIONS_PATH_KEY = "options.path" + val KUBERNETES_VOLUMES_OPTIONS_TYPE_KEY = "options.type" val KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY = "options.claimName" val KUBERNETES_VOLUMES_OPTIONS_CLAIM_STORAGE_CLASS_KEY = "options.storageClass" val KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY = "options.medium" val KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY = "options.sizeLimit" val KUBERNETES_VOLUMES_OPTIONS_SERVER_KEY = "options.server" val KUBERNETES_VOLUMES_LABEL_KEY = "label." + val KUBERNETES_VOLUMES_ANNOTATION_KEY = "annotation." val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv." val KUBERNETES_DNS_SUBDOMAIN_NAME_MAX_LENGTH = 253 diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala index 9dfd40a773eb1..b7113a562fa06 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala @@ -18,14 +18,15 @@ package org.apache.spark.deploy.k8s private[spark] sealed trait KubernetesVolumeSpecificConf -private[spark] case class KubernetesHostPathVolumeConf(hostPath: String) +private[spark] case class KubernetesHostPathVolumeConf(hostPath: String, volumeType: String) extends KubernetesVolumeSpecificConf private[spark] case class KubernetesPVCVolumeConf( claimName: String, storageClass: Option[String] = None, size: Option[String] = None, - labels: Option[Map[String, String]] = None) + labels: Option[Map[String, String]] = None, + annotations: Option[Map[String, String]] = None) extends KubernetesVolumeSpecificConf private[spark] case class KubernetesEmptyDirVolumeConf( @@ -42,5 +43,6 @@ private[spark] case class KubernetesVolumeSpec( volumeName: String, mountPath: String, mountSubPath: String, + mountSubPathExpr: String, mountReadOnly: Boolean, volumeConf: KubernetesVolumeSpecificConf) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala index 6463512c0114b..95821a909f351 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala @@ -45,21 +45,30 @@ object KubernetesVolumeUtils { val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_PATH_KEY" val readOnlyKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_READONLY_KEY" val subPathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_SUBPATH_KEY" + val subPathExprKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_SUBPATHEXPR_KEY" val labelKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_LABEL_KEY" + val annotationKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_ANNOTATION_KEY" + verifyMutuallyExclusiveOptionKeys(properties, subPathKey, subPathExprKey) val volumeLabelsMap = properties .filter(_._1.startsWith(labelKey)) .map { case (k, v) => k.replaceAll(labelKey, "") -> v } + val volumeAnnotationsMap = properties + .filter(_._1.startsWith(annotationKey)) + .map { + case (k, v) => k.replaceAll(annotationKey, "") -> v + } KubernetesVolumeSpec( volumeName = volumeName, mountPath = properties(pathKey), mountSubPath = properties.getOrElse(subPathKey, ""), + mountSubPathExpr = properties.getOrElse(subPathExprKey, ""), mountReadOnly = properties.get(readOnlyKey).exists(_.toBoolean), volumeConf = parseVolumeSpecificConf(properties, - volumeType, volumeName, Option(volumeLabelsMap))) + volumeType, volumeName, Option(volumeLabelsMap), Option(volumeAnnotationsMap))) }.toSeq } @@ -83,12 +92,16 @@ object KubernetesVolumeUtils { options: Map[String, String], volumeType: String, volumeName: String, - labels: Option[Map[String, String]]): KubernetesVolumeSpecificConf = { + labels: Option[Map[String, String]], + annotations: Option[Map[String, String]]): KubernetesVolumeSpecificConf = { volumeType match { case KUBERNETES_VOLUMES_HOSTPATH_TYPE => val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_PATH_KEY" + val typeKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_TYPE_KEY" verifyOptionKey(options, pathKey, KUBERNETES_VOLUMES_HOSTPATH_TYPE) - KubernetesHostPathVolumeConf(options(pathKey)) + // "" means that no checks will be performed before mounting the hostPath volume + // backward compatibility default + KubernetesHostPathVolumeConf(options(pathKey), options.getOrElse(typeKey, "")) case KUBERNETES_VOLUMES_PVC_TYPE => val claimNameKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY" @@ -101,7 +114,8 @@ object KubernetesVolumeUtils { options(claimNameKey), options.get(storageClassKey), options.get(sizeLimitKey), - labels) + labels, + annotations) case KUBERNETES_VOLUMES_EMPTYDIR_TYPE => val mediumKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY" @@ -129,6 +143,16 @@ object KubernetesVolumeUtils { } } + private def verifyMutuallyExclusiveOptionKeys( + options: Map[String, String], + keys: String*): Unit = { + val givenKeys = keys.filter(options.contains) + if (givenKeys.length > 1) { + throw new IllegalArgumentException("These config options are mutually exclusive: " + + s"${givenKeys.mkString(", ")}") + } + } + private def verifySize(size: Option[String]): Unit = { size.foreach { v => if (v.forall(_.isDigit) && parseLong(v) < 1024) { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala index 79f76e96474e3..2c28dc380046c 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala @@ -107,7 +107,7 @@ object SparkKubernetesClientFactory extends Logging { (token, configBuilder) => configBuilder.withOauthToken(token) }.withOption(oauthTokenFile) { (file, configBuilder) => - configBuilder.withOauthToken(Files.toString(file, Charsets.UTF_8)) + configBuilder.withOauthToken(Files.asCharSource(file, Charsets.UTF_8).read()) }.withOption(caCertFile) { (file, configBuilder) => configBuilder.withCaCertFile(file) }.withOption(clientKeyFile) { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStep.scala index e266d0f904e46..d64378a65d66f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStep.scala @@ -116,7 +116,7 @@ private[spark] class HadoopConfDriverFeatureStep(conf: KubernetesConf) override def getAdditionalKubernetesResources(): Seq[HasMetadata] = { if (confDir.isDefined) { val fileMap = confFiles.map { file => - (file.getName(), Files.toString(file, StandardCharsets.UTF_8)) + (file.getName(), Files.asCharSource(file, StandardCharsets.UTF_8).read()) }.toMap.asJava Seq(new ConfigMapBuilder() diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala index 82bda88892d04..89aefe47e46d1 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala @@ -229,7 +229,7 @@ private[spark] class KerberosConfDriverFeatureStep(kubernetesConf: KubernetesDri .endMetadata() .withImmutable(true) .addToData( - Map(file.getName() -> Files.toString(file, StandardCharsets.UTF_8)).asJava) + Map(file.getName() -> Files.asCharSource(file, StandardCharsets.UTF_8).read()).asJava) .build() } } ++ { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala index 5cc61c746b0e0..3d89696f19fcc 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala @@ -65,16 +65,16 @@ private[spark] class MountVolumesFeatureStep(conf: KubernetesConf) .withMountPath(spec.mountPath) .withReadOnly(spec.mountReadOnly) .withSubPath(spec.mountSubPath) + .withSubPathExpr(spec.mountSubPathExpr) .withName(spec.volumeName) .build() val volumeBuilder = spec.volumeConf match { - case KubernetesHostPathVolumeConf(hostPath) => - /* "" means that no checks will be performed before mounting the hostPath volume */ + case KubernetesHostPathVolumeConf(hostPath, volumeType) => new VolumeBuilder() - .withHostPath(new HostPathVolumeSource(hostPath, "")) + .withHostPath(new HostPathVolumeSource(hostPath, volumeType)) - case KubernetesPVCVolumeConf(claimNameTemplate, storageClass, size, labels) => + case KubernetesPVCVolumeConf(claimNameTemplate, storageClass, size, labels, annotations) => val claimName = conf match { case c: KubernetesExecutorConf => claimNameTemplate @@ -91,12 +91,17 @@ private[spark] class MountVolumesFeatureStep(conf: KubernetesConf) case Some(customLabelsMap) => (customLabelsMap ++ defaultVolumeLabels).asJava case None => defaultVolumeLabels.asJava } + val volumeAnnotations = annotations match { + case Some(value) => value.asJava + case None => Map[String, String]().asJava + } additionalResources.append(new PersistentVolumeClaimBuilder() .withKind(PVC) .withApiVersion("v1") .withNewMetadata() .withName(claimName) .addToLabels(volumeLabels) + .addToAnnotations(volumeAnnotations) .endMetadata() .withNewSpec() .withStorageClassName(storageClass.get) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala index cdc0112294113..f94dad2d15dc1 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala @@ -81,7 +81,7 @@ private[spark] class PodTemplateConfigMapStep(conf: KubernetesConf) val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf.sparkConf) val uri = downloadFile(podTemplateFile, Utils.createTempDir(), conf.sparkConf, hadoopConf) val file = new java.net.URI(uri).getPath - val podTemplateString = Files.toString(new File(file), StandardCharsets.UTF_8) + val podTemplateString = Files.asCharSource(new File(file), StandardCharsets.UTF_8).read() Seq(new ConfigMapBuilder() .withNewMetadata() .withName(configmapName) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala index 8bc6e9340871f..6021a4fb953e5 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala @@ -28,7 +28,6 @@ import io.fabric8.kubernetes.api.model.{HasMetadata, PersistentVolumeClaim, Pod, import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException} import org.apache.spark.{SecurityManager, SparkConf, SparkException} -import org.apache.spark.deploy.ExecutorFailureTracker import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.KubernetesConf @@ -38,7 +37,6 @@ import org.apache.spark.internal.config._ import org.apache.spark.resource.ResourceProfile import org.apache.spark.scheduler.cluster.SchedulerBackendUtils.DEFAULT_NUMBER_EXECUTORS import org.apache.spark.util.{Clock, Utils} -import org.apache.spark.util.SparkExitCode.EXCEED_MAX_EXECUTOR_FAILURES class ExecutorPodsAllocator( conf: SparkConf, @@ -73,8 +71,6 @@ class ExecutorPodsAllocator( protected val maxPendingPods = conf.get(KUBERNETES_MAX_PENDING_PODS) - protected val maxNumExecutorFailures = ExecutorFailureTracker.maxNumExecutorFailures(conf) - protected val podCreationTimeout = math.max( podAllocationDelay * 5, conf.get(KUBERNETES_ALLOCATION_EXECUTOR_TIMEOUT)) @@ -121,12 +117,6 @@ class ExecutorPodsAllocator( // if they happen to come up before the deletion takes effect. @volatile protected var deletedExecutorIds = Set.empty[Long] - @volatile private var failedExecutorIds = Set.empty[Long] - - protected val failureTracker = new ExecutorFailureTracker(conf, clock) - - protected[spark] def getNumExecutorsFailed: Int = failureTracker.numFailedExecutors - def start(applicationId: String, schedulerBackend: KubernetesClusterSchedulerBackend): Unit = { appId = applicationId driverPod.foreach { pod => @@ -142,11 +132,6 @@ class ExecutorPodsAllocator( } snapshotsStore.addSubscriber(podAllocationDelay) { executorPodsSnapshot => onNewSnapshots(applicationId, schedulerBackend, executorPodsSnapshot) - if (failureTracker.numFailedExecutors > maxNumExecutorFailures) { - logError(log"Max number of executor failures " + - log"(${MDC(LogKeys.MAX_EXECUTOR_FAILURES, maxNumExecutorFailures)}) reached") - stopApplication(EXCEED_MAX_EXECUTOR_FAILURES) - } } } @@ -163,10 +148,6 @@ class ExecutorPodsAllocator( def isDeleted(executorId: String): Boolean = deletedExecutorIds.contains(executorId.toLong) - private[k8s] def stopApplication(exitCode: Int): Unit = { - sys.exit(exitCode) - } - protected def onNewSnapshots( applicationId: String, schedulerBackend: KubernetesClusterSchedulerBackend, @@ -276,18 +257,6 @@ class ExecutorPodsAllocator( case _ => false } - val currentFailedExecutorIds = podsForRpId.filter { - case (_, PodFailed(_)) => true - case _ => false - }.keySet - - val newFailedExecutorIds = currentFailedExecutorIds.diff(failedExecutorIds) - if (newFailedExecutorIds.nonEmpty) { - logWarning(log"${MDC(LogKeys.COUNT, newFailedExecutorIds.size)} new failed executors.") - newFailedExecutorIds.foreach { _ => failureTracker.registerExecutorFailure() } - } - failedExecutorIds = failedExecutorIds ++ currentFailedExecutorIds - val (schedulerKnownPendingExecsForRpId, currentPendingExecutorsForRpId) = podsForRpId.filter { case (_, PodPending(_)) => true case _ => false diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala index 0d79efa06e497..fe2707a7f65b1 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala @@ -27,18 +27,20 @@ import io.fabric8.kubernetes.api.model.{Pod, PodBuilder} import io.fabric8.kubernetes.client.KubernetesClient import org.apache.spark.SparkConf +import org.apache.spark.deploy.ExecutorFailureTracker import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.KubernetesUtils._ -import org.apache.spark.internal.{Logging, MDC} -import org.apache.spark.internal.LogKeys.EXECUTOR_ID +import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.scheduler.ExecutorExited -import org.apache.spark.util.Utils +import org.apache.spark.util.{Clock, SystemClock, Utils} +import org.apache.spark.util.SparkExitCode.EXCEED_MAX_EXECUTOR_FAILURES private[spark] class ExecutorPodsLifecycleManager( val conf: SparkConf, kubernetesClient: KubernetesClient, - snapshotsStore: ExecutorPodsSnapshotsStore) extends Logging { + snapshotsStore: ExecutorPodsSnapshotsStore, + clock: Clock = new SystemClock()) extends Logging { import ExecutorPodsLifecycleManager._ @@ -62,18 +64,49 @@ private[spark] class ExecutorPodsLifecycleManager( private val namespace = conf.get(KUBERNETES_NAMESPACE) + private val sparkContainerName = conf.get(KUBERNETES_EXECUTOR_PODTEMPLATE_CONTAINER_NAME) + .getOrElse(DEFAULT_EXECUTOR_CONTAINER_NAME) + + protected val maxNumExecutorFailures = ExecutorFailureTracker.maxNumExecutorFailures(conf) + + @volatile private var failedExecutorIds = Set.empty[Long] + + protected val failureTracker = new ExecutorFailureTracker(conf, clock) + + protected[spark] def getNumExecutorsFailed: Int = failureTracker.numFailedExecutors + def start(schedulerBackend: KubernetesClusterSchedulerBackend): Unit = { val eventProcessingInterval = conf.get(KUBERNETES_EXECUTOR_EVENT_PROCESSING_INTERVAL) - snapshotsStore.addSubscriber(eventProcessingInterval) { - onNewSnapshots(schedulerBackend, _) + snapshotsStore.addSubscriber(eventProcessingInterval) { executorPodsSnapshot => + onNewSnapshots(schedulerBackend, executorPodsSnapshot) + if (failureTracker.numFailedExecutors > maxNumExecutorFailures) { + logError(log"Max number of executor failures " + + log"(${MDC(LogKeys.MAX_EXECUTOR_FAILURES, maxNumExecutorFailures)}) reached") + stopApplication(EXCEED_MAX_EXECUTOR_FAILURES) + } } } + private[k8s] def stopApplication(exitCode: Int): Unit = { + sys.exit(exitCode) + } + private def onNewSnapshots( schedulerBackend: KubernetesClusterSchedulerBackend, snapshots: Seq[ExecutorPodsSnapshot]): Unit = { val execIdsRemovedInThisRound = mutable.HashSet.empty[Long] snapshots.foreach { snapshot => + val currentFailedExecutorIds = snapshot.executorPods.filter { + case (_, PodFailed(_)) => true + case _ => false + }.keySet + + val newFailedExecutorIds = currentFailedExecutorIds -- failedExecutorIds + if (newFailedExecutorIds.nonEmpty) { + logWarning(log"${MDC(LogKeys.COUNT, newFailedExecutorIds.size)} new failed executors.") + newFailedExecutorIds.foreach { _ => failureTracker.registerExecutorFailure() } + } + failedExecutorIds = failedExecutorIds ++ currentFailedExecutorIds snapshot.executorPods.foreach { case (execId, state) => state match { case _state if isPodInactive(_state.pod) => @@ -101,7 +134,7 @@ private[spark] class ExecutorPodsLifecycleManager( execIdsRemovedInThisRound += execId if (schedulerBackend.isExecutorActive(execId.toString)) { logInfo(log"Snapshot reported succeeded executor with id " + - log"${MDC(EXECUTOR_ID, execId)}, even though the application has not " + + log"${MDC(LogKeys.EXECUTOR_ID, execId)}, even though the application has not " + log"requested for it to be removed.") } else { logDebug(s"Snapshot reported succeeded executor with id $execId," + @@ -246,7 +279,8 @@ private[spark] class ExecutorPodsLifecycleManager( private def findExitCode(podState: FinalPodState): Int = { podState.pod.getStatus.getContainerStatuses.asScala.find { containerStatus => - containerStatus.getState.getTerminated != null + containerStatus.getName == sparkContainerName && + containerStatus.getState.getTerminated != null }.map { terminatedContainer => terminatedContainer.getState.getTerminated.getExitCode.toInt }.getOrElse(UNKNOWN_EXIT_CODE) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index 4e4634504a0f3..09faa2a7fb1b3 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -32,7 +32,7 @@ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.submit.KubernetesClientUtils import org.apache.spark.deploy.security.HadoopDelegationTokenManager -import org.apache.spark.internal.LogKeys.{COUNT, HOST_PORT, TOTAL} +import org.apache.spark.internal.LogKeys.{COUNT, TOTAL} import org.apache.spark.internal.MDC import org.apache.spark.internal.config.SCHEDULER_MIN_REGISTERED_RESOURCES_RATIO import org.apache.spark.resource.ResourceProfile @@ -356,7 +356,7 @@ private[spark] class KubernetesClusterSchedulerBackend( execIDRequester -= rpcAddress // Expected, executors re-establish a connection with an ID case _ => - logInfo(log"No executor found for ${MDC(HOST_PORT, rpcAddress)}") + logDebug(s"No executor found for ${rpcAddress}") } } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBackend.scala index c515ae5e3a246..e44d7e29ef606 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBackend.scala @@ -116,6 +116,10 @@ private[spark] object KubernetesExecutorBackend extends Logging { } } + // Initialize logging system again after `spark.log.structuredLogging.enabled` takes effect + Utils.resetStructuredLogging(driverConf) + Logging.uninitialize() + cfg.hadoopDelegationCreds.foreach { tokens => SparkHadoopUtil.get.addDelegationTokens(tokens, driverConf) } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala index 7e0a65bcdda90..e5ed79718d733 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala @@ -113,11 +113,12 @@ object KubernetesTestConf { volumes.foreach { case spec => val (vtype, configs) = spec.volumeConf match { - case KubernetesHostPathVolumeConf(path) => - (KUBERNETES_VOLUMES_HOSTPATH_TYPE, - Map(KUBERNETES_VOLUMES_OPTIONS_PATH_KEY -> path)) + case KubernetesHostPathVolumeConf(hostPath, volumeType) => + (KUBERNETES_VOLUMES_HOSTPATH_TYPE, Map( + KUBERNETES_VOLUMES_OPTIONS_PATH_KEY -> hostPath, + KUBERNETES_VOLUMES_OPTIONS_TYPE_KEY -> volumeType)) - case KubernetesPVCVolumeConf(claimName, storageClass, sizeLimit, labels) => + case KubernetesPVCVolumeConf(claimName, storageClass, sizeLimit, labels, annotations) => val sconf = storageClass .map { s => (KUBERNETES_VOLUMES_OPTIONS_CLAIM_STORAGE_CLASS_KEY, s) }.toMap val lconf = sizeLimit.map { l => (KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY, l) }.toMap @@ -125,9 +126,13 @@ object KubernetesTestConf { case Some(value) => value.map { case(k, v) => s"label.$k" -> v } case None => Map() } + val aannotations = annotations match { + case Some(value) => value.map { case (k, v) => s"annotation.$k" -> v } + case None => Map() + } (KUBERNETES_VOLUMES_PVC_TYPE, Map(KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY -> claimName) ++ - sconf ++ lconf ++ llabels) + sconf ++ lconf ++ llabels ++ aannotations) case KubernetesEmptyDirVolumeConf(medium, sizeLimit) => val mconf = medium.map { m => (KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY, m) }.toMap @@ -145,6 +150,10 @@ object KubernetesTestConf { conf.set(key(vtype, spec.volumeName, KUBERNETES_VOLUMES_MOUNT_SUBPATH_KEY), spec.mountSubPath) } + if (spec.mountSubPathExpr.nonEmpty) { + conf.set(key(vtype, spec.volumeName, KUBERNETES_VOLUMES_MOUNT_SUBPATHEXPR_KEY), + spec.mountSubPathExpr) + } conf.set(key(vtype, spec.volumeName, KUBERNETES_VOLUMES_MOUNT_READONLY_KEY), spec.mountReadOnly.toString) configs.foreach { case (k, v) => diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala index 5c103739d3082..3c57cba9a7ff0 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala @@ -30,7 +30,20 @@ class KubernetesVolumeUtilsSuite extends SparkFunSuite { assert(volumeSpec.mountPath === "/path") assert(volumeSpec.mountReadOnly) assert(volumeSpec.volumeConf.asInstanceOf[KubernetesHostPathVolumeConf] === - KubernetesHostPathVolumeConf("/hostPath")) + KubernetesHostPathVolumeConf("/hostPath", "")) + } + + test("Parses hostPath volume type correctly") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.hostPath.volumeName.mount.path", "/path") + sparkConf.set("test.hostPath.volumeName.options.path", "/hostPath") + sparkConf.set("test.hostPath.volumeName.options.type", "Type") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesHostPathVolumeConf] === + KubernetesHostPathVolumeConf("/hostPath", "Type")) } test("Parses subPath correctly") { @@ -43,6 +56,33 @@ class KubernetesVolumeUtilsSuite extends SparkFunSuite { assert(volumeSpec.volumeName === "volumeName") assert(volumeSpec.mountPath === "/path") assert(volumeSpec.mountSubPath === "subPath") + assert(volumeSpec.mountSubPathExpr === "") + } + + test("Parses subPathExpr correctly") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.emptyDir.volumeName.mount.path", "/path") + sparkConf.set("test.emptyDir.volumeName.mount.readOnly", "true") + sparkConf.set("test.emptyDir.volumeName.mount.subPathExpr", "subPathExpr") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountSubPath === "") + assert(volumeSpec.mountSubPathExpr === "subPathExpr") + } + + test("Rejects mutually exclusive subPath and subPathExpr") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.emptyDir.volumeName.mount.path", "/path") + sparkConf.set("test.emptyDir.volumeName.mount.subPath", "subPath") + sparkConf.set("test.emptyDir.volumeName.mount.subPathExpr", "subPathExpr") + + val msg = intercept[IllegalArgumentException] { + KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head + }.getMessage + assert(msg === "These config options are mutually exclusive: " + + "emptyDir.volumeName.mount.subPath, emptyDir.volumeName.mount.subPathExpr") } test("Parses persistentVolumeClaim volumes correctly") { @@ -56,7 +96,7 @@ class KubernetesVolumeUtilsSuite extends SparkFunSuite { assert(volumeSpec.mountPath === "/path") assert(volumeSpec.mountReadOnly) assert(volumeSpec.volumeConf.asInstanceOf[KubernetesPVCVolumeConf] === - KubernetesPVCVolumeConf("claimName", labels = Some(Map()))) + KubernetesPVCVolumeConf("claimName", labels = Some(Map()), annotations = Some(Map()))) } test("SPARK-49598: Parses persistentVolumeClaim volumes correctly with labels") { @@ -73,7 +113,8 @@ class KubernetesVolumeUtilsSuite extends SparkFunSuite { assert(volumeSpec.mountReadOnly) assert(volumeSpec.volumeConf.asInstanceOf[KubernetesPVCVolumeConf] === KubernetesPVCVolumeConf(claimName = "claimName", - labels = Some(Map("env" -> "test", "foo" -> "bar")))) + labels = Some(Map("env" -> "test", "foo" -> "bar")), + annotations = Some(Map()))) } test("SPARK-49598: Parses persistentVolumeClaim volumes & puts " + @@ -88,7 +129,8 @@ class KubernetesVolumeUtilsSuite extends SparkFunSuite { assert(volumeSpec.mountPath === "/path") assert(volumeSpec.mountReadOnly) assert(volumeSpec.volumeConf.asInstanceOf[KubernetesPVCVolumeConf] === - KubernetesPVCVolumeConf(claimName = "claimName", labels = Some(Map()))) + KubernetesPVCVolumeConf(claimName = "claimName", labels = Some(Map()), + annotations = Some(Map()))) } test("Parses emptyDir volumes correctly") { @@ -240,4 +282,38 @@ class KubernetesVolumeUtilsSuite extends SparkFunSuite { }.getMessage assert(m.contains("smaller than 1KiB. Missing units?")) } + + test("SPARK-49833: Parses persistentVolumeClaim volumes correctly with annotations") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.persistentVolumeClaim.volumeName.mount.path", "/path") + sparkConf.set("test.persistentVolumeClaim.volumeName.mount.readOnly", "true") + sparkConf.set("test.persistentVolumeClaim.volumeName.options.claimName", "claimName") + sparkConf.set("test.persistentVolumeClaim.volumeName.annotation.key1", "value1") + sparkConf.set("test.persistentVolumeClaim.volumeName.annotation.key2", "value2") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesPVCVolumeConf] === + KubernetesPVCVolumeConf(claimName = "claimName", + labels = Some(Map()), + annotations = Some(Map("key1" -> "value1", "key2" -> "value2")))) + } + + test("SPARK-49833: Parses persistentVolumeClaim volumes & puts " + + "annotations as empty Map if not provided") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.persistentVolumeClaim.volumeName.mount.path", "/path") + sparkConf.set("test.persistentVolumeClaim.volumeName.mount.readOnly", "true") + sparkConf.set("test.persistentVolumeClaim.volumeName.options.claimName", "claimName") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesPVCVolumeConf] === + KubernetesPVCVolumeConf(claimName = "claimName", labels = Some(Map()), + annotations = Some(Map()))) + } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala index f1dd8b94f17ff..a72152a851c4f 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala @@ -128,7 +128,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite { private def writeCredentials(credentialsFileName: String, credentialsContents: String): File = { val credentialsFile = new File(credentialsTempDirectory, credentialsFileName) - Files.write(credentialsContents, credentialsFile, Charsets.UTF_8) + Files.asCharSink(credentialsFile, Charsets.UTF_8).write(credentialsContents) credentialsFile } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStepSuite.scala index 8f21b95236a9c..4310ac0220e5e 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStepSuite.scala @@ -48,7 +48,7 @@ class HadoopConfDriverFeatureStepSuite extends SparkFunSuite { val confFiles = Set("core-site.xml", "hdfs-site.xml") confFiles.foreach { f => - Files.write("some data", new File(confDir, f), UTF_8) + Files.asCharSink(new File(confDir, f), UTF_8).write("some data") } val sparkConf = new SparkConfWithEnv(Map(ENV_HADOOP_CONF_DIR -> confDir.getAbsolutePath())) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStepSuite.scala index a60227814eb13..04e20258d068f 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStepSuite.scala @@ -36,7 +36,7 @@ class HadoopConfExecutorFeatureStepSuite extends SparkFunSuite { val confFiles = Set("core-site.xml", "hdfs-site.xml") confFiles.foreach { f => - Files.write("some data", new File(confDir, f), UTF_8) + Files.asCharSink(new File(confDir, f), UTF_8).write("some data") } Seq( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStepSuite.scala index 163d87643abd3..b172bdc06ddca 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStepSuite.scala @@ -55,7 +55,7 @@ class KerberosConfDriverFeatureStepSuite extends SparkFunSuite { test("create krb5.conf config map if local config provided") { val krbConf = File.createTempFile("krb5", ".conf", tmpDir) - Files.write("some data", krbConf, UTF_8) + Files.asCharSink(krbConf, UTF_8).write("some data") val sparkConf = new SparkConf(false) .set(KUBERNETES_KERBEROS_KRB5_FILE, krbConf.getAbsolutePath()) @@ -70,7 +70,7 @@ class KerberosConfDriverFeatureStepSuite extends SparkFunSuite { test("create keytab secret if client keytab file used") { val keytab = File.createTempFile("keytab", ".bin", tmpDir) - Files.write("some data", keytab, UTF_8) + Files.asCharSink(keytab, UTF_8).write("some data") val sparkConf = new SparkConf(false) .set(KEYTAB, keytab.getAbsolutePath()) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala index eaadad163f064..3a9561051a894 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala @@ -137,8 +137,9 @@ class LocalDirsFeatureStepSuite extends SparkFunSuite { "spark-local-dir-test", "/tmp", "", + "", false, - KubernetesHostPathVolumeConf("/hostPath/tmp") + KubernetesHostPathVolumeConf("/hostPath/tmp", "") ) val kubernetesConf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeConf)) val mountVolumeStep = new MountVolumesFeatureStep(kubernetesConf) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala index 6a68898c5f61c..293773ddb9ec5 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala @@ -27,8 +27,9 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testVolume", "/tmp", "", + "", false, - KubernetesHostPathVolumeConf("/hostPath/tmp") + KubernetesHostPathVolumeConf("/hostPath/tmp", "type") ) val kubernetesConf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeConf)) val step = new MountVolumesFeatureStep(kubernetesConf) @@ -36,6 +37,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { assert(configuredPod.pod.getSpec.getVolumes.size() === 1) assert(configuredPod.pod.getSpec.getVolumes.get(0).getHostPath.getPath === "/hostPath/tmp") + assert(configuredPod.pod.getSpec.getVolumes.get(0).getHostPath.getType === "type") assert(configuredPod.container.getVolumeMounts.size() === 1) assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp") assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume") @@ -47,6 +49,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testVolume", "/tmp", "", + "", true, KubernetesPVCVolumeConf("pvcClaim") ) @@ -69,6 +72,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testVolume", "/tmp", "", + "", true, KubernetesPVCVolumeConf("pvc-spark-SPARK_EXECUTOR_ID") ) @@ -94,6 +98,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testVolume", "/tmp", "", + "", true, KubernetesPVCVolumeConf("pvc-spark-SPARK_EXECUTOR_ID", Some("fast"), Some("512M")) ) @@ -119,6 +124,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testVolume", "/tmp", "", + "", true, KubernetesPVCVolumeConf("OnDemand") ) @@ -136,6 +142,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testVolume", "/tmp", "", + "", true, KubernetesPVCVolumeConf(claimName = MountVolumesFeatureStep.PVC_ON_DEMAND, storageClass = Some("gp3"), @@ -156,6 +163,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testVolume", "/tmp", "", + "", true, KubernetesPVCVolumeConf(claimName = MountVolumesFeatureStep.PVC_ON_DEMAND, storageClass = Some("gp3"), @@ -177,6 +185,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "checkpointVolume1", "/checkpoints1", "", + "", true, KubernetesPVCVolumeConf(claimName = "pvcClaim1", storageClass = Some("gp3"), @@ -188,6 +197,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "checkpointVolume2", "/checkpoints2", "", + "", true, KubernetesPVCVolumeConf(claimName = "pvcClaim2", storageClass = Some("gp3"), @@ -209,6 +219,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testVolume", "/tmp", "", + "", true, KubernetesPVCVolumeConf(MountVolumesFeatureStep.PVC_ON_DEMAND) ) @@ -226,6 +237,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testVolume", "/tmp", "", + "", false, KubernetesEmptyDirVolumeConf(Some("Memory"), Some("6G")) ) @@ -249,6 +261,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testVolume", "/tmp", "", + "", false, KubernetesEmptyDirVolumeConf(None, None) ) @@ -271,6 +284,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testVolume", "/tmp", "", + "", false, KubernetesNFSVolumeConf("/share/name", "nfs.example.com") ) @@ -293,6 +307,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testVolume", "/tmp", "", + "", true, KubernetesNFSVolumeConf("/share/name", "nfs.example.com") ) @@ -315,13 +330,15 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "hpVolume", "/tmp", "", + "", false, - KubernetesHostPathVolumeConf("/hostPath/tmp") + KubernetesHostPathVolumeConf("/hostPath/tmp", "") ) val pvcVolumeConf = KubernetesVolumeSpec( "checkpointVolume", "/checkpoints", "", + "", true, KubernetesPVCVolumeConf("pvcClaim") ) @@ -339,13 +356,15 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "hpVolume", "/data", "", + "", false, - KubernetesHostPathVolumeConf("/hostPath/tmp") + KubernetesHostPathVolumeConf("/hostPath/tmp", "") ) val pvcVolumeConf = KubernetesVolumeSpec( "checkpointVolume", "/data", "", + "", true, KubernetesPVCVolumeConf("pvcClaim") ) @@ -364,6 +383,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testVolume", "/tmp", "foo", + "", false, KubernetesEmptyDirVolumeConf(None, None) ) @@ -378,11 +398,32 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { assert(emptyDirMount.getSubPath === "foo") } + test("Mounts subpathexpr on emptyDir") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + "", + "foo", + false, + KubernetesEmptyDirVolumeConf(None, None) + ) + val kubernetesConf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeConf)) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val emptyDirMount = configuredPod.container.getVolumeMounts.get(0) + assert(emptyDirMount.getMountPath === "/tmp") + assert(emptyDirMount.getName === "testVolume") + assert(emptyDirMount.getSubPathExpr === "foo") + } + test("Mounts subpath on persistentVolumeClaims") { val volumeConf = KubernetesVolumeSpec( "testVolume", "/tmp", "bar", + "", true, KubernetesPVCVolumeConf("pvcClaim") ) @@ -400,12 +441,36 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { assert(pvcMount.getSubPath === "bar") } + test("Mounts subpathexpr on persistentVolumeClaims") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + "", + "bar", + true, + KubernetesPVCVolumeConf("pvcClaim") + ) + val kubernetesConf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeConf)) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val pvcClaim = configuredPod.pod.getSpec.getVolumes.get(0).getPersistentVolumeClaim + assert(pvcClaim.getClaimName === "pvcClaim") + assert(configuredPod.container.getVolumeMounts.size() === 1) + val pvcMount = configuredPod.container.getVolumeMounts.get(0) + assert(pvcMount.getMountPath === "/tmp") + assert(pvcMount.getName === "testVolume") + assert(pvcMount.getSubPathExpr === "bar") + } + test("Mounts multiple subpaths") { val volumeConf = KubernetesEmptyDirVolumeConf(None, None) val emptyDirSpec = KubernetesVolumeSpec( "testEmptyDir", "/tmp/foo", "foo", + "", true, KubernetesEmptyDirVolumeConf(None, None) ) @@ -413,6 +478,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testPVC", "/tmp/bar", "bar", + "", true, KubernetesEmptyDirVolumeConf(None, None) ) @@ -430,4 +496,81 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { assert(mounts(1).getMountPath === "/tmp/bar") assert(mounts(1).getSubPath === "bar") } + + test("SPARK-49833: Create and mounts persistentVolumeClaims in driver with annotations") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + "", + "", + true, + KubernetesPVCVolumeConf(claimName = MountVolumesFeatureStep.PVC_ON_DEMAND, + storageClass = Some("gp3"), + size = Some("1Mi"), + annotations = Some(Map("env" -> "test"))) + ) + + val kubernetesConf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeConf)) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val pvcClaim = configuredPod.pod.getSpec.getVolumes.get(0).getPersistentVolumeClaim + assert(pvcClaim.getClaimName.endsWith("-driver-pvc-0")) + } + + test("SPARK-49833: Create and mounts persistentVolumeClaims in executors with annotations") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + "", + "", + true, + KubernetesPVCVolumeConf(claimName = MountVolumesFeatureStep.PVC_ON_DEMAND, + storageClass = Some("gp3"), + size = Some("1Mi"), + annotations = Some(Map("env" -> "exec-test"))) + ) + + val executorConf = KubernetesTestConf.createExecutorConf(volumes = Seq(volumeConf)) + val executorStep = new MountVolumesFeatureStep(executorConf) + val executorPod = executorStep.configurePod(SparkPod.initialPod()) + + assert(executorPod.pod.getSpec.getVolumes.size() === 1) + val executorPVC = executorPod.pod.getSpec.getVolumes.get(0).getPersistentVolumeClaim + assert(executorPVC.getClaimName.endsWith("-exec-1-pvc-0")) + } + + test("SPARK-49833: Mount multiple volumes to executor with annotations") { + val pvcVolumeConf1 = KubernetesVolumeSpec( + "checkpointVolume1", + "/checkpoints1", + "", + "", + true, + KubernetesPVCVolumeConf(claimName = "pvcClaim1", + storageClass = Some("gp3"), + size = Some("1Mi"), + annotations = Some(Map("env1" -> "exec-test-1"))) + ) + + val pvcVolumeConf2 = KubernetesVolumeSpec( + "checkpointVolume2", + "/checkpoints2", + "", + "", + true, + KubernetesPVCVolumeConf(claimName = "pvcClaim2", + storageClass = Some("gp3"), + size = Some("1Mi"), + annotations = Some(Map("env2" -> "exec-test-2"))) + ) + + val kubernetesConf = KubernetesTestConf.createExecutorConf( + volumes = Seq(pvcVolumeConf1, pvcVolumeConf2)) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 2) + assert(configuredPod.container.getVolumeMounts.size() === 2) + } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala index 299979071b5d7..fc75414e4a7e0 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala @@ -29,6 +29,7 @@ import org.apache.spark.resource.ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID object ExecutorLifecycleTestUtils { val TEST_SPARK_APP_ID = "spark-app-id" + val TEST_SPARK_EXECUTOR_CONTAINER_NAME = "spark-executor" def failedExecutorWithoutDeletion( executorId: Long, rpId: Int = DEFAULT_RESOURCE_PROFILE_ID): Pod = { @@ -37,7 +38,7 @@ object ExecutorLifecycleTestUtils { .withPhase("failed") .withStartTime(Instant.now.toString) .addNewContainerStatus() - .withName("spark-executor") + .withName(TEST_SPARK_EXECUTOR_CONTAINER_NAME) .withImage("k8s-spark") .withNewState() .withNewTerminated() @@ -49,6 +50,38 @@ object ExecutorLifecycleTestUtils { .addNewContainerStatus() .withName("spark-executor-sidecar") .withImage("k8s-spark-sidecar") + .withNewState() + .withNewTerminated() + .withMessage("Failed") + .withExitCode(2) + .endTerminated() + .endState() + .endContainerStatus() + .withMessage("Executor failed.") + .withReason("Executor failed because of a thrown error.") + .endStatus() + .build() + } + + def failedExecutorWithSidecarStatusListedFirst( + executorId: Long, rpId: Int = DEFAULT_RESOURCE_PROFILE_ID): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId, rpId)) + .editOrNewStatus() + .withPhase("failed") + .withStartTime(Instant.now.toString) + .addNewContainerStatus() // sidecar status listed before executor's container status + .withName("spark-executor-sidecar") + .withImage("k8s-spark-sidecar") + .withNewState() + .withNewTerminated() + .withMessage("Failed") + .withExitCode(2) + .endTerminated() + .endState() + .endContainerStatus() + .addNewContainerStatus() + .withName(TEST_SPARK_EXECUTOR_CONTAINER_NAME) + .withImage("k8s-spark") .withNewState() .withNewTerminated() .withMessage("Failed") @@ -200,7 +233,7 @@ object ExecutorLifecycleTestUtils { .endSpec() .build() val container = new ContainerBuilder() - .withName("spark-executor") + .withName(TEST_SPARK_EXECUTOR_CONTAINER_NAME) .withImage("k8s-spark") .build() SparkPod(pod, container) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala index 1ad5e0af0bd73..d4f7b9f67fd6f 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala @@ -42,7 +42,7 @@ import org.apache.spark.deploy.k8s.Fabric8Aliases._ import org.apache.spark.internal.config._ import org.apache.spark.resource._ import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ -import org.apache.spark.util.{ManualClock, SparkExitCode} +import org.apache.spark.util.ManualClock class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter { @@ -158,46 +158,6 @@ class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter { assert(m.contains("Allocation batch delay must be greater than 0.1s.")) } - test("SPARK-41210: Window based executor failure tracking mechanism") { - var _exitCode = -1 - val _conf = conf.clone - .set(MAX_EXECUTOR_FAILURES.key, "2") - .set(EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS.key, "2s") - podsAllocatorUnderTest = new ExecutorPodsAllocator(_conf, secMgr, - executorBuilder, kubernetesClient, snapshotsStore, waitForExecutorPodsClock) { - override private[spark] def stopApplication(exitCode: Int): Unit = { - _exitCode = exitCode - } - } - podsAllocatorUnderTest.setTotalExpectedExecutors(Map(defaultProfile -> 3)) - podsAllocatorUnderTest.start(TEST_SPARK_APP_ID, schedulerBackend) - assert(podsAllocatorUnderTest.getNumExecutorsFailed === 0) - - waitForExecutorPodsClock.advance(1000) - snapshotsStore.updatePod(failedExecutorWithoutDeletion(1)) - snapshotsStore.updatePod(failedExecutorWithoutDeletion(2)) - snapshotsStore.notifySubscribers() - assert(podsAllocatorUnderTest.getNumExecutorsFailed === 2) - assert(_exitCode === -1) - - waitForExecutorPodsClock.advance(1000) - snapshotsStore.notifySubscribers() - assert(podsAllocatorUnderTest.getNumExecutorsFailed === 2) - assert(_exitCode === -1) - - waitForExecutorPodsClock.advance(2000) - assert(podsAllocatorUnderTest.getNumExecutorsFailed === 0) - assert(_exitCode === -1) - - waitForExecutorPodsClock.advance(1000) - snapshotsStore.updatePod(failedExecutorWithoutDeletion(3)) - snapshotsStore.updatePod(failedExecutorWithoutDeletion(4)) - snapshotsStore.updatePod(failedExecutorWithoutDeletion(5)) - snapshotsStore.notifySubscribers() - assert(podsAllocatorUnderTest.getNumExecutorsFailed === 3) - assert(_exitCode === SparkExitCode.EXCEED_MAX_EXECUTOR_FAILURES) - } - test("SPARK-36052: test splitSlots") { val seq1 = Seq("a") assert(ExecutorPodsAllocator.splitSlots(seq1, 0) === Seq(("a", 0))) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala index 96be5dfabd121..4c7ffe692b105 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala @@ -33,11 +33,14 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.Config +import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.Fabric8Aliases._ import org.apache.spark.deploy.k8s.KubernetesUtils._ +import org.apache.spark.internal.config._ import org.apache.spark.scheduler.ExecutorExited import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ +import org.apache.spark.util.{ManualClock, SparkExitCode} class ExecutorPodsLifecycleManagerSuite extends SparkFunSuite with BeforeAndAfter { @@ -60,19 +63,64 @@ class ExecutorPodsLifecycleManagerSuite extends SparkFunSuite with BeforeAndAfte before { MockitoAnnotations.openMocks(this).close() + val sparkConf = new SparkConf() + .set(KUBERNETES_EXECUTOR_PODTEMPLATE_CONTAINER_NAME, TEST_SPARK_EXECUTOR_CONTAINER_NAME) snapshotsStore = new DeterministicExecutorPodsSnapshotsStore() namedExecutorPods = mutable.Map.empty[String, PodResource] when(schedulerBackend.getExecutorsWithRegistrationTs()).thenReturn(Map.empty[String, Long]) + when(schedulerBackend.getExecutorIds()).thenReturn(Seq.empty) when(kubernetesClient.pods()).thenReturn(podOperations) when(podOperations.inNamespace(anyString())).thenReturn(podsWithNamespace) when(podsWithNamespace.withName(any(classOf[String]))).thenAnswer(namedPodsAnswer()) eventHandlerUnderTest = new ExecutorPodsLifecycleManager( - new SparkConf(), + sparkConf, kubernetesClient, snapshotsStore) eventHandlerUnderTest.start(schedulerBackend) } + test("SPARK-41210: Window based executor failure tracking mechanism") { + var _exitCode = -1 + var waitForExecutorPodsClock = new ManualClock(0L) + val _conf = eventHandlerUnderTest.conf.clone + .set(MAX_EXECUTOR_FAILURES.key, "2") + .set(EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS.key, "2s") + snapshotsStore = new DeterministicExecutorPodsSnapshotsStore() + eventHandlerUnderTest = new ExecutorPodsLifecycleManager(_conf, + kubernetesClient, snapshotsStore, waitForExecutorPodsClock) { + override private[k8s] def stopApplication(exitCode: Int): Unit = { + logError("!!!") + _exitCode = exitCode + } + } + eventHandlerUnderTest.start(schedulerBackend) + assert(eventHandlerUnderTest.getNumExecutorsFailed === 0) + + waitForExecutorPodsClock.advance(1000) + snapshotsStore.updatePod(failedExecutorWithoutDeletion(1)) + snapshotsStore.updatePod(failedExecutorWithoutDeletion(2)) + snapshotsStore.notifySubscribers() + assert(eventHandlerUnderTest.getNumExecutorsFailed === 2) + assert(_exitCode === -1) + + waitForExecutorPodsClock.advance(1000) + snapshotsStore.notifySubscribers() + assert(eventHandlerUnderTest.getNumExecutorsFailed === 2) + assert(_exitCode === -1) + + waitForExecutorPodsClock.advance(2000) + assert(eventHandlerUnderTest.getNumExecutorsFailed === 0) + assert(_exitCode === -1) + + waitForExecutorPodsClock.advance(1000) + snapshotsStore.updatePod(failedExecutorWithoutDeletion(3)) + snapshotsStore.updatePod(failedExecutorWithoutDeletion(4)) + snapshotsStore.updatePod(failedExecutorWithoutDeletion(5)) + snapshotsStore.notifySubscribers() + assert(eventHandlerUnderTest.getNumExecutorsFailed === 3) + assert(_exitCode === SparkExitCode.EXCEED_MAX_EXECUTOR_FAILURES) + } + test("When an executor reaches error states immediately, remove from the scheduler backend.") { val failedPod = failedExecutorWithoutDeletion(1) val mockPodResource = mock(classOf[PodResource]) @@ -162,6 +210,15 @@ class ExecutorPodsLifecycleManagerSuite extends SparkFunSuite with BeforeAndAfte .edit(any[UnaryOperator[Pod]]()) } + test("SPARK-49804: Use the exit code of executor container always") { + val failedPod = failedExecutorWithSidecarStatusListedFirst(1) + snapshotsStore.updatePod(failedPod) + snapshotsStore.notifySubscribers() + val msg = exitReasonMessage(1, failedPod, 1) + val expectedLossReason = ExecutorExited(1, exitCausedByApp = true, msg) + verify(schedulerBackend).doRemoveExecutor("1", expectedLossReason) + } + private def exitReasonMessage(execId: Int, failedPod: Pod, exitCode: Int): String = { val reason = Option(failedPod.getStatus.getReason) val message = Option(failedPod.getStatus.getMessage) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala index ae5f037c6b7d4..950079dcb5362 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala @@ -40,7 +40,7 @@ private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => val logConfFilePath = s"${sparkHomeDir.toFile}/conf/log4j2.properties" try { - Files.write( + Files.asCharSink(new File(logConfFilePath), StandardCharsets.UTF_8).write( """rootLogger.level = info |rootLogger.appenderRef.stdout.ref = console |appender.console.type = Console @@ -51,9 +51,7 @@ private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => | |logger.spark.name = org.apache.spark |logger.spark.level = debug - """.stripMargin, - new File(logConfFilePath), - StandardCharsets.UTF_8) + """.stripMargin) f() } finally { diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 0b0b30e5e04fd..cf129677ad9c2 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -129,7 +129,7 @@ class KubernetesSuite extends SparkFunSuite val tagFile = new File(path) require(tagFile.isFile, s"No file found for image tag at ${tagFile.getAbsolutePath}.") - Files.toString(tagFile, Charsets.UTF_8).trim + Files.asCharSource(tagFile, Charsets.UTF_8).read().trim } .orElse(sys.props.get(CONFIG_KEY_IMAGE_TAG)) .getOrElse { diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index 694d81b3c25e3..770a550030f51 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -156,6 +156,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + + + + 4.0.0 + + org.apache.spark + spark-parent_2.13 + 4.0.0-SNAPSHOT + ../../../pom.xml + + + spark-connect-shims_2.13 + jar + Spark Project Connect Shims + https://spark.apache.org/ + + connect-shims + + + + + org.scala-lang + scala-library + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/sql/connect/shims/src/main/scala/org/apache/spark/api/java/shims.scala b/sql/connect/shims/src/main/scala/org/apache/spark/api/java/shims.scala new file mode 100644 index 0000000000000..45fae00247485 --- /dev/null +++ b/sql/connect/shims/src/main/scala/org/apache/spark/api/java/shims.scala @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.api.java + +class JavaRDD[T] diff --git a/sql/connect/shims/src/main/scala/org/apache/spark/rdd/shims.scala b/sql/connect/shims/src/main/scala/org/apache/spark/rdd/shims.scala new file mode 100644 index 0000000000000..b23f83fa9185c --- /dev/null +++ b/sql/connect/shims/src/main/scala/org/apache/spark/rdd/shims.scala @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rdd + +class RDD[T] diff --git a/sql/connect/shims/src/main/scala/org/apache/spark/shims.scala b/sql/connect/shims/src/main/scala/org/apache/spark/shims.scala new file mode 100644 index 0000000000000..813b8e4859c28 --- /dev/null +++ b/sql/connect/shims/src/main/scala/org/apache/spark/shims.scala @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark + +class SparkContext diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 9eb5decb3b515..16236940fe072 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -73,6 +73,19 @@ test-jar test + + org.apache.spark + spark-sql-api_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-connect-shims_${scala.binary.version} + + + org.apache.spark spark-tags_${scala.binary.version} @@ -219,6 +232,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.seleniumhq.selenium selenium-java diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/StateMessage.proto b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/StateMessage.proto index 1ff90f27e173a..63728216ded1e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/StateMessage.proto +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/StateMessage.proto @@ -46,6 +46,7 @@ message StatefulProcessorCall { message StateVariableRequest { oneof method { ValueStateCall valueStateCall = 1; + ListStateCall listStateCall = 2; } } @@ -72,6 +73,18 @@ message ValueStateCall { } } +message ListStateCall { + string stateName = 1; + oneof method { + Exists exists = 2; + ListStateGet listStateGet = 3; + ListStatePut listStatePut = 4; + AppendValue appendValue = 5; + AppendList appendList = 6; + Clear clear = 7; + } +} + message SetImplicitKey { bytes key = 1; } @@ -92,6 +105,20 @@ message ValueStateUpdate { message Clear { } +message ListStateGet { + string iteratorId = 1; +} + +message ListStatePut { +} + +message AppendValue { + bytes value = 1; +} + +message AppendList { +} + enum HandleState { CREATED = 0; INITIALIZED = 1; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/state/StateMessage.java b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/state/StateMessage.java index 4fbb20be05b7b..d6d56dd732775 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/state/StateMessage.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/state/StateMessage.java @@ -3462,6 +3462,21 @@ public interface StateVariableRequestOrBuilder extends */ org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCallOrBuilder getValueStateCallOrBuilder(); + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + * @return Whether the listStateCall field is set. + */ + boolean hasListStateCall(); + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + * @return The listStateCall. + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall getListStateCall(); + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCallOrBuilder getListStateCallOrBuilder(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariableRequest.MethodCase getMethodCase(); } /** @@ -3510,6 +3525,7 @@ public enum MethodCase implements com.google.protobuf.Internal.EnumLite, com.google.protobuf.AbstractMessage.InternalOneOfEnum { VALUESTATECALL(1), + LISTSTATECALL(2), METHOD_NOT_SET(0); private final int value; private MethodCase(int value) { @@ -3528,6 +3544,7 @@ public static MethodCase valueOf(int value) { public static MethodCase forNumber(int value) { switch (value) { case 1: return VALUESTATECALL; + case 2: return LISTSTATECALL; case 0: return METHOD_NOT_SET; default: return null; } @@ -3574,6 +3591,37 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCal return org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall.getDefaultInstance(); } + public static final int LISTSTATECALL_FIELD_NUMBER = 2; + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + * @return Whether the listStateCall field is set. + */ + @java.lang.Override + public boolean hasListStateCall() { + return methodCase_ == 2; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + * @return The listStateCall. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall getListStateCall() { + if (methodCase_ == 2) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.getDefaultInstance(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCallOrBuilder getListStateCallOrBuilder() { + if (methodCase_ == 2) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.getDefaultInstance(); + } + private byte memoizedIsInitialized = -1; @java.lang.Override public final boolean isInitialized() { @@ -3591,6 +3639,9 @@ public void writeTo(com.google.protobuf.CodedOutputStream output) if (methodCase_ == 1) { output.writeMessage(1, (org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall) method_); } + if (methodCase_ == 2) { + output.writeMessage(2, (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall) method_); + } getUnknownFields().writeTo(output); } @@ -3604,6 +3655,10 @@ public int getSerializedSize() { size += com.google.protobuf.CodedOutputStream .computeMessageSize(1, (org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall) method_); } + if (methodCase_ == 2) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(2, (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall) method_); + } size += getUnknownFields().getSerializedSize(); memoizedSize = size; return size; @@ -3625,6 +3680,10 @@ public boolean equals(final java.lang.Object obj) { if (!getValueStateCall() .equals(other.getValueStateCall())) return false; break; + case 2: + if (!getListStateCall() + .equals(other.getListStateCall())) return false; + break; case 0: default: } @@ -3644,6 +3703,10 @@ public int hashCode() { hash = (37 * hash) + VALUESTATECALL_FIELD_NUMBER; hash = (53 * hash) + getValueStateCall().hashCode(); break; + case 2: + hash = (37 * hash) + LISTSTATECALL_FIELD_NUMBER; + hash = (53 * hash) + getListStateCall().hashCode(); + break; case 0: default: } @@ -3778,6 +3841,9 @@ public Builder clear() { if (valueStateCallBuilder_ != null) { valueStateCallBuilder_.clear(); } + if (listStateCallBuilder_ != null) { + listStateCallBuilder_.clear(); + } methodCase_ = 0; method_ = null; return this; @@ -3813,6 +3879,13 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariable result.method_ = valueStateCallBuilder_.build(); } } + if (methodCase_ == 2) { + if (listStateCallBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = listStateCallBuilder_.build(); + } + } result.methodCase_ = methodCase_; onBuilt(); return result; @@ -3867,6 +3940,10 @@ public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMes mergeValueStateCall(other.getValueStateCall()); break; } + case LISTSTATECALL: { + mergeListStateCall(other.getListStateCall()); + break; + } case METHOD_NOT_SET: { break; } @@ -3904,6 +3981,13 @@ public Builder mergeFrom( methodCase_ = 1; break; } // case 10 + case 18: { + input.readMessage( + getListStateCallFieldBuilder().getBuilder(), + extensionRegistry); + methodCase_ = 2; + break; + } // case 18 default: { if (!super.parseUnknownField(input, extensionRegistry, tag)) { done = true; // was an endgroup tag @@ -4076,6 +4160,148 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCal onChanged();; return valueStateCallBuilder_; } + + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCallOrBuilder> listStateCallBuilder_; + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + * @return Whether the listStateCall field is set. + */ + @java.lang.Override + public boolean hasListStateCall() { + return methodCase_ == 2; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + * @return The listStateCall. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall getListStateCall() { + if (listStateCallBuilder_ == null) { + if (methodCase_ == 2) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.getDefaultInstance(); + } else { + if (methodCase_ == 2) { + return listStateCallBuilder_.getMessage(); + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + */ + public Builder setListStateCall(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall value) { + if (listStateCallBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + method_ = value; + onChanged(); + } else { + listStateCallBuilder_.setMessage(value); + } + methodCase_ = 2; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + */ + public Builder setListStateCall( + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.Builder builderForValue) { + if (listStateCallBuilder_ == null) { + method_ = builderForValue.build(); + onChanged(); + } else { + listStateCallBuilder_.setMessage(builderForValue.build()); + } + methodCase_ = 2; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + */ + public Builder mergeListStateCall(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall value) { + if (listStateCallBuilder_ == null) { + if (methodCase_ == 2 && + method_ != org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.getDefaultInstance()) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.newBuilder((org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall) method_) + .mergeFrom(value).buildPartial(); + } else { + method_ = value; + } + onChanged(); + } else { + if (methodCase_ == 2) { + listStateCallBuilder_.mergeFrom(value); + } else { + listStateCallBuilder_.setMessage(value); + } + } + methodCase_ = 2; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + */ + public Builder clearListStateCall() { + if (listStateCallBuilder_ == null) { + if (methodCase_ == 2) { + methodCase_ = 0; + method_ = null; + onChanged(); + } + } else { + if (methodCase_ == 2) { + methodCase_ = 0; + method_ = null; + } + listStateCallBuilder_.clear(); + } + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.Builder getListStateCallBuilder() { + return getListStateCallFieldBuilder().getBuilder(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCallOrBuilder getListStateCallOrBuilder() { + if ((methodCase_ == 2) && (listStateCallBuilder_ != null)) { + return listStateCallBuilder_.getMessageOrBuilder(); + } else { + if (methodCase_ == 2) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCallOrBuilder> + getListStateCallFieldBuilder() { + if (listStateCallBuilder_ == null) { + if (!(methodCase_ == 2)) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.getDefaultInstance(); + } + listStateCallBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCallOrBuilder>( + (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall) method_, + getParentForChildren(), + isClean()); + method_ = null; + } + methodCase_ = 2; + onChanged();; + return listStateCallBuilder_; + } @java.lang.Override public final Builder setUnknownFields( final com.google.protobuf.UnknownFieldSet unknownFields) { @@ -7482,37 +7708,135 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCal } - public interface SetImplicitKeyOrBuilder extends - // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) + public interface ListStateCallOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.ListStateCall) com.google.protobuf.MessageOrBuilder { /** - * bytes key = 1; - * @return The key. + * string stateName = 1; + * @return The stateName. */ - com.google.protobuf.ByteString getKey(); + java.lang.String getStateName(); + /** + * string stateName = 1; + * @return The bytes for stateName. + */ + com.google.protobuf.ByteString + getStateNameBytes(); + + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + * @return Whether the exists field is set. + */ + boolean hasExists(); + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + * @return The exists. + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getExists(); + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.ExistsOrBuilder getExistsOrBuilder(); + + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + * @return Whether the listStateGet field is set. + */ + boolean hasListStateGet(); + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + * @return The listStateGet. + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet getListStateGet(); + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGetOrBuilder getListStateGetOrBuilder(); + + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + * @return Whether the listStatePut field is set. + */ + boolean hasListStatePut(); + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + * @return The listStatePut. + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut getListStatePut(); + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePutOrBuilder getListStatePutOrBuilder(); + + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + * @return Whether the appendValue field is set. + */ + boolean hasAppendValue(); + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + * @return The appendValue. + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue getAppendValue(); + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValueOrBuilder getAppendValueOrBuilder(); + + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + * @return Whether the appendList field is set. + */ + boolean hasAppendList(); + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + * @return The appendList. + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList getAppendList(); + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendListOrBuilder getAppendListOrBuilder(); + + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + * @return Whether the clear field is set. + */ + boolean hasClear(); + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + * @return The clear. + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getClear(); + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilder getClearOrBuilder(); + + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.MethodCase getMethodCase(); } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.SetImplicitKey} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ListStateCall} */ - public static final class SetImplicitKey extends + public static final class ListStateCall extends com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) - SetImplicitKeyOrBuilder { + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.ListStateCall) + ListStateCallOrBuilder { private static final long serialVersionUID = 0L; - // Use SetImplicitKey.newBuilder() to construct. - private SetImplicitKey(com.google.protobuf.GeneratedMessageV3.Builder builder) { + // Use ListStateCall.newBuilder() to construct. + private ListStateCall(com.google.protobuf.GeneratedMessageV3.Builder builder) { super(builder); } - private SetImplicitKey() { - key_ = com.google.protobuf.ByteString.EMPTY; + private ListStateCall() { + stateName_ = ""; } @java.lang.Override @SuppressWarnings({"unused"}) protected java.lang.Object newInstance( UnusedPrivateParameter unused) { - return new SetImplicitKey(); + return new ListStateCall(); } @java.lang.Override @@ -7522,31 +7846,3583 @@ protected java.lang.Object newInstance( } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateCall_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateCall_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.Builder.class); } - public static final int KEY_FIELD_NUMBER = 1; - private com.google.protobuf.ByteString key_; - /** - * bytes key = 1; - * @return The key. - */ - @java.lang.Override - public com.google.protobuf.ByteString getKey() { - return key_; - } + private int methodCase_ = 0; + private java.lang.Object method_; + public enum MethodCase + implements com.google.protobuf.Internal.EnumLite, + com.google.protobuf.AbstractMessage.InternalOneOfEnum { + EXISTS(2), + LISTSTATEGET(3), + LISTSTATEPUT(4), + APPENDVALUE(5), + APPENDLIST(6), + CLEAR(7), + METHOD_NOT_SET(0); + private final int value; + private MethodCase(int value) { + this.value = value; + } + /** + * @param value The number of the enum to look for. + * @return The enum associated with the given number. + * @deprecated Use {@link #forNumber(int)} instead. + */ + @java.lang.Deprecated + public static MethodCase valueOf(int value) { + return forNumber(value); + } - private byte memoizedIsInitialized = -1; - @java.lang.Override - public final boolean isInitialized() { + public static MethodCase forNumber(int value) { + switch (value) { + case 2: return EXISTS; + case 3: return LISTSTATEGET; + case 4: return LISTSTATEPUT; + case 5: return APPENDVALUE; + case 6: return APPENDLIST; + case 7: return CLEAR; + case 0: return METHOD_NOT_SET; + default: return null; + } + } + public int getNumber() { + return this.value; + } + }; + + public MethodCase + getMethodCase() { + return MethodCase.forNumber( + methodCase_); + } + + public static final int STATENAME_FIELD_NUMBER = 1; + private volatile java.lang.Object stateName_; + /** + * string stateName = 1; + * @return The stateName. + */ + @java.lang.Override + public java.lang.String getStateName() { + java.lang.Object ref = stateName_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + stateName_ = s; + return s; + } + } + /** + * string stateName = 1; + * @return The bytes for stateName. + */ + @java.lang.Override + public com.google.protobuf.ByteString + getStateNameBytes() { + java.lang.Object ref = stateName_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + stateName_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + + public static final int EXISTS_FIELD_NUMBER = 2; + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + * @return Whether the exists field is set. + */ + @java.lang.Override + public boolean hasExists() { + return methodCase_ == 2; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + * @return The exists. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getExists() { + if (methodCase_ == 2) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ExistsOrBuilder getExistsOrBuilder() { + if (methodCase_ == 2) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance(); + } + + public static final int LISTSTATEGET_FIELD_NUMBER = 3; + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + * @return Whether the listStateGet field is set. + */ + @java.lang.Override + public boolean hasListStateGet() { + return methodCase_ == 3; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + * @return The listStateGet. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet getListStateGet() { + if (methodCase_ == 3) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.getDefaultInstance(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGetOrBuilder getListStateGetOrBuilder() { + if (methodCase_ == 3) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.getDefaultInstance(); + } + + public static final int LISTSTATEPUT_FIELD_NUMBER = 4; + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + * @return Whether the listStatePut field is set. + */ + @java.lang.Override + public boolean hasListStatePut() { + return methodCase_ == 4; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + * @return The listStatePut. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut getListStatePut() { + if (methodCase_ == 4) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.getDefaultInstance(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePutOrBuilder getListStatePutOrBuilder() { + if (methodCase_ == 4) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.getDefaultInstance(); + } + + public static final int APPENDVALUE_FIELD_NUMBER = 5; + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + * @return Whether the appendValue field is set. + */ + @java.lang.Override + public boolean hasAppendValue() { + return methodCase_ == 5; + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + * @return The appendValue. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue getAppendValue() { + if (methodCase_ == 5) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.getDefaultInstance(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValueOrBuilder getAppendValueOrBuilder() { + if (methodCase_ == 5) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.getDefaultInstance(); + } + + public static final int APPENDLIST_FIELD_NUMBER = 6; + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + * @return Whether the appendList field is set. + */ + @java.lang.Override + public boolean hasAppendList() { + return methodCase_ == 6; + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + * @return The appendList. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList getAppendList() { + if (methodCase_ == 6) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.getDefaultInstance(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendListOrBuilder getAppendListOrBuilder() { + if (methodCase_ == 6) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.getDefaultInstance(); + } + + public static final int CLEAR_FIELD_NUMBER = 7; + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + * @return Whether the clear field is set. + */ + @java.lang.Override + public boolean hasClear() { + return methodCase_ == 7; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + * @return The clear. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getClear() { + if (methodCase_ == 7) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilder getClearOrBuilder() { + if (methodCase_ == 7) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance(); + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(stateName_)) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 1, stateName_); + } + if (methodCase_ == 2) { + output.writeMessage(2, (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) method_); + } + if (methodCase_ == 3) { + output.writeMessage(3, (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet) method_); + } + if (methodCase_ == 4) { + output.writeMessage(4, (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut) method_); + } + if (methodCase_ == 5) { + output.writeMessage(5, (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue) method_); + } + if (methodCase_ == 6) { + output.writeMessage(6, (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList) method_); + } + if (methodCase_ == 7) { + output.writeMessage(7, (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) method_); + } + getUnknownFields().writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(stateName_)) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(1, stateName_); + } + if (methodCase_ == 2) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(2, (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) method_); + } + if (methodCase_ == 3) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(3, (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet) method_); + } + if (methodCase_ == 4) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(4, (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut) method_); + } + if (methodCase_ == 5) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(5, (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue) method_); + } + if (methodCase_ == 6) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(6, (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList) method_); + } + if (methodCase_ == 7) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(7, (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) method_); + } + size += getUnknownFields().getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall)) { + return super.equals(obj); + } + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall other = (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall) obj; + + if (!getStateName() + .equals(other.getStateName())) return false; + if (!getMethodCase().equals(other.getMethodCase())) return false; + switch (methodCase_) { + case 2: + if (!getExists() + .equals(other.getExists())) return false; + break; + case 3: + if (!getListStateGet() + .equals(other.getListStateGet())) return false; + break; + case 4: + if (!getListStatePut() + .equals(other.getListStatePut())) return false; + break; + case 5: + if (!getAppendValue() + .equals(other.getAppendValue())) return false; + break; + case 6: + if (!getAppendList() + .equals(other.getAppendList())) return false; + break; + case 7: + if (!getClear() + .equals(other.getClear())) return false; + break; + case 0: + default: + } + if (!getUnknownFields().equals(other.getUnknownFields())) return false; + return true; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + hash = (37 * hash) + STATENAME_FIELD_NUMBER; + hash = (53 * hash) + getStateName().hashCode(); + switch (methodCase_) { + case 2: + hash = (37 * hash) + EXISTS_FIELD_NUMBER; + hash = (53 * hash) + getExists().hashCode(); + break; + case 3: + hash = (37 * hash) + LISTSTATEGET_FIELD_NUMBER; + hash = (53 * hash) + getListStateGet().hashCode(); + break; + case 4: + hash = (37 * hash) + LISTSTATEPUT_FIELD_NUMBER; + hash = (53 * hash) + getListStatePut().hashCode(); + break; + case 5: + hash = (37 * hash) + APPENDVALUE_FIELD_NUMBER; + hash = (53 * hash) + getAppendValue().hashCode(); + break; + case 6: + hash = (37 * hash) + APPENDLIST_FIELD_NUMBER; + hash = (53 * hash) + getAppendList().hashCode(); + break; + case 7: + hash = (37 * hash) + CLEAR_FIELD_NUMBER; + hash = (53 * hash) + getClear().hashCode(); + break; + case 0: + default: + } + hash = (29 * hash) + getUnknownFields().hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ListStateCall} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.ListStateCall) + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCallOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateCall_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateCall_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.Builder.class); + } + + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.newBuilder() + private Builder() { + + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + + } + @java.lang.Override + public Builder clear() { + super.clear(); + stateName_ = ""; + + if (existsBuilder_ != null) { + existsBuilder_.clear(); + } + if (listStateGetBuilder_ != null) { + listStateGetBuilder_.clear(); + } + if (listStatePutBuilder_ != null) { + listStatePutBuilder_.clear(); + } + if (appendValueBuilder_ != null) { + appendValueBuilder_.clear(); + } + if (appendListBuilder_ != null) { + appendListBuilder_.clear(); + } + if (clearBuilder_ != null) { + clearBuilder_.clear(); + } + methodCase_ = 0; + method_ = null; + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateCall_descriptor; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.getDefaultInstance(); + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall result = new org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall(this); + result.stateName_ = stateName_; + if (methodCase_ == 2) { + if (existsBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = existsBuilder_.build(); + } + } + if (methodCase_ == 3) { + if (listStateGetBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = listStateGetBuilder_.build(); + } + } + if (methodCase_ == 4) { + if (listStatePutBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = listStatePutBuilder_.build(); + } + } + if (methodCase_ == 5) { + if (appendValueBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = appendValueBuilder_.build(); + } + } + if (methodCase_ == 6) { + if (appendListBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = appendListBuilder_.build(); + } + } + if (methodCase_ == 7) { + if (clearBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = clearBuilder_.build(); + } + } + result.methodCase_ = methodCase_; + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.getDefaultInstance()) return this; + if (!other.getStateName().isEmpty()) { + stateName_ = other.stateName_; + onChanged(); + } + switch (other.getMethodCase()) { + case EXISTS: { + mergeExists(other.getExists()); + break; + } + case LISTSTATEGET: { + mergeListStateGet(other.getListStateGet()); + break; + } + case LISTSTATEPUT: { + mergeListStatePut(other.getListStatePut()); + break; + } + case APPENDVALUE: { + mergeAppendValue(other.getAppendValue()); + break; + } + case APPENDLIST: { + mergeAppendList(other.getAppendList()); + break; + } + case CLEAR: { + mergeClear(other.getClear()); + break; + } + case METHOD_NOT_SET: { + break; + } + } + this.mergeUnknownFields(other.getUnknownFields()); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 10: { + stateName_ = input.readStringRequireUtf8(); + + break; + } // case 10 + case 18: { + input.readMessage( + getExistsFieldBuilder().getBuilder(), + extensionRegistry); + methodCase_ = 2; + break; + } // case 18 + case 26: { + input.readMessage( + getListStateGetFieldBuilder().getBuilder(), + extensionRegistry); + methodCase_ = 3; + break; + } // case 26 + case 34: { + input.readMessage( + getListStatePutFieldBuilder().getBuilder(), + extensionRegistry); + methodCase_ = 4; + break; + } // case 34 + case 42: { + input.readMessage( + getAppendValueFieldBuilder().getBuilder(), + extensionRegistry); + methodCase_ = 5; + break; + } // case 42 + case 50: { + input.readMessage( + getAppendListFieldBuilder().getBuilder(), + extensionRegistry); + methodCase_ = 6; + break; + } // case 50 + case 58: { + input.readMessage( + getClearFieldBuilder().getBuilder(), + extensionRegistry); + methodCase_ = 7; + break; + } // case 58 + default: { + if (!super.parseUnknownField(input, extensionRegistry, tag)) { + done = true; // was an endgroup tag + } + break; + } // default: + } // switch (tag) + } // while (!done) + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.unwrapIOException(); + } finally { + onChanged(); + } // finally + return this; + } + private int methodCase_ = 0; + private java.lang.Object method_; + public MethodCase + getMethodCase() { + return MethodCase.forNumber( + methodCase_); + } + + public Builder clearMethod() { + methodCase_ = 0; + method_ = null; + onChanged(); + return this; + } + + + private java.lang.Object stateName_ = ""; + /** + * string stateName = 1; + * @return The stateName. + */ + public java.lang.String getStateName() { + java.lang.Object ref = stateName_; + if (!(ref instanceof java.lang.String)) { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + stateName_ = s; + return s; + } else { + return (java.lang.String) ref; + } + } + /** + * string stateName = 1; + * @return The bytes for stateName. + */ + public com.google.protobuf.ByteString + getStateNameBytes() { + java.lang.Object ref = stateName_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + stateName_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + * string stateName = 1; + * @param value The stateName to set. + * @return This builder for chaining. + */ + public Builder setStateName( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + + stateName_ = value; + onChanged(); + return this; + } + /** + * string stateName = 1; + * @return This builder for chaining. + */ + public Builder clearStateName() { + + stateName_ = getDefaultInstance().getStateName(); + onChanged(); + return this; + } + /** + * string stateName = 1; + * @param value The bytes for stateName to set. + * @return This builder for chaining. + */ + public Builder setStateNameBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + checkByteStringIsUtf8(value); + + stateName_ = value; + onChanged(); + return this; + } + + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists, org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ExistsOrBuilder> existsBuilder_; + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + * @return Whether the exists field is set. + */ + @java.lang.Override + public boolean hasExists() { + return methodCase_ == 2; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + * @return The exists. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getExists() { + if (existsBuilder_ == null) { + if (methodCase_ == 2) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance(); + } else { + if (methodCase_ == 2) { + return existsBuilder_.getMessage(); + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + public Builder setExists(org.apache.spark.sql.execution.streaming.state.StateMessage.Exists value) { + if (existsBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + method_ = value; + onChanged(); + } else { + existsBuilder_.setMessage(value); + } + methodCase_ = 2; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + public Builder setExists( + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder builderForValue) { + if (existsBuilder_ == null) { + method_ = builderForValue.build(); + onChanged(); + } else { + existsBuilder_.setMessage(builderForValue.build()); + } + methodCase_ = 2; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + public Builder mergeExists(org.apache.spark.sql.execution.streaming.state.StateMessage.Exists value) { + if (existsBuilder_ == null) { + if (methodCase_ == 2 && + method_ != org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance()) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.newBuilder((org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) method_) + .mergeFrom(value).buildPartial(); + } else { + method_ = value; + } + onChanged(); + } else { + if (methodCase_ == 2) { + existsBuilder_.mergeFrom(value); + } else { + existsBuilder_.setMessage(value); + } + } + methodCase_ = 2; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + public Builder clearExists() { + if (existsBuilder_ == null) { + if (methodCase_ == 2) { + methodCase_ = 0; + method_ = null; + onChanged(); + } + } else { + if (methodCase_ == 2) { + methodCase_ = 0; + method_ = null; + } + existsBuilder_.clear(); + } + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder getExistsBuilder() { + return getExistsFieldBuilder().getBuilder(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ExistsOrBuilder getExistsOrBuilder() { + if ((methodCase_ == 2) && (existsBuilder_ != null)) { + return existsBuilder_.getMessageOrBuilder(); + } else { + if (methodCase_ == 2) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists, org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ExistsOrBuilder> + getExistsFieldBuilder() { + if (existsBuilder_ == null) { + if (!(methodCase_ == 2)) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance(); + } + existsBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists, org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ExistsOrBuilder>( + (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) method_, + getParentForChildren(), + isClean()); + method_ = null; + } + methodCase_ = 2; + onChanged();; + return existsBuilder_; + } + + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGetOrBuilder> listStateGetBuilder_; + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + * @return Whether the listStateGet field is set. + */ + @java.lang.Override + public boolean hasListStateGet() { + return methodCase_ == 3; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + * @return The listStateGet. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet getListStateGet() { + if (listStateGetBuilder_ == null) { + if (methodCase_ == 3) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.getDefaultInstance(); + } else { + if (methodCase_ == 3) { + return listStateGetBuilder_.getMessage(); + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + */ + public Builder setListStateGet(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet value) { + if (listStateGetBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + method_ = value; + onChanged(); + } else { + listStateGetBuilder_.setMessage(value); + } + methodCase_ = 3; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + */ + public Builder setListStateGet( + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.Builder builderForValue) { + if (listStateGetBuilder_ == null) { + method_ = builderForValue.build(); + onChanged(); + } else { + listStateGetBuilder_.setMessage(builderForValue.build()); + } + methodCase_ = 3; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + */ + public Builder mergeListStateGet(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet value) { + if (listStateGetBuilder_ == null) { + if (methodCase_ == 3 && + method_ != org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.getDefaultInstance()) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.newBuilder((org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet) method_) + .mergeFrom(value).buildPartial(); + } else { + method_ = value; + } + onChanged(); + } else { + if (methodCase_ == 3) { + listStateGetBuilder_.mergeFrom(value); + } else { + listStateGetBuilder_.setMessage(value); + } + } + methodCase_ = 3; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + */ + public Builder clearListStateGet() { + if (listStateGetBuilder_ == null) { + if (methodCase_ == 3) { + methodCase_ = 0; + method_ = null; + onChanged(); + } + } else { + if (methodCase_ == 3) { + methodCase_ = 0; + method_ = null; + } + listStateGetBuilder_.clear(); + } + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.Builder getListStateGetBuilder() { + return getListStateGetFieldBuilder().getBuilder(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGetOrBuilder getListStateGetOrBuilder() { + if ((methodCase_ == 3) && (listStateGetBuilder_ != null)) { + return listStateGetBuilder_.getMessageOrBuilder(); + } else { + if (methodCase_ == 3) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGetOrBuilder> + getListStateGetFieldBuilder() { + if (listStateGetBuilder_ == null) { + if (!(methodCase_ == 3)) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.getDefaultInstance(); + } + listStateGetBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGetOrBuilder>( + (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet) method_, + getParentForChildren(), + isClean()); + method_ = null; + } + methodCase_ = 3; + onChanged();; + return listStateGetBuilder_; + } + + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePutOrBuilder> listStatePutBuilder_; + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + * @return Whether the listStatePut field is set. + */ + @java.lang.Override + public boolean hasListStatePut() { + return methodCase_ == 4; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + * @return The listStatePut. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut getListStatePut() { + if (listStatePutBuilder_ == null) { + if (methodCase_ == 4) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.getDefaultInstance(); + } else { + if (methodCase_ == 4) { + return listStatePutBuilder_.getMessage(); + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + */ + public Builder setListStatePut(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut value) { + if (listStatePutBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + method_ = value; + onChanged(); + } else { + listStatePutBuilder_.setMessage(value); + } + methodCase_ = 4; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + */ + public Builder setListStatePut( + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.Builder builderForValue) { + if (listStatePutBuilder_ == null) { + method_ = builderForValue.build(); + onChanged(); + } else { + listStatePutBuilder_.setMessage(builderForValue.build()); + } + methodCase_ = 4; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + */ + public Builder mergeListStatePut(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut value) { + if (listStatePutBuilder_ == null) { + if (methodCase_ == 4 && + method_ != org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.getDefaultInstance()) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.newBuilder((org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut) method_) + .mergeFrom(value).buildPartial(); + } else { + method_ = value; + } + onChanged(); + } else { + if (methodCase_ == 4) { + listStatePutBuilder_.mergeFrom(value); + } else { + listStatePutBuilder_.setMessage(value); + } + } + methodCase_ = 4; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + */ + public Builder clearListStatePut() { + if (listStatePutBuilder_ == null) { + if (methodCase_ == 4) { + methodCase_ = 0; + method_ = null; + onChanged(); + } + } else { + if (methodCase_ == 4) { + methodCase_ = 0; + method_ = null; + } + listStatePutBuilder_.clear(); + } + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.Builder getListStatePutBuilder() { + return getListStatePutFieldBuilder().getBuilder(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePutOrBuilder getListStatePutOrBuilder() { + if ((methodCase_ == 4) && (listStatePutBuilder_ != null)) { + return listStatePutBuilder_.getMessageOrBuilder(); + } else { + if (methodCase_ == 4) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePutOrBuilder> + getListStatePutFieldBuilder() { + if (listStatePutBuilder_ == null) { + if (!(methodCase_ == 4)) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.getDefaultInstance(); + } + listStatePutBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePutOrBuilder>( + (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut) method_, + getParentForChildren(), + isClean()); + method_ = null; + } + methodCase_ = 4; + onChanged();; + return listStatePutBuilder_; + } + + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValueOrBuilder> appendValueBuilder_; + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + * @return Whether the appendValue field is set. + */ + @java.lang.Override + public boolean hasAppendValue() { + return methodCase_ == 5; + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + * @return The appendValue. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue getAppendValue() { + if (appendValueBuilder_ == null) { + if (methodCase_ == 5) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.getDefaultInstance(); + } else { + if (methodCase_ == 5) { + return appendValueBuilder_.getMessage(); + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + */ + public Builder setAppendValue(org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue value) { + if (appendValueBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + method_ = value; + onChanged(); + } else { + appendValueBuilder_.setMessage(value); + } + methodCase_ = 5; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + */ + public Builder setAppendValue( + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.Builder builderForValue) { + if (appendValueBuilder_ == null) { + method_ = builderForValue.build(); + onChanged(); + } else { + appendValueBuilder_.setMessage(builderForValue.build()); + } + methodCase_ = 5; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + */ + public Builder mergeAppendValue(org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue value) { + if (appendValueBuilder_ == null) { + if (methodCase_ == 5 && + method_ != org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.getDefaultInstance()) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.newBuilder((org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue) method_) + .mergeFrom(value).buildPartial(); + } else { + method_ = value; + } + onChanged(); + } else { + if (methodCase_ == 5) { + appendValueBuilder_.mergeFrom(value); + } else { + appendValueBuilder_.setMessage(value); + } + } + methodCase_ = 5; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + */ + public Builder clearAppendValue() { + if (appendValueBuilder_ == null) { + if (methodCase_ == 5) { + methodCase_ = 0; + method_ = null; + onChanged(); + } + } else { + if (methodCase_ == 5) { + methodCase_ = 0; + method_ = null; + } + appendValueBuilder_.clear(); + } + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.Builder getAppendValueBuilder() { + return getAppendValueFieldBuilder().getBuilder(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValueOrBuilder getAppendValueOrBuilder() { + if ((methodCase_ == 5) && (appendValueBuilder_ != null)) { + return appendValueBuilder_.getMessageOrBuilder(); + } else { + if (methodCase_ == 5) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValueOrBuilder> + getAppendValueFieldBuilder() { + if (appendValueBuilder_ == null) { + if (!(methodCase_ == 5)) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.getDefaultInstance(); + } + appendValueBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValueOrBuilder>( + (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue) method_, + getParentForChildren(), + isClean()); + method_ = null; + } + methodCase_ = 5; + onChanged();; + return appendValueBuilder_; + } + + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendListOrBuilder> appendListBuilder_; + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + * @return Whether the appendList field is set. + */ + @java.lang.Override + public boolean hasAppendList() { + return methodCase_ == 6; + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + * @return The appendList. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList getAppendList() { + if (appendListBuilder_ == null) { + if (methodCase_ == 6) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.getDefaultInstance(); + } else { + if (methodCase_ == 6) { + return appendListBuilder_.getMessage(); + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + */ + public Builder setAppendList(org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList value) { + if (appendListBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + method_ = value; + onChanged(); + } else { + appendListBuilder_.setMessage(value); + } + methodCase_ = 6; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + */ + public Builder setAppendList( + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.Builder builderForValue) { + if (appendListBuilder_ == null) { + method_ = builderForValue.build(); + onChanged(); + } else { + appendListBuilder_.setMessage(builderForValue.build()); + } + methodCase_ = 6; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + */ + public Builder mergeAppendList(org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList value) { + if (appendListBuilder_ == null) { + if (methodCase_ == 6 && + method_ != org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.getDefaultInstance()) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.newBuilder((org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList) method_) + .mergeFrom(value).buildPartial(); + } else { + method_ = value; + } + onChanged(); + } else { + if (methodCase_ == 6) { + appendListBuilder_.mergeFrom(value); + } else { + appendListBuilder_.setMessage(value); + } + } + methodCase_ = 6; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + */ + public Builder clearAppendList() { + if (appendListBuilder_ == null) { + if (methodCase_ == 6) { + methodCase_ = 0; + method_ = null; + onChanged(); + } + } else { + if (methodCase_ == 6) { + methodCase_ = 0; + method_ = null; + } + appendListBuilder_.clear(); + } + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.Builder getAppendListBuilder() { + return getAppendListFieldBuilder().getBuilder(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendListOrBuilder getAppendListOrBuilder() { + if ((methodCase_ == 6) && (appendListBuilder_ != null)) { + return appendListBuilder_.getMessageOrBuilder(); + } else { + if (methodCase_ == 6) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendListOrBuilder> + getAppendListFieldBuilder() { + if (appendListBuilder_ == null) { + if (!(methodCase_ == 6)) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.getDefaultInstance(); + } + appendListBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendListOrBuilder>( + (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList) method_, + getParentForChildren(), + isClean()); + method_ = null; + } + methodCase_ = 6; + onChanged();; + return appendListBuilder_; + } + + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear, org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilder> clearBuilder_; + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + * @return Whether the clear field is set. + */ + @java.lang.Override + public boolean hasClear() { + return methodCase_ == 7; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + * @return The clear. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getClear() { + if (clearBuilder_ == null) { + if (methodCase_ == 7) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance(); + } else { + if (methodCase_ == 7) { + return clearBuilder_.getMessage(); + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + */ + public Builder setClear(org.apache.spark.sql.execution.streaming.state.StateMessage.Clear value) { + if (clearBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + method_ = value; + onChanged(); + } else { + clearBuilder_.setMessage(value); + } + methodCase_ = 7; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + */ + public Builder setClear( + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder builderForValue) { + if (clearBuilder_ == null) { + method_ = builderForValue.build(); + onChanged(); + } else { + clearBuilder_.setMessage(builderForValue.build()); + } + methodCase_ = 7; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + */ + public Builder mergeClear(org.apache.spark.sql.execution.streaming.state.StateMessage.Clear value) { + if (clearBuilder_ == null) { + if (methodCase_ == 7 && + method_ != org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance()) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.newBuilder((org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) method_) + .mergeFrom(value).buildPartial(); + } else { + method_ = value; + } + onChanged(); + } else { + if (methodCase_ == 7) { + clearBuilder_.mergeFrom(value); + } else { + clearBuilder_.setMessage(value); + } + } + methodCase_ = 7; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + */ + public Builder clearClear() { + if (clearBuilder_ == null) { + if (methodCase_ == 7) { + methodCase_ = 0; + method_ = null; + onChanged(); + } + } else { + if (methodCase_ == 7) { + methodCase_ = 0; + method_ = null; + } + clearBuilder_.clear(); + } + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder getClearBuilder() { + return getClearFieldBuilder().getBuilder(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilder getClearOrBuilder() { + if ((methodCase_ == 7) && (clearBuilder_ != null)) { + return clearBuilder_.getMessageOrBuilder(); + } else { + if (methodCase_ == 7) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear, org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilder> + getClearFieldBuilder() { + if (clearBuilder_ == null) { + if (!(methodCase_ == 7)) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance(); + } + clearBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear, org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilder>( + (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) method_, + getParentForChildren(), + isClean()); + method_ = null; + } + methodCase_ = 7; + onChanged();; + return clearBuilder_; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.ListStateCall) + } + + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.ListStateCall) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall(); + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public ListStateCall parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + Builder builder = newBuilder(); + try { + builder.mergeFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(builder.buildPartial()); + } catch (com.google.protobuf.UninitializedMessageException e) { + throw e.asInvalidProtocolBufferException().setUnfinishedMessage(builder.buildPartial()); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e) + .setUnfinishedMessage(builder.buildPartial()); + } + return builder.buildPartial(); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + + public interface SetImplicitKeyOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) + com.google.protobuf.MessageOrBuilder { + + /** + * bytes key = 1; + * @return The key. + */ + com.google.protobuf.ByteString getKey(); + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.SetImplicitKey} + */ + public static final class SetImplicitKey extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) + SetImplicitKeyOrBuilder { + private static final long serialVersionUID = 0L; + // Use SetImplicitKey.newBuilder() to construct. + private SetImplicitKey(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private SetImplicitKey() { + key_ = com.google.protobuf.ByteString.EMPTY; + } + + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new SetImplicitKey(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.Builder.class); + } + + public static final int KEY_FIELD_NUMBER = 1; + private com.google.protobuf.ByteString key_; + /** + * bytes key = 1; + * @return The key. + */ + @java.lang.Override + public com.google.protobuf.ByteString getKey() { + return key_; + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + if (!key_.isEmpty()) { + output.writeBytes(1, key_); + } + getUnknownFields().writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + if (!key_.isEmpty()) { + size += com.google.protobuf.CodedOutputStream + .computeBytesSize(1, key_); + } + size += getUnknownFields().getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey)) { + return super.equals(obj); + } + org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey other = (org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey) obj; + + if (!getKey() + .equals(other.getKey())) return false; + if (!getUnknownFields().equals(other.getUnknownFields())) return false; + return true; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + hash = (37 * hash) + KEY_FIELD_NUMBER; + hash = (53 * hash) + getKey().hashCode(); + hash = (29 * hash) + getUnknownFields().hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.SetImplicitKey} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) + org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKeyOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.Builder.class); + } + + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.newBuilder() + private Builder() { + + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + + } + @java.lang.Override + public Builder clear() { + super.clear(); + key_ = com.google.protobuf.ByteString.EMPTY; + + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.getDefaultInstance(); + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey result = new org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey(this); + result.key_ = key_; + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.getDefaultInstance()) return this; + if (other.getKey() != com.google.protobuf.ByteString.EMPTY) { + setKey(other.getKey()); + } + this.mergeUnknownFields(other.getUnknownFields()); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 10: { + key_ = input.readBytes(); + + break; + } // case 10 + default: { + if (!super.parseUnknownField(input, extensionRegistry, tag)) { + done = true; // was an endgroup tag + } + break; + } // default: + } // switch (tag) + } // while (!done) + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.unwrapIOException(); + } finally { + onChanged(); + } // finally + return this; + } + + private com.google.protobuf.ByteString key_ = com.google.protobuf.ByteString.EMPTY; + /** + * bytes key = 1; + * @return The key. + */ + @java.lang.Override + public com.google.protobuf.ByteString getKey() { + return key_; + } + /** + * bytes key = 1; + * @param value The key to set. + * @return This builder for chaining. + */ + public Builder setKey(com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + + key_ = value; + onChanged(); + return this; + } + /** + * bytes key = 1; + * @return This builder for chaining. + */ + public Builder clearKey() { + + key_ = getDefaultInstance().getKey(); + onChanged(); + return this; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) + } + + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey(); + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public SetImplicitKey parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + Builder builder = newBuilder(); + try { + builder.mergeFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(builder.buildPartial()); + } catch (com.google.protobuf.UninitializedMessageException e) { + throw e.asInvalidProtocolBufferException().setUnfinishedMessage(builder.buildPartial()); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e) + .setUnfinishedMessage(builder.buildPartial()); + } + return builder.buildPartial(); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + + public interface RemoveImplicitKeyOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) + com.google.protobuf.MessageOrBuilder { + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey} + */ + public static final class RemoveImplicitKey extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) + RemoveImplicitKeyOrBuilder { + private static final long serialVersionUID = 0L; + // Use RemoveImplicitKey.newBuilder() to construct. + private RemoveImplicitKey(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private RemoveImplicitKey() { + } + + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new RemoveImplicitKey(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.Builder.class); + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + getUnknownFields().writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + size += getUnknownFields().getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey)) { + return super.equals(obj); + } + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey other = (org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey) obj; + + if (!getUnknownFields().equals(other.getUnknownFields())) return false; + return true; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + hash = (29 * hash) + getUnknownFields().hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKeyOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.Builder.class); + } + + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.newBuilder() + private Builder() { + + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + + } + @java.lang.Override + public Builder clear() { + super.clear(); + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_descriptor; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.getDefaultInstance(); + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey result = new org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey(this); + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.getDefaultInstance()) return this; + this.mergeUnknownFields(other.getUnknownFields()); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + default: { + if (!super.parseUnknownField(input, extensionRegistry, tag)) { + done = true; // was an endgroup tag + } + break; + } // default: + } // switch (tag) + } // while (!done) + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.unwrapIOException(); + } finally { + onChanged(); + } // finally + return this; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) + } + + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey(); + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public RemoveImplicitKey parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + Builder builder = newBuilder(); + try { + builder.mergeFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(builder.buildPartial()); + } catch (com.google.protobuf.UninitializedMessageException e) { + throw e.asInvalidProtocolBufferException().setUnfinishedMessage(builder.buildPartial()); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e) + .setUnfinishedMessage(builder.buildPartial()); + } + return builder.buildPartial(); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + + public interface ExistsOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.Exists) + com.google.protobuf.MessageOrBuilder { + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Exists} + */ + public static final class Exists extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.Exists) + ExistsOrBuilder { + private static final long serialVersionUID = 0L; + // Use Exists.newBuilder() to construct. + private Exists(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private Exists() { + } + + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new Exists(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder.class); + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + getUnknownFields().writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + size += getUnknownFields().getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Exists)) { + return super.equals(obj); + } + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists other = (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) obj; + + if (!getUnknownFields().equals(other.getUnknownFields())) return false; + return true; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + hash = (29 * hash) + getUnknownFields().hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.Exists prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Exists} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.Exists) + org.apache.spark.sql.execution.streaming.state.StateMessage.ExistsOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder.class); + } + + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.newBuilder() + private Builder() { + + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + + } + @java.lang.Override + public Builder clear() { + super.clear(); + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_descriptor; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance(); + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists result = new org.apache.spark.sql.execution.streaming.state.StateMessage.Exists(this); + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.Exists)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.Exists other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance()) return this; + this.mergeUnknownFields(other.getUnknownFields()); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + default: { + if (!super.parseUnknownField(input, extensionRegistry, tag)) { + done = true; // was an endgroup tag + } + break; + } // default: + } // switch (tag) + } // while (!done) + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.unwrapIOException(); + } finally { + onChanged(); + } // finally + return this; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Exists) + } + + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.Exists) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.Exists DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.Exists(); + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public Exists parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + Builder builder = newBuilder(); + try { + builder.mergeFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(builder.buildPartial()); + } catch (com.google.protobuf.UninitializedMessageException e) { + throw e.asInvalidProtocolBufferException().setUnfinishedMessage(builder.buildPartial()); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e) + .setUnfinishedMessage(builder.buildPartial()); + } + return builder.buildPartial(); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + + public interface GetOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.Get) + com.google.protobuf.MessageOrBuilder { + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Get} + */ + public static final class Get extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.Get) + GetOrBuilder { + private static final long serialVersionUID = 0L; + // Use Get.newBuilder() to construct. + private Get(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private Get() { + } + + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new Get(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.Get.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Get.Builder.class); + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + getUnknownFields().writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + size += getUnknownFields().getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Get)) { + return super.equals(obj); + } + org.apache.spark.sql.execution.streaming.state.StateMessage.Get other = (org.apache.spark.sql.execution.streaming.state.StateMessage.Get) obj; + + if (!getUnknownFields().equals(other.getUnknownFields())) return false; + return true; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + hash = (29 * hash) + getUnknownFields().hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.Get prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Get} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.Get) + org.apache.spark.sql.execution.streaming.state.StateMessage.GetOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.Get.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Get.Builder.class); + } + + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.Get.newBuilder() + private Builder() { + + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + + } + @java.lang.Override + public Builder clear() { + super.clear(); + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_descriptor; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Get getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.Get.getDefaultInstance(); + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Get build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.Get result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Get buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.Get result = new org.apache.spark.sql.execution.streaming.state.StateMessage.Get(this); + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Get) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.Get)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.Get other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.Get.getDefaultInstance()) return this; + this.mergeUnknownFields(other.getUnknownFields()); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + default: { + if (!super.parseUnknownField(input, extensionRegistry, tag)) { + done = true; // was an endgroup tag + } + break; + } // default: + } // switch (tag) + } // while (!done) + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.unwrapIOException(); + } finally { + onChanged(); + } // finally + return this; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Get) + } + + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.Get) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.Get DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.Get(); + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public Get parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + Builder builder = newBuilder(); + try { + builder.mergeFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(builder.buildPartial()); + } catch (com.google.protobuf.UninitializedMessageException e) { + throw e.asInvalidProtocolBufferException().setUnfinishedMessage(builder.buildPartial()); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e) + .setUnfinishedMessage(builder.buildPartial()); + } + return builder.buildPartial(); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Get getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + + public interface ValueStateUpdateOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) + com.google.protobuf.MessageOrBuilder { + + /** + * bytes value = 1; + * @return The value. + */ + com.google.protobuf.ByteString getValue(); + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ValueStateUpdate} + */ + public static final class ValueStateUpdate extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) + ValueStateUpdateOrBuilder { + private static final long serialVersionUID = 0L; + // Use ValueStateUpdate.newBuilder() to construct. + private ValueStateUpdate(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private ValueStateUpdate() { + value_ = com.google.protobuf.ByteString.EMPTY; + } + + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new ValueStateUpdate(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.Builder.class); + } + + public static final int VALUE_FIELD_NUMBER = 1; + private com.google.protobuf.ByteString value_; + /** + * bytes value = 1; + * @return The value. + */ + @java.lang.Override + public com.google.protobuf.ByteString getValue() { + return value_; + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { byte isInitialized = memoizedIsInitialized; if (isInitialized == 1) return true; if (isInitialized == 0) return false; @@ -7558,8 +11434,8 @@ public final boolean isInitialized() { @java.lang.Override public void writeTo(com.google.protobuf.CodedOutputStream output) throws java.io.IOException { - if (!key_.isEmpty()) { - output.writeBytes(1, key_); + if (!value_.isEmpty()) { + output.writeBytes(1, value_); } getUnknownFields().writeTo(output); } @@ -7570,9 +11446,9 @@ public int getSerializedSize() { if (size != -1) return size; size = 0; - if (!key_.isEmpty()) { + if (!value_.isEmpty()) { size += com.google.protobuf.CodedOutputStream - .computeBytesSize(1, key_); + .computeBytesSize(1, value_); } size += getUnknownFields().getSerializedSize(); memoizedSize = size; @@ -7584,13 +11460,13 @@ public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } - if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey)) { + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate)) { return super.equals(obj); } - org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey other = (org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey) obj; + org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate other = (org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate) obj; - if (!getKey() - .equals(other.getKey())) return false; + if (!getValue() + .equals(other.getValue())) return false; if (!getUnknownFields().equals(other.getUnknownFields())) return false; return true; } @@ -7602,76 +11478,76 @@ public int hashCode() { } int hash = 41; hash = (19 * hash) + getDescriptor().hashCode(); - hash = (37 * hash) + KEY_FIELD_NUMBER; - hash = (53 * hash) + getKey().hashCode(); + hash = (37 * hash) + VALUE_FIELD_NUMBER; + hash = (53 * hash) + getValue().hashCode(); hash = (29 * hash) + getUnknownFields().hashCode(); memoizedHashCode = hash; return hash; } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( java.nio.ByteBuffer data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( com.google.protobuf.ByteString data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( com.google.protobuf.ByteString data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom(byte[] data) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom(byte[] data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseDelimitedFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseDelimitedFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( com.google.protobuf.CodedInputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { @@ -7684,7 +11560,7 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImp public static Builder newBuilder() { return DEFAULT_INSTANCE.toBuilder(); } - public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey prototype) { + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate prototype) { return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); } @java.lang.Override @@ -7700,26 +11576,26 @@ protected Builder newBuilderForType( return builder; } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.SetImplicitKey} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ValueStateUpdate} */ public static final class Builder extends com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) - org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKeyOrBuilder { + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) + org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdateOrBuilder { public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.Builder.class); } - // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.newBuilder() + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.newBuilder() private Builder() { } @@ -7732,7 +11608,7 @@ private Builder( @java.lang.Override public Builder clear() { super.clear(); - key_ = com.google.protobuf.ByteString.EMPTY; + value_ = com.google.protobuf.ByteString.EMPTY; return this; } @@ -7740,17 +11616,17 @@ public Builder clear() { @java.lang.Override public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_descriptor; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey getDefaultInstanceForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.getDefaultInstance(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.getDefaultInstance(); } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey build() { - org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey result = buildPartial(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate result = buildPartial(); if (!result.isInitialized()) { throw newUninitializedMessageException(result); } @@ -7758,9 +11634,9 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKe } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey buildPartial() { - org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey result = new org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey(this); - result.key_ = key_; + public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate result = new org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate(this); + result.value_ = value_; onBuilt(); return result; } @@ -7799,18 +11675,18 @@ public Builder addRepeatedField( } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey) { - return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey)other); + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate)other); } else { super.mergeFrom(other); return this; } } - public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey other) { - if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.getDefaultInstance()) return this; - if (other.getKey() != com.google.protobuf.ByteString.EMPTY) { - setKey(other.getKey()); + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.getDefaultInstance()) return this; + if (other.getValue() != com.google.protobuf.ByteString.EMPTY) { + setValue(other.getValue()); } this.mergeUnknownFields(other.getUnknownFields()); onChanged(); @@ -7839,7 +11715,7 @@ public Builder mergeFrom( done = true; break; case 10: { - key_ = input.readBytes(); + value_ = input.readBytes(); break; } // case 10 @@ -7859,36 +11735,36 @@ public Builder mergeFrom( return this; } - private com.google.protobuf.ByteString key_ = com.google.protobuf.ByteString.EMPTY; + private com.google.protobuf.ByteString value_ = com.google.protobuf.ByteString.EMPTY; /** - * bytes key = 1; - * @return The key. + * bytes value = 1; + * @return The value. */ @java.lang.Override - public com.google.protobuf.ByteString getKey() { - return key_; + public com.google.protobuf.ByteString getValue() { + return value_; } /** - * bytes key = 1; - * @param value The key to set. + * bytes value = 1; + * @param value The value to set. * @return This builder for chaining. */ - public Builder setKey(com.google.protobuf.ByteString value) { + public Builder setValue(com.google.protobuf.ByteString value) { if (value == null) { throw new NullPointerException(); } - key_ = value; + value_ = value; onChanged(); return this; } /** - * bytes key = 1; + * bytes value = 1; * @return This builder for chaining. */ - public Builder clearKey() { + public Builder clearValue() { - key_ = getDefaultInstance().getKey(); + value_ = getDefaultInstance().getValue(); onChanged(); return this; } @@ -7905,23 +11781,23 @@ public final Builder mergeUnknownFields( } - // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) } - // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) - private static final org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey DEFAULT_INSTANCE; + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate DEFAULT_INSTANCE; static { - DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey(); + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate(); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey getDefaultInstance() { + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate getDefaultInstance() { return DEFAULT_INSTANCE; } - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { @java.lang.Override - public SetImplicitKey parsePartialFrom( + public ValueStateUpdate parsePartialFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { @@ -7940,46 +11816,46 @@ public SetImplicitKey parsePartialFrom( } }; - public static com.google.protobuf.Parser parser() { + public static com.google.protobuf.Parser parser() { return PARSER; } @java.lang.Override - public com.google.protobuf.Parser getParserForType() { + public com.google.protobuf.Parser getParserForType() { return PARSER; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey getDefaultInstanceForType() { + public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate getDefaultInstanceForType() { return DEFAULT_INSTANCE; } } - public interface RemoveImplicitKeyOrBuilder extends - // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) + public interface ClearOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.Clear) com.google.protobuf.MessageOrBuilder { } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Clear} */ - public static final class RemoveImplicitKey extends + public static final class Clear extends com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) - RemoveImplicitKeyOrBuilder { + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.Clear) + ClearOrBuilder { private static final long serialVersionUID = 0L; - // Use RemoveImplicitKey.newBuilder() to construct. - private RemoveImplicitKey(com.google.protobuf.GeneratedMessageV3.Builder builder) { + // Use Clear.newBuilder() to construct. + private Clear(com.google.protobuf.GeneratedMessageV3.Builder builder) { super(builder); } - private RemoveImplicitKey() { + private Clear() { } @java.lang.Override @SuppressWarnings({"unused"}) protected java.lang.Object newInstance( UnusedPrivateParameter unused) { - return new RemoveImplicitKey(); + return new Clear(); } @java.lang.Override @@ -7989,15 +11865,15 @@ protected java.lang.Object newInstance( } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder.class); } private byte memoizedIsInitialized = -1; @@ -8033,10 +11909,10 @@ public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } - if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey)) { + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Clear)) { return super.equals(obj); } - org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey other = (org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey) obj; + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear other = (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) obj; if (!getUnknownFields().equals(other.getUnknownFields())) return false; return true; @@ -8054,69 +11930,69 @@ public int hashCode() { return hash; } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( java.nio.ByteBuffer data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( com.google.protobuf.ByteString data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( com.google.protobuf.ByteString data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom(byte[] data) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom(byte[] data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseDelimitedFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseDelimitedFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( com.google.protobuf.CodedInputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { @@ -8129,7 +12005,7 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.Remove public static Builder newBuilder() { return DEFAULT_INSTANCE.toBuilder(); } - public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey prototype) { + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.Clear prototype) { return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); } @java.lang.Override @@ -8145,26 +12021,26 @@ protected Builder newBuilderForType( return builder; } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Clear} */ public static final class Builder extends com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) - org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKeyOrBuilder { + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.Clear) + org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilder { public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder.class); } - // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.newBuilder() + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.newBuilder() private Builder() { } @@ -8183,17 +12059,17 @@ public Builder clear() { @java.lang.Override public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_descriptor; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey getDefaultInstanceForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.getDefaultInstance(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance(); } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey build() { - org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey result = buildPartial(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear result = buildPartial(); if (!result.isInitialized()) { throw newUninitializedMessageException(result); } @@ -8201,8 +12077,8 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplici } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey buildPartial() { - org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey result = new org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey(this); + public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear result = new org.apache.spark.sql.execution.streaming.state.StateMessage.Clear(this); onBuilt(); return result; } @@ -8241,16 +12117,16 @@ public Builder addRepeatedField( } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey) { - return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey)other); + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.Clear)other); } else { super.mergeFrom(other); return this; } } - public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey other) { - if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.getDefaultInstance()) return this; + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.Clear other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance()) return this; this.mergeUnknownFields(other.getUnknownFields()); onChanged(); return this; @@ -8305,23 +12181,23 @@ public final Builder mergeUnknownFields( } - // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Clear) } - // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) - private static final org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey DEFAULT_INSTANCE; + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.Clear) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.Clear DEFAULT_INSTANCE; static { - DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey(); + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.Clear(); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey getDefaultInstance() { + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getDefaultInstance() { return DEFAULT_INSTANCE; } - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { @java.lang.Override - public RemoveImplicitKey parsePartialFrom( + public Clear parsePartialFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { @@ -8340,46 +12216,59 @@ public RemoveImplicitKey parsePartialFrom( } }; - public static com.google.protobuf.Parser parser() { + public static com.google.protobuf.Parser parser() { return PARSER; } @java.lang.Override - public com.google.protobuf.Parser getParserForType() { + public com.google.protobuf.Parser getParserForType() { return PARSER; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey getDefaultInstanceForType() { + public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getDefaultInstanceForType() { return DEFAULT_INSTANCE; } } - public interface ExistsOrBuilder extends - // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.Exists) + public interface ListStateGetOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.ListStateGet) com.google.protobuf.MessageOrBuilder { + + /** + * string iteratorId = 1; + * @return The iteratorId. + */ + java.lang.String getIteratorId(); + /** + * string iteratorId = 1; + * @return The bytes for iteratorId. + */ + com.google.protobuf.ByteString + getIteratorIdBytes(); } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Exists} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ListStateGet} */ - public static final class Exists extends + public static final class ListStateGet extends com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.Exists) - ExistsOrBuilder { + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.ListStateGet) + ListStateGetOrBuilder { private static final long serialVersionUID = 0L; - // Use Exists.newBuilder() to construct. - private Exists(com.google.protobuf.GeneratedMessageV3.Builder builder) { + // Use ListStateGet.newBuilder() to construct. + private ListStateGet(com.google.protobuf.GeneratedMessageV3.Builder builder) { super(builder); } - private Exists() { + private ListStateGet() { + iteratorId_ = ""; } @java.lang.Override @SuppressWarnings({"unused"}) protected java.lang.Object newInstance( UnusedPrivateParameter unused) { - return new Exists(); + return new ListStateGet(); } @java.lang.Override @@ -8389,15 +12278,53 @@ protected java.lang.Object newInstance( } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.Builder.class); + } + + public static final int ITERATORID_FIELD_NUMBER = 1; + private volatile java.lang.Object iteratorId_; + /** + * string iteratorId = 1; + * @return The iteratorId. + */ + @java.lang.Override + public java.lang.String getIteratorId() { + java.lang.Object ref = iteratorId_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + iteratorId_ = s; + return s; + } + } + /** + * string iteratorId = 1; + * @return The bytes for iteratorId. + */ + @java.lang.Override + public com.google.protobuf.ByteString + getIteratorIdBytes() { + java.lang.Object ref = iteratorId_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + iteratorId_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } } private byte memoizedIsInitialized = -1; @@ -8414,6 +12341,9 @@ public final boolean isInitialized() { @java.lang.Override public void writeTo(com.google.protobuf.CodedOutputStream output) throws java.io.IOException { + if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(iteratorId_)) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 1, iteratorId_); + } getUnknownFields().writeTo(output); } @@ -8423,6 +12353,9 @@ public int getSerializedSize() { if (size != -1) return size; size = 0; + if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(iteratorId_)) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(1, iteratorId_); + } size += getUnknownFields().getSerializedSize(); memoizedSize = size; return size; @@ -8433,11 +12366,13 @@ public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } - if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Exists)) { + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet)) { return super.equals(obj); } - org.apache.spark.sql.execution.streaming.state.StateMessage.Exists other = (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) obj; + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet other = (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet) obj; + if (!getIteratorId() + .equals(other.getIteratorId())) return false; if (!getUnknownFields().equals(other.getUnknownFields())) return false; return true; } @@ -8449,74 +12384,76 @@ public int hashCode() { } int hash = 41; hash = (19 * hash) + getDescriptor().hashCode(); + hash = (37 * hash) + ITERATORID_FIELD_NUMBER; + hash = (53 * hash) + getIteratorId().hashCode(); hash = (29 * hash) + getUnknownFields().hashCode(); memoizedHashCode = hash; return hash; } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( java.nio.ByteBuffer data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( com.google.protobuf.ByteString data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( com.google.protobuf.ByteString data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom(byte[] data) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom(byte[] data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseDelimitedFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseDelimitedFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( com.google.protobuf.CodedInputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { @@ -8529,7 +12466,7 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists public static Builder newBuilder() { return DEFAULT_INSTANCE.toBuilder(); } - public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.Exists prototype) { + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet prototype) { return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); } @java.lang.Override @@ -8545,26 +12482,26 @@ protected Builder newBuilderForType( return builder; } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Exists} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ListStateGet} */ public static final class Builder extends com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.Exists) - org.apache.spark.sql.execution.streaming.state.StateMessage.ExistsOrBuilder { + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.ListStateGet) + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGetOrBuilder { public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.Builder.class); } - // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.newBuilder() + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.newBuilder() private Builder() { } @@ -8577,23 +12514,25 @@ private Builder( @java.lang.Override public Builder clear() { super.clear(); + iteratorId_ = ""; + return this; } @java.lang.Override public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_descriptor; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getDefaultInstanceForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.getDefaultInstance(); } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists build() { - org.apache.spark.sql.execution.streaming.state.StateMessage.Exists result = buildPartial(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet result = buildPartial(); if (!result.isInitialized()) { throw newUninitializedMessageException(result); } @@ -8601,8 +12540,9 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists build( } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists buildPartial() { - org.apache.spark.sql.execution.streaming.state.StateMessage.Exists result = new org.apache.spark.sql.execution.streaming.state.StateMessage.Exists(this); + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet result = new org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet(this); + result.iteratorId_ = iteratorId_; onBuilt(); return result; } @@ -8641,16 +12581,20 @@ public Builder addRepeatedField( } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) { - return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.Exists)other); + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet)other); } else { super.mergeFrom(other); return this; } } - public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.Exists other) { - if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance()) return this; + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.getDefaultInstance()) return this; + if (!other.getIteratorId().isEmpty()) { + iteratorId_ = other.iteratorId_; + onChanged(); + } this.mergeUnknownFields(other.getUnknownFields()); onChanged(); return this; @@ -8677,6 +12621,11 @@ public Builder mergeFrom( case 0: done = true; break; + case 10: { + iteratorId_ = input.readStringRequireUtf8(); + + break; + } // case 10 default: { if (!super.parseUnknownField(input, extensionRegistry, tag)) { done = true; // was an endgroup tag @@ -8692,6 +12641,82 @@ public Builder mergeFrom( } // finally return this; } + + private java.lang.Object iteratorId_ = ""; + /** + * string iteratorId = 1; + * @return The iteratorId. + */ + public java.lang.String getIteratorId() { + java.lang.Object ref = iteratorId_; + if (!(ref instanceof java.lang.String)) { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + iteratorId_ = s; + return s; + } else { + return (java.lang.String) ref; + } + } + /** + * string iteratorId = 1; + * @return The bytes for iteratorId. + */ + public com.google.protobuf.ByteString + getIteratorIdBytes() { + java.lang.Object ref = iteratorId_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + iteratorId_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + * string iteratorId = 1; + * @param value The iteratorId to set. + * @return This builder for chaining. + */ + public Builder setIteratorId( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + + iteratorId_ = value; + onChanged(); + return this; + } + /** + * string iteratorId = 1; + * @return This builder for chaining. + */ + public Builder clearIteratorId() { + + iteratorId_ = getDefaultInstance().getIteratorId(); + onChanged(); + return this; + } + /** + * string iteratorId = 1; + * @param value The bytes for iteratorId to set. + * @return This builder for chaining. + */ + public Builder setIteratorIdBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + checkByteStringIsUtf8(value); + + iteratorId_ = value; + onChanged(); + return this; + } @java.lang.Override public final Builder setUnknownFields( final com.google.protobuf.UnknownFieldSet unknownFields) { @@ -8705,23 +12730,23 @@ public final Builder mergeUnknownFields( } - // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Exists) + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.ListStateGet) } - // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.Exists) - private static final org.apache.spark.sql.execution.streaming.state.StateMessage.Exists DEFAULT_INSTANCE; + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.ListStateGet) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet DEFAULT_INSTANCE; static { - DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.Exists(); + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet(); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getDefaultInstance() { + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet getDefaultInstance() { return DEFAULT_INSTANCE; } - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { @java.lang.Override - public Exists parsePartialFrom( + public ListStateGet parsePartialFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { @@ -8740,46 +12765,46 @@ public Exists parsePartialFrom( } }; - public static com.google.protobuf.Parser parser() { + public static com.google.protobuf.Parser parser() { return PARSER; } @java.lang.Override - public com.google.protobuf.Parser getParserForType() { + public com.google.protobuf.Parser getParserForType() { return PARSER; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getDefaultInstanceForType() { + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet getDefaultInstanceForType() { return DEFAULT_INSTANCE; } } - public interface GetOrBuilder extends - // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.Get) + public interface ListStatePutOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.ListStatePut) com.google.protobuf.MessageOrBuilder { } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Get} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ListStatePut} */ - public static final class Get extends + public static final class ListStatePut extends com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.Get) - GetOrBuilder { + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.ListStatePut) + ListStatePutOrBuilder { private static final long serialVersionUID = 0L; - // Use Get.newBuilder() to construct. - private Get(com.google.protobuf.GeneratedMessageV3.Builder builder) { + // Use ListStatePut.newBuilder() to construct. + private ListStatePut(com.google.protobuf.GeneratedMessageV3.Builder builder) { super(builder); } - private Get() { + private ListStatePut() { } @java.lang.Override @SuppressWarnings({"unused"}) protected java.lang.Object newInstance( UnusedPrivateParameter unused) { - return new Get(); + return new ListStatePut(); } @java.lang.Override @@ -8789,15 +12814,15 @@ protected java.lang.Object newInstance( } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.Get.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Get.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.Builder.class); } private byte memoizedIsInitialized = -1; @@ -8833,10 +12858,10 @@ public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } - if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Get)) { + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut)) { return super.equals(obj); } - org.apache.spark.sql.execution.streaming.state.StateMessage.Get other = (org.apache.spark.sql.execution.streaming.state.StateMessage.Get) obj; + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut other = (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut) obj; if (!getUnknownFields().equals(other.getUnknownFields())) return false; return true; @@ -8854,69 +12879,69 @@ public int hashCode() { return hash; } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( java.nio.ByteBuffer data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( com.google.protobuf.ByteString data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( com.google.protobuf.ByteString data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom(byte[] data) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom(byte[] data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseDelimitedFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseDelimitedFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( com.google.protobuf.CodedInputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { @@ -8929,7 +12954,7 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get pa public static Builder newBuilder() { return DEFAULT_INSTANCE.toBuilder(); } - public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.Get prototype) { + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut prototype) { return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); } @java.lang.Override @@ -8945,26 +12970,26 @@ protected Builder newBuilderForType( return builder; } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Get} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ListStatePut} */ public static final class Builder extends com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.Get) - org.apache.spark.sql.execution.streaming.state.StateMessage.GetOrBuilder { + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.ListStatePut) + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePutOrBuilder { public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.Get.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Get.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.Builder.class); } - // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.Get.newBuilder() + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.newBuilder() private Builder() { } @@ -8983,17 +13008,17 @@ public Builder clear() { @java.lang.Override public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_descriptor; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Get getDefaultInstanceForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.Get.getDefaultInstance(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.getDefaultInstance(); } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Get build() { - org.apache.spark.sql.execution.streaming.state.StateMessage.Get result = buildPartial(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut result = buildPartial(); if (!result.isInitialized()) { throw newUninitializedMessageException(result); } @@ -9001,8 +13026,8 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.Get build() { } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Get buildPartial() { - org.apache.spark.sql.execution.streaming.state.StateMessage.Get result = new org.apache.spark.sql.execution.streaming.state.StateMessage.Get(this); + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut result = new org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut(this); onBuilt(); return result; } @@ -9041,16 +13066,16 @@ public Builder addRepeatedField( } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Get) { - return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.Get)other); + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut)other); } else { super.mergeFrom(other); return this; } } - public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.Get other) { - if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.Get.getDefaultInstance()) return this; + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.getDefaultInstance()) return this; this.mergeUnknownFields(other.getUnknownFields()); onChanged(); return this; @@ -9105,23 +13130,23 @@ public final Builder mergeUnknownFields( } - // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Get) + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.ListStatePut) } - // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.Get) - private static final org.apache.spark.sql.execution.streaming.state.StateMessage.Get DEFAULT_INSTANCE; + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.ListStatePut) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut DEFAULT_INSTANCE; static { - DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.Get(); + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut(); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get getDefaultInstance() { + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut getDefaultInstance() { return DEFAULT_INSTANCE; } - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { @java.lang.Override - public Get parsePartialFrom( + public ListStatePut parsePartialFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { @@ -9140,24 +13165,24 @@ public Get parsePartialFrom( } }; - public static com.google.protobuf.Parser parser() { + public static com.google.protobuf.Parser parser() { return PARSER; } @java.lang.Override - public com.google.protobuf.Parser getParserForType() { + public com.google.protobuf.Parser getParserForType() { return PARSER; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Get getDefaultInstanceForType() { + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut getDefaultInstanceForType() { return DEFAULT_INSTANCE; } } - public interface ValueStateUpdateOrBuilder extends - // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) + public interface AppendValueOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.AppendValue) com.google.protobuf.MessageOrBuilder { /** @@ -9167,18 +13192,18 @@ public interface ValueStateUpdateOrBuilder extends com.google.protobuf.ByteString getValue(); } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ValueStateUpdate} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.AppendValue} */ - public static final class ValueStateUpdate extends + public static final class AppendValue extends com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) - ValueStateUpdateOrBuilder { + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.AppendValue) + AppendValueOrBuilder { private static final long serialVersionUID = 0L; - // Use ValueStateUpdate.newBuilder() to construct. - private ValueStateUpdate(com.google.protobuf.GeneratedMessageV3.Builder builder) { + // Use AppendValue.newBuilder() to construct. + private AppendValue(com.google.protobuf.GeneratedMessageV3.Builder builder) { super(builder); } - private ValueStateUpdate() { + private AppendValue() { value_ = com.google.protobuf.ByteString.EMPTY; } @@ -9186,7 +13211,7 @@ private ValueStateUpdate() { @SuppressWarnings({"unused"}) protected java.lang.Object newInstance( UnusedPrivateParameter unused) { - return new ValueStateUpdate(); + return new AppendValue(); } @java.lang.Override @@ -9196,15 +13221,15 @@ protected java.lang.Object newInstance( } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.class, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.Builder.class); } public static final int VALUE_FIELD_NUMBER = 1; @@ -9258,10 +13283,10 @@ public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } - if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate)) { + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue)) { return super.equals(obj); } - org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate other = (org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate) obj; + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue other = (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue) obj; if (!getValue() .equals(other.getValue())) return false; @@ -9283,69 +13308,69 @@ public int hashCode() { return hash; } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( java.nio.ByteBuffer data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( com.google.protobuf.ByteString data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( com.google.protobuf.ByteString data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom(byte[] data) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom(byte[] data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseDelimitedFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseDelimitedFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( com.google.protobuf.CodedInputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { @@ -9358,7 +13383,7 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueS public static Builder newBuilder() { return DEFAULT_INSTANCE.toBuilder(); } - public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate prototype) { + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue prototype) { return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); } @java.lang.Override @@ -9374,26 +13399,26 @@ protected Builder newBuilderForType( return builder; } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ValueStateUpdate} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.AppendValue} */ public static final class Builder extends com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) - org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdateOrBuilder { + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.AppendValue) + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValueOrBuilder { public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.class, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.Builder.class); } - // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.newBuilder() + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.newBuilder() private Builder() { } @@ -9414,17 +13439,17 @@ public Builder clear() { @java.lang.Override public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_descriptor; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate getDefaultInstanceForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.getDefaultInstance(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.getDefaultInstance(); } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate build() { - org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate result = buildPartial(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue result = buildPartial(); if (!result.isInitialized()) { throw newUninitializedMessageException(result); } @@ -9432,8 +13457,8 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpd } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate buildPartial() { - org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate result = new org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate(this); + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue result = new org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue(this); result.value_ = value_; onBuilt(); return result; @@ -9473,16 +13498,16 @@ public Builder addRepeatedField( } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate) { - return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate)other); + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue)other); } else { super.mergeFrom(other); return this; } } - public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate other) { - if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.getDefaultInstance()) return this; + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.getDefaultInstance()) return this; if (other.getValue() != com.google.protobuf.ByteString.EMPTY) { setValue(other.getValue()); } @@ -9579,23 +13604,23 @@ public final Builder mergeUnknownFields( } - // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.AppendValue) } - // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) - private static final org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate DEFAULT_INSTANCE; + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.AppendValue) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue DEFAULT_INSTANCE; static { - DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate(); + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue(); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate getDefaultInstance() { + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue getDefaultInstance() { return DEFAULT_INSTANCE; } - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { @java.lang.Override - public ValueStateUpdate parsePartialFrom( + public AppendValue parsePartialFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { @@ -9614,46 +13639,46 @@ public ValueStateUpdate parsePartialFrom( } }; - public static com.google.protobuf.Parser parser() { + public static com.google.protobuf.Parser parser() { return PARSER; } @java.lang.Override - public com.google.protobuf.Parser getParserForType() { + public com.google.protobuf.Parser getParserForType() { return PARSER; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate getDefaultInstanceForType() { + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue getDefaultInstanceForType() { return DEFAULT_INSTANCE; } } - public interface ClearOrBuilder extends - // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.Clear) + public interface AppendListOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.AppendList) com.google.protobuf.MessageOrBuilder { } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Clear} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.AppendList} */ - public static final class Clear extends + public static final class AppendList extends com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.Clear) - ClearOrBuilder { + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.AppendList) + AppendListOrBuilder { private static final long serialVersionUID = 0L; - // Use Clear.newBuilder() to construct. - private Clear(com.google.protobuf.GeneratedMessageV3.Builder builder) { + // Use AppendList.newBuilder() to construct. + private AppendList(com.google.protobuf.GeneratedMessageV3.Builder builder) { super(builder); } - private Clear() { + private AppendList() { } @java.lang.Override @SuppressWarnings({"unused"}) protected java.lang.Object newInstance( UnusedPrivateParameter unused) { - return new Clear(); + return new AppendList(); } @java.lang.Override @@ -9663,15 +13688,15 @@ protected java.lang.Object newInstance( } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.class, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.Builder.class); } private byte memoizedIsInitialized = -1; @@ -9707,10 +13732,10 @@ public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } - if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Clear)) { + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList)) { return super.equals(obj); } - org.apache.spark.sql.execution.streaming.state.StateMessage.Clear other = (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) obj; + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList other = (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList) obj; if (!getUnknownFields().equals(other.getUnknownFields())) return false; return true; @@ -9728,69 +13753,69 @@ public int hashCode() { return hash; } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( java.nio.ByteBuffer data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( com.google.protobuf.ByteString data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( com.google.protobuf.ByteString data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom(byte[] data) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom(byte[] data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseDelimitedFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseDelimitedFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( com.google.protobuf.CodedInputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { @@ -9803,7 +13828,7 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear public static Builder newBuilder() { return DEFAULT_INSTANCE.toBuilder(); } - public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.Clear prototype) { + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList prototype) { return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); } @java.lang.Override @@ -9819,26 +13844,26 @@ protected Builder newBuilderForType( return builder; } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Clear} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.AppendList} */ public static final class Builder extends com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.Clear) - org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilder { + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.AppendList) + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendListOrBuilder { public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.class, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.Builder.class); } - // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.newBuilder() + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.newBuilder() private Builder() { } @@ -9857,17 +13882,17 @@ public Builder clear() { @java.lang.Override public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_descriptor; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getDefaultInstanceForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.getDefaultInstance(); } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear build() { - org.apache.spark.sql.execution.streaming.state.StateMessage.Clear result = buildPartial(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList result = buildPartial(); if (!result.isInitialized()) { throw newUninitializedMessageException(result); } @@ -9875,8 +13900,8 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear build() } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear buildPartial() { - org.apache.spark.sql.execution.streaming.state.StateMessage.Clear result = new org.apache.spark.sql.execution.streaming.state.StateMessage.Clear(this); + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList result = new org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList(this); onBuilt(); return result; } @@ -9915,16 +13940,16 @@ public Builder addRepeatedField( } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) { - return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.Clear)other); + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList)other); } else { super.mergeFrom(other); return this; } } - public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.Clear other) { - if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance()) return this; + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.getDefaultInstance()) return this; this.mergeUnknownFields(other.getUnknownFields()); onChanged(); return this; @@ -9979,23 +14004,23 @@ public final Builder mergeUnknownFields( } - // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Clear) + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.AppendList) } - // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.Clear) - private static final org.apache.spark.sql.execution.streaming.state.StateMessage.Clear DEFAULT_INSTANCE; + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.AppendList) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList DEFAULT_INSTANCE; static { - DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.Clear(); + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList(); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getDefaultInstance() { + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList getDefaultInstance() { return DEFAULT_INSTANCE; } - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { @java.lang.Override - public Clear parsePartialFrom( + public AppendList parsePartialFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { @@ -10014,17 +14039,17 @@ public Clear parsePartialFrom( } }; - public static com.google.protobuf.Parser parser() { + public static com.google.protobuf.Parser parser() { return PARSER; } @java.lang.Override - public com.google.protobuf.Parser getParserForType() { + public com.google.protobuf.Parser getParserForType() { return PARSER; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getDefaultInstanceForType() { + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList getDefaultInstanceForType() { return DEFAULT_INSTANCE; } @@ -11041,6 +15066,11 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig get private static final com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateCall_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_org_apache_spark_sql_execution_streaming_state_ListStateCall_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_org_apache_spark_sql_execution_streaming_state_ListStateCall_fieldAccessorTable; private static final com.google.protobuf.Descriptors.Descriptor internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor; private static final @@ -11071,6 +15101,26 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig get private static final com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internal_static_org_apache_spark_sql_execution_streaming_state_Clear_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_fieldAccessorTable; private static final com.google.protobuf.Descriptors.Descriptor internal_static_org_apache_spark_sql_execution_streaming_state_SetHandleState_descriptor; private static final @@ -11112,36 +15162,54 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig get "xecution.streaming.state.StateCallComman" + "dH\000\022W\n\013getMapState\030\004 \001(\0132@.org.apache.sp" + "ark.sql.execution.streaming.state.StateC" + - "allCommandH\000B\010\n\006method\"z\n\024StateVariableR" + - "equest\022X\n\016valueStateCall\030\001 \001(\0132>.org.apa" + - "che.spark.sql.execution.streaming.state." + - "ValueStateCallH\000B\010\n\006method\"\340\001\n\032ImplicitG" + - "roupingKeyRequest\022X\n\016setImplicitKey\030\001 \001(" + - "\0132>.org.apache.spark.sql.execution.strea" + - "ming.state.SetImplicitKeyH\000\022^\n\021removeImp" + - "licitKey\030\002 \001(\0132A.org.apache.spark.sql.ex" + - "ecution.streaming.state.RemoveImplicitKe" + - "yH\000B\010\n\006method\"}\n\020StateCallCommand\022\021\n\tsta" + - "teName\030\001 \001(\t\022\016\n\006schema\030\002 \001(\t\022F\n\003ttl\030\003 \001(" + - "\01329.org.apache.spark.sql.execution.strea" + - "ming.state.TTLConfig\"\341\002\n\016ValueStateCall\022" + - "\021\n\tstateName\030\001 \001(\t\022H\n\006exists\030\002 \001(\01326.org" + + "allCommandH\000B\010\n\006method\"\322\001\n\024StateVariable" + + "Request\022X\n\016valueStateCall\030\001 \001(\0132>.org.ap" + + "ache.spark.sql.execution.streaming.state" + + ".ValueStateCallH\000\022V\n\rlistStateCall\030\002 \001(\013" + + "2=.org.apache.spark.sql.execution.stream" + + "ing.state.ListStateCallH\000B\010\n\006method\"\340\001\n\032" + + "ImplicitGroupingKeyRequest\022X\n\016setImplici" + + "tKey\030\001 \001(\0132>.org.apache.spark.sql.execut" + + "ion.streaming.state.SetImplicitKeyH\000\022^\n\021" + + "removeImplicitKey\030\002 \001(\0132A.org.apache.spa" + + "rk.sql.execution.streaming.state.RemoveI" + + "mplicitKeyH\000B\010\n\006method\"}\n\020StateCallComma" + + "nd\022\021\n\tstateName\030\001 \001(\t\022\016\n\006schema\030\002 \001(\t\022F\n" + + "\003ttl\030\003 \001(\01329.org.apache.spark.sql.execut" + + "ion.streaming.state.TTLConfig\"\341\002\n\016ValueS" + + "tateCall\022\021\n\tstateName\030\001 \001(\t\022H\n\006exists\030\002 " + + "\001(\01326.org.apache.spark.sql.execution.str" + + "eaming.state.ExistsH\000\022B\n\003get\030\003 \001(\01323.org" + ".apache.spark.sql.execution.streaming.st" + - "ate.ExistsH\000\022B\n\003get\030\003 \001(\01323.org.apache.s" + - "park.sql.execution.streaming.state.GetH\000" + - "\022\\\n\020valueStateUpdate\030\004 \001(\0132@.org.apache." + - "spark.sql.execution.streaming.state.Valu" + - "eStateUpdateH\000\022F\n\005clear\030\005 \001(\01325.org.apac" + - "he.spark.sql.execution.streaming.state.C" + - "learH\000B\010\n\006method\"\035\n\016SetImplicitKey\022\013\n\003ke" + - "y\030\001 \001(\014\"\023\n\021RemoveImplicitKey\"\010\n\006Exists\"\005" + - "\n\003Get\"!\n\020ValueStateUpdate\022\r\n\005value\030\001 \001(\014" + - "\"\007\n\005Clear\"\\\n\016SetHandleState\022J\n\005state\030\001 \001" + - "(\0162;.org.apache.spark.sql.execution.stre" + - "aming.state.HandleState\"\037\n\tTTLConfig\022\022\n\n" + - "durationMs\030\001 \001(\005*K\n\013HandleState\022\013\n\007CREAT" + - "ED\020\000\022\017\n\013INITIALIZED\020\001\022\022\n\016DATA_PROCESSED\020" + - "\002\022\n\n\006CLOSED\020\003b\006proto3" + "ate.GetH\000\022\\\n\020valueStateUpdate\030\004 \001(\0132@.or" + + "g.apache.spark.sql.execution.streaming.s" + + "tate.ValueStateUpdateH\000\022F\n\005clear\030\005 \001(\01325" + + ".org.apache.spark.sql.execution.streamin" + + "g.state.ClearH\000B\010\n\006method\"\220\004\n\rListStateC" + + "all\022\021\n\tstateName\030\001 \001(\t\022H\n\006exists\030\002 \001(\01326" + + ".org.apache.spark.sql.execution.streamin" + + "g.state.ExistsH\000\022T\n\014listStateGet\030\003 \001(\0132<" + + ".org.apache.spark.sql.execution.streamin" + + "g.state.ListStateGetH\000\022T\n\014listStatePut\030\004" + + " \001(\0132<.org.apache.spark.sql.execution.st" + + "reaming.state.ListStatePutH\000\022R\n\013appendVa" + + "lue\030\005 \001(\0132;.org.apache.spark.sql.executi" + + "on.streaming.state.AppendValueH\000\022P\n\nappe" + + "ndList\030\006 \001(\0132:.org.apache.spark.sql.exec" + + "ution.streaming.state.AppendListH\000\022F\n\005cl" + + "ear\030\007 \001(\01325.org.apache.spark.sql.executi" + + "on.streaming.state.ClearH\000B\010\n\006method\"\035\n\016" + + "SetImplicitKey\022\013\n\003key\030\001 \001(\014\"\023\n\021RemoveImp" + + "licitKey\"\010\n\006Exists\"\005\n\003Get\"!\n\020ValueStateU" + + "pdate\022\r\n\005value\030\001 \001(\014\"\007\n\005Clear\"\"\n\014ListSta" + + "teGet\022\022\n\niteratorId\030\001 \001(\t\"\016\n\014ListStatePu" + + "t\"\034\n\013AppendValue\022\r\n\005value\030\001 \001(\014\"\014\n\nAppen" + + "dList\"\\\n\016SetHandleState\022J\n\005state\030\001 \001(\0162;" + + ".org.apache.spark.sql.execution.streamin" + + "g.state.HandleState\"\037\n\tTTLConfig\022\022\n\ndura" + + "tionMs\030\001 \001(\005*K\n\013HandleState\022\013\n\007CREATED\020\000" + + "\022\017\n\013INITIALIZED\020\001\022\022\n\016DATA_PROCESSED\020\002\022\n\n" + + "\006CLOSED\020\003b\006proto3" }; descriptor = com.google.protobuf.Descriptors.FileDescriptor .internalBuildGeneratedFileFrom(descriptorData, @@ -11170,7 +15238,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig get internal_static_org_apache_spark_sql_execution_streaming_state_StateVariableRequest_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_StateVariableRequest_descriptor, - new java.lang.String[] { "ValueStateCall", "Method", }); + new java.lang.String[] { "ValueStateCall", "ListStateCall", "Method", }); internal_static_org_apache_spark_sql_execution_streaming_state_ImplicitGroupingKeyRequest_descriptor = getDescriptor().getMessageTypes().get(4); internal_static_org_apache_spark_sql_execution_streaming_state_ImplicitGroupingKeyRequest_fieldAccessorTable = new @@ -11189,50 +15257,80 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig get com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateCall_descriptor, new java.lang.String[] { "StateName", "Exists", "Get", "ValueStateUpdate", "Clear", "Method", }); - internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor = + internal_static_org_apache_spark_sql_execution_streaming_state_ListStateCall_descriptor = getDescriptor().getMessageTypes().get(7); + internal_static_org_apache_spark_sql_execution_streaming_state_ListStateCall_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_org_apache_spark_sql_execution_streaming_state_ListStateCall_descriptor, + new java.lang.String[] { "StateName", "Exists", "ListStateGet", "ListStatePut", "AppendValue", "AppendList", "Clear", "Method", }); + internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor = + getDescriptor().getMessageTypes().get(8); internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor, new java.lang.String[] { "Key", }); internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_descriptor = - getDescriptor().getMessageTypes().get(8); + getDescriptor().getMessageTypes().get(9); internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_descriptor, new java.lang.String[] { }); internal_static_org_apache_spark_sql_execution_streaming_state_Exists_descriptor = - getDescriptor().getMessageTypes().get(9); + getDescriptor().getMessageTypes().get(10); internal_static_org_apache_spark_sql_execution_streaming_state_Exists_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_Exists_descriptor, new java.lang.String[] { }); internal_static_org_apache_spark_sql_execution_streaming_state_Get_descriptor = - getDescriptor().getMessageTypes().get(10); + getDescriptor().getMessageTypes().get(11); internal_static_org_apache_spark_sql_execution_streaming_state_Get_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_Get_descriptor, new java.lang.String[] { }); internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_descriptor = - getDescriptor().getMessageTypes().get(11); + getDescriptor().getMessageTypes().get(12); internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_descriptor, new java.lang.String[] { "Value", }); internal_static_org_apache_spark_sql_execution_streaming_state_Clear_descriptor = - getDescriptor().getMessageTypes().get(12); + getDescriptor().getMessageTypes().get(13); internal_static_org_apache_spark_sql_execution_streaming_state_Clear_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_Clear_descriptor, new java.lang.String[] { }); + internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_descriptor = + getDescriptor().getMessageTypes().get(14); + internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_descriptor, + new java.lang.String[] { "IteratorId", }); + internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_descriptor = + getDescriptor().getMessageTypes().get(15); + internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_descriptor, + new java.lang.String[] { }); + internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_descriptor = + getDescriptor().getMessageTypes().get(16); + internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_descriptor, + new java.lang.String[] { "Value", }); + internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_descriptor = + getDescriptor().getMessageTypes().get(17); + internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_descriptor, + new java.lang.String[] { }); internal_static_org_apache_spark_sql_execution_streaming_state_SetHandleState_descriptor = - getDescriptor().getMessageTypes().get(13); + getDescriptor().getMessageTypes().get(18); internal_static_org_apache_spark_sql_execution_streaming_state_SetHandleState_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_SetHandleState_descriptor, new java.lang.String[] { "State", }); internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_descriptor = - getDescriptor().getMessageTypes().get(14); + getDescriptor().getMessageTypes().get(19); internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_descriptor, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 53640f513fc81..53e12f58edd69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -21,9 +21,9 @@ import java.{lang => jl} import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.ExpressionUtils.column import org.apache.spark.sql.types._ /** @@ -33,7 +33,7 @@ import org.apache.spark.sql.types._ */ @Stable final class DataFrameNaFunctions private[sql](df: DataFrame) - extends api.DataFrameNaFunctions[Dataset] { + extends api.DataFrameNaFunctions { import df.sparkSession.RichColumn protected def drop(minNonNulls: Option[Int]): Dataset[Row] = { @@ -121,7 +121,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) (attr.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType))) { replaceCol(attr, replacementMap) } else { - column(attr) + Column(attr) } } df.select(projections : _*) @@ -130,7 +130,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) protected def fillMap(values: Seq[(String, Any)]): DataFrame = { // Error handling val attrToValue = AttributeMap(values.map { case (colName, replaceValue) => - // Check column name exists + // Check Column name exists val attr = df.resolve(colName) match { case a: Attribute => a case _ => throw QueryExecutionErrors.nestedFieldUnsupportedError(colName) @@ -154,7 +154,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) case v: jl.Integer => fillCol[Integer](attr, v) case v: jl.Boolean => fillCol[Boolean](attr, v.booleanValue()) case v: String => fillCol[String](attr, v) - }.getOrElse(column(attr)) + }.getOrElse(Column(attr)) } df.select(projections : _*) } @@ -164,7 +164,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) * with `replacement`. */ private def fillCol[T](attr: Attribute, replacement: T): Column = { - fillCol(attr.dataType, attr.name, column(attr), replacement) + fillCol(attr.dataType, attr.name, Column(attr), replacement) } /** @@ -191,7 +191,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) val branches = replacementMap.flatMap { case (source, target) => Seq(Literal(source), buildExpr(target)) }.toSeq - column(CaseKeyWhen(attr, branches :+ attr)).as(attr.name) + Column(CaseKeyWhen(attr, branches :+ attr)).as(attr.name) } private def convertToDouble(v: Any): Double = v match { @@ -218,7 +218,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) // Filtering condition: // only keep the row if it has at least `minNonNulls` non-null and non-NaN values. val predicate = AtLeastNNonNulls(minNonNulls.getOrElse(cols.size), cols) - df.filter(column(predicate)) + df.filter(Column(predicate)) } private[sql] def fillValue(value: Any, cols: Option[Seq[String]]): DataFrame = { @@ -254,9 +254,9 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) } // Only fill if the column is part of the cols list. if (typeMatches && cols.exists(_.semanticEquals(col))) { - fillCol(col.dataType, col.name, column(col), value) + fillCol(col.dataType, col.name, Column(col), value) } else { - column(col) + Column(col) } } df.select(projections : _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index f105a77cf253b..ab3e939cee171 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.FailureSafeParser import org.apache.spark.sql.catalyst.xml.{StaxXmlParser, XmlOptions} +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource @@ -54,7 +55,9 @@ import org.apache.spark.unsafe.types.UTF8String */ @Stable class DataFrameReader private[sql](sparkSession: SparkSession) - extends api.DataFrameReader[Dataset] { + extends api.DataFrameReader { + override type DS[U] = Dataset[U] + format(sparkSession.sessionState.conf.defaultDataSourceName) /** @inheritdoc */ @@ -174,30 +177,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) @scala.annotation.varargs override def json(paths: String*): DataFrame = super.json(paths: _*) - /** - * Loads a `JavaRDD[String]` storing JSON objects (JSON - * Lines text format or newline-delimited JSON) and returns the result as - * a `DataFrame`. - * - * Unless the schema is specified using `schema` function, this function goes through the - * input once to determine the input schema. - * - * @param jsonRDD input RDD with one JSON object per record - * @since 1.4.0 - */ + /** @inheritdoc */ @deprecated("Use json(Dataset[String]) instead.", "2.2.0") def json(jsonRDD: JavaRDD[String]): DataFrame = json(jsonRDD.rdd) - /** - * Loads an `RDD[String]` storing JSON objects (JSON Lines - * text format or newline-delimited JSON) and returns the result as a `DataFrame`. - * - * Unless the schema is specified using `schema` function, this function goes through the - * input once to determine the input schema. - * - * @param jsonRDD input RDD with one JSON object per record - * @since 1.4.0 - */ + /** @inheritdoc */ @deprecated("Use json(Dataset[String]) instead.", "2.2.0") def json(jsonRDD: RDD[String]): DataFrame = { json(sparkSession.createDataset(jsonRDD)(Encoders.STRING)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index a5ab237bb7041..9f7180d8dfd6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -23,6 +23,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.execution.stat._ import org.apache.spark.sql.functions.col import org.apache.spark.util.ArrayImplicits._ @@ -34,7 +35,7 @@ import org.apache.spark.util.ArrayImplicits._ */ @Stable final class DataFrameStatFunctions private[sql](protected val df: DataFrame) - extends api.DataFrameStatFunctions[Dataset] { + extends api.DataFrameStatFunctions { /** @inheritdoc */ def approxQuantile( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala index 8ffdbb952b082..3b64cb97e10b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala @@ -68,7 +68,7 @@ class DataSourceRegistration private[sql] (dataSourceManager: DataSourceManager) DataSource.lookupDataSource(name, SQLConf.get) throw QueryCompilationErrors.dataSourceAlreadyExists(name) } catch { - case e: SparkClassNotFoundException if e.getErrorClass == "DATA_SOURCE_NOT_FOUND" => // OK + case e: SparkClassNotFoundException if e.getCondition == "DATA_SOURCE_NOT_FOUND" => // OK case _: Throwable => // If there are other errors when resolving the data source, it's unclear whether // it's safe to proceed. To prevent potential lookup errors, treat it as an existing diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c147b6a56e024..b7b96f0c98274 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -52,6 +52,7 @@ import org.apache.spark.sql.catalyst.trees.{TreeNodeTag, TreePattern} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, IntervalUtils} import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression @@ -61,8 +62,7 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation, FileTable} import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.execution.stat.StatFunctions -import org.apache.spark.sql.internal.{DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, SQLConf, ToScalaUDF} -import org.apache.spark.sql.internal.ExpressionUtils.column +import org.apache.spark.sql.internal.{DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, SQLConf} import org.apache.spark.sql.internal.TypedAggUtils.withInputType import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ @@ -215,7 +215,8 @@ private[sql] object Dataset { class Dataset[T] private[sql]( @DeveloperApi @Unstable @transient val queryExecution: QueryExecution, @DeveloperApi @Unstable @transient val encoder: Encoder[T]) - extends api.Dataset[T, Dataset] { + extends api.Dataset[T] { + type DS[U] = Dataset[U] type RGD = RelationalGroupedDataset @transient lazy val sparkSession: SparkSession = { @@ -301,7 +302,7 @@ class Dataset[T] private[sql]( truncate: Int): Seq[Seq[String]] = { val newDf = commandResultOptimized.toDF() val castCols = newDf.logicalPlan.output.map { col => - column(ToPrettyString(col)) + Column(ToPrettyString(col)) } val data = newDf.select(castCols: _*).take(numRows + 1) @@ -411,11 +412,11 @@ class Dataset[T] private[sql]( // Print a footer if (vertical && rows.tail.isEmpty) { // In a vertical mode, print an empty row set explicitly - sb.append("(0 rows)\n") + sb.append("(0 rows)") } else if (hasMoreData) { // For Data that has more than "numRows" records val rowsString = if (numRows == 1) "row" else "rows" - sb.append(s"only showing top $numRows $rowsString\n") + sb.append(s"only showing top $numRows $rowsString") } sb.toString() @@ -503,7 +504,7 @@ class Dataset[T] private[sql]( s"New column names (${colNames.size}): " + colNames.mkString(", ")) val newCols = logicalPlan.output.zip(colNames).map { case (oldAttribute, newName) => - column(oldAttribute).as(newName) + Column(oldAttribute).as(newName) } select(newCols : _*) } @@ -539,13 +540,18 @@ class Dataset[T] private[sql]( def isStreaming: Boolean = logicalPlan.isStreaming /** @inheritdoc */ - protected[sql] def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] = { + protected[sql] def checkpoint( + eager: Boolean, + reliableCheckpoint: Boolean, + storageLevel: Option[StorageLevel]): Dataset[T] = { val actionName = if (reliableCheckpoint) "checkpoint" else "localCheckpoint" withAction(actionName, queryExecution) { physicalPlan => val internalRdd = physicalPlan.execute().map(_.copy()) if (reliableCheckpoint) { + assert(storageLevel.isEmpty, "StorageLevel should not be defined for reliableCheckpoint") internalRdd.checkpoint() } else { + storageLevel.foreach(storageLevel => internalRdd.persist(storageLevel)) internalRdd.localCheckpoint() } @@ -758,18 +764,18 @@ class Dataset[T] private[sql]( /** @inheritdoc */ def col(colName: String): Column = colName match { case "*" => - column(ResolvedStar(queryExecution.analyzed.output)) + Column(ResolvedStar(queryExecution.analyzed.output)) case _ => if (sparkSession.sessionState.conf.supportQuotedRegexColumnName) { colRegex(colName) } else { - column(addDataFrameIdToCol(resolve(colName))) + Column(addDataFrameIdToCol(resolve(colName))) } } /** @inheritdoc */ def metadataColumn(colName: String): Column = - column(queryExecution.analyzed.getMetadataAttributeByName(colName)) + Column(queryExecution.analyzed.getMetadataAttributeByName(colName)) // Attach the dataset id and column position to the column reference, so that we can detect // ambiguous self-join correctly. See the rule `DetectAmbiguousSelfJoin`. @@ -795,11 +801,11 @@ class Dataset[T] private[sql]( val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis colName match { case ParserUtils.escapedIdentifier(columnNameRegex) => - column(UnresolvedRegex(columnNameRegex, None, caseSensitive)) + Column(UnresolvedRegex(columnNameRegex, None, caseSensitive)) case ParserUtils.qualifiedEscapedIdentifier(nameParts, columnNameRegex) => - column(UnresolvedRegex(columnNameRegex, Some(nameParts), caseSensitive)) + Column(UnresolvedRegex(columnNameRegex, Some(nameParts), caseSensitive)) case _ => - column(addDataFrameIdToCol(resolve(colName))) + Column(addDataFrameIdToCol(resolve(colName))) } } @@ -863,24 +869,7 @@ class Dataset[T] private[sql]( Filter(condition.expr, logicalPlan) } - /** - * Groups the Dataset using the specified columns, so we can run aggregation on them. See - * [[RelationalGroupedDataset]] for all the available aggregate functions. - * - * {{{ - * // Compute the average for all numeric columns grouped by department. - * ds.groupBy($"department").avg() - * - * // Compute the max age and average salary, grouped by department and gender. - * ds.groupBy($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * - * @group untypedrel - * @since 2.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def groupBy(cols: Column*): RelationalGroupedDataset = { RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.GroupByType) @@ -912,13 +901,7 @@ class Dataset[T] private[sql]( rdd.reduce(func) } - /** - * (Scala-specific) - * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`. - * - * @group typedrel - * @since 2.0.0 - */ + /** @inheritdoc */ def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { val withGroupingKey = AppendColumns(func, logicalPlan) val executed = sparkSession.sessionState.executePlan(withGroupingKey) @@ -931,16 +914,6 @@ class Dataset[T] private[sql]( withGroupingKey.newColumns) } - /** - * (Java-specific) - * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`. - * - * @group typedrel - * @since 2.0.0 - */ - def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = - groupByKey(ToScalaUDF(func))(encoder) - /** @inheritdoc */ def unpivot( ids: Array[Column], @@ -1225,7 +1198,7 @@ class Dataset[T] private[sql]( resolver(field.name, colName) } match { case Some((colName: String, col: Column)) => col.as(colName) - case _ => column(field) + case _ => Column(field) } } @@ -1295,7 +1268,7 @@ class Dataset[T] private[sql]( val allColumns = queryExecution.analyzed.output val remainingCols = allColumns.filter { attribute => colNames.forall(n => !resolver(attribute.name, n)) - }.map(attribute => column(attribute)) + }.map(attribute => Column(attribute)) if (remainingCols.size == allColumns.size) { toDF() } else { @@ -1556,12 +1529,7 @@ class Dataset[T] private[sql]( sparkSession.sessionState.executePlan(deserialized) } - /** - * Represents the content of the Dataset as an `RDD` of `T`. - * - * @group basic - * @since 1.6.0 - */ + /** @inheritdoc */ lazy val rdd: RDD[T] = { val objectType = exprEnc.deserializer.dataType rddQueryExecution.toRdd.mapPartitions { rows => @@ -1569,20 +1537,9 @@ class Dataset[T] private[sql]( } } - /** - * Returns the content of the Dataset as a `JavaRDD` of `T`s. - * @group basic - * @since 1.6.0 - */ + /** @inheritdoc */ def toJavaRDD: JavaRDD[T] = rdd.toJavaRDD() - /** - * Returns the content of the Dataset as a `JavaRDD` of `T`s. - * @group basic - * @since 1.6.0 - */ - def javaRDD: JavaRDD[T] = toJavaRDD - protected def createTempView( viewName: String, replace: Boolean, @@ -1638,28 +1595,7 @@ class Dataset[T] private[sql]( new DataFrameWriterV2Impl[T](table, this) } - /** - * Merges a set of updates, insertions, and deletions based on a source table into - * a target table. - * - * Scala Examples: - * {{{ - * spark.table("source") - * .mergeInto("target", $"source.id" === $"target.id") - * .whenMatched($"salary" === 100) - * .delete() - * .whenNotMatched() - * .insertAll() - * .whenNotMatchedBySource($"salary" === 100) - * .update(Map( - * "salary" -> lit(200) - * )) - * .merge() - * }}} - * - * @group basic - * @since 4.0.0 - */ + /** @inheritdoc */ def mergeInto(table: String, condition: Column): MergeIntoWriter[T] = { if (isStreaming) { logicalPlan.failAnalysis( @@ -1670,12 +1606,7 @@ class Dataset[T] private[sql]( new MergeIntoWriterImpl[T](table, this, condition) } - /** - * Interface for saving the content of the streaming Dataset out into external storage. - * - * @group basic - * @since 2.0.0 - */ + /** @inheritdoc */ def writeStream: DataStreamWriter[T] = { if (!isStreaming) { logicalPlan.failAnalysis( @@ -1868,6 +1799,10 @@ class Dataset[T] private[sql]( /** @inheritdoc */ override def localCheckpoint(eager: Boolean): Dataset[T] = super.localCheckpoint(eager) + /** @inheritdoc */ + override def localCheckpoint(eager: Boolean, storageLevel: StorageLevel): Dataset[T] = + super.localCheckpoint(eager, storageLevel) + /** @inheritdoc */ override def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = super.joinWith(other, condition) @@ -2022,10 +1957,24 @@ class Dataset[T] private[sql]( @scala.annotation.varargs override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*) + /** @inheritdoc */ + override def groupByKey[K]( + func: MapFunction[T, K], + encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = + super.groupByKey(func, encoder).asInstanceOf[KeyValueGroupedDataset[K, T]] + //////////////////////////////////////////////////////////////////////////// // For Python API //////////////////////////////////////////////////////////////////////////// + /** + * It adds a new long column with the name `name` that increases one by one. + * This is for 'distributed-sequence' default index in pandas API on Spark. + */ + private[sql] def withSequenceColumn(name: String) = { + select(Column(DistributedSequenceID()).alias(name), col("*")) + } + /** * Converts a JavaRDD to a PythonRDD. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index fcad1b721eaca..c645ba57e8f82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderF import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator import org.apache.spark.sql.internal.TypedAggUtils.{aggKeyColumn, withInputType} @@ -41,7 +42,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( @transient val queryExecution: QueryExecution, private val dataAttributes: Seq[Attribute], private val groupingAttributes: Seq[Attribute]) - extends api.KeyValueGroupedDataset[K, V, Dataset] { + extends api.KeyValueGroupedDataset[K, V] { type KVDS[KY, VL] = KeyValueGroupedDataset[KY, VL] private implicit def kEncoderImpl: Encoder[K] = kEncoder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index da4609135fd63..bd47a21a1e09b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.internal.ExpressionUtils.{column, generateAlias} @@ -52,8 +53,8 @@ class RelationalGroupedDataset protected[sql]( protected[sql] val df: DataFrame, private[sql] val groupingExprs: Seq[Expression], groupType: RelationalGroupedDataset.GroupType) - extends api.RelationalGroupedDataset[Dataset] { - type RGD = RelationalGroupedDataset + extends api.RelationalGroupedDataset { + import RelationalGroupedDataset._ import df.sparkSession._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index ffcc0b923f2cb..636899a7acb06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -251,8 +251,8 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group basic * @since 1.3.0 */ - object implicits extends SQLImplicits with Serializable { - protected override def session: SparkSession = self.sparkSession + object implicits extends SQLImplicits { + override protected def session: SparkSession = sparkSession } // scalastyle:on diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index a657836aafbea..b6ed50447109d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -17,259 +17,9 @@ package org.apache.spark.sql -import scala.collection.Map -import scala.language.implicitConversions -import scala.reflect.runtime.universe.TypeTag - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder - -/** - * A collection of implicit methods for converting common Scala objects into [[Dataset]]s. - * - * @since 1.6.0 - */ -abstract class SQLImplicits extends LowPrioritySQLImplicits { +/** @inheritdoc */ +abstract class SQLImplicits extends api.SQLImplicits { + type DS[U] = Dataset[U] protected def session: SparkSession - - /** - * Converts $"col name" into a [[Column]]. - * - * @since 2.0.0 - */ - implicit class StringToColumn(val sc: StringContext) { - def $(args: Any*): ColumnName = { - new ColumnName(sc.s(args: _*)) - } - } - - // Primitives - - /** @since 1.6.0 */ - implicit def newIntEncoder: Encoder[Int] = Encoders.scalaInt - - /** @since 1.6.0 */ - implicit def newLongEncoder: Encoder[Long] = Encoders.scalaLong - - /** @since 1.6.0 */ - implicit def newDoubleEncoder: Encoder[Double] = Encoders.scalaDouble - - /** @since 1.6.0 */ - implicit def newFloatEncoder: Encoder[Float] = Encoders.scalaFloat - - /** @since 1.6.0 */ - implicit def newByteEncoder: Encoder[Byte] = Encoders.scalaByte - - /** @since 1.6.0 */ - implicit def newShortEncoder: Encoder[Short] = Encoders.scalaShort - - /** @since 1.6.0 */ - implicit def newBooleanEncoder: Encoder[Boolean] = Encoders.scalaBoolean - - /** @since 1.6.0 */ - implicit def newStringEncoder: Encoder[String] = Encoders.STRING - - /** @since 2.2.0 */ - implicit def newJavaDecimalEncoder: Encoder[java.math.BigDecimal] = Encoders.DECIMAL - - /** @since 2.2.0 */ - implicit def newScalaDecimalEncoder: Encoder[scala.math.BigDecimal] = ExpressionEncoder() - - /** @since 2.2.0 */ - implicit def newDateEncoder: Encoder[java.sql.Date] = Encoders.DATE - - /** @since 3.0.0 */ - implicit def newLocalDateEncoder: Encoder[java.time.LocalDate] = Encoders.LOCALDATE - - /** @since 3.4.0 */ - implicit def newLocalDateTimeEncoder: Encoder[java.time.LocalDateTime] = Encoders.LOCALDATETIME - - /** @since 2.2.0 */ - implicit def newTimeStampEncoder: Encoder[java.sql.Timestamp] = Encoders.TIMESTAMP - - /** @since 3.0.0 */ - implicit def newInstantEncoder: Encoder[java.time.Instant] = Encoders.INSTANT - - /** @since 3.2.0 */ - implicit def newDurationEncoder: Encoder[java.time.Duration] = Encoders.DURATION - - /** @since 3.2.0 */ - implicit def newPeriodEncoder: Encoder[java.time.Period] = Encoders.PERIOD - - /** @since 3.2.0 */ - implicit def newJavaEnumEncoder[A <: java.lang.Enum[_] : TypeTag]: Encoder[A] = - ExpressionEncoder() - - // Boxed primitives - - /** @since 2.0.0 */ - implicit def newBoxedIntEncoder: Encoder[java.lang.Integer] = Encoders.INT - - /** @since 2.0.0 */ - implicit def newBoxedLongEncoder: Encoder[java.lang.Long] = Encoders.LONG - - /** @since 2.0.0 */ - implicit def newBoxedDoubleEncoder: Encoder[java.lang.Double] = Encoders.DOUBLE - - /** @since 2.0.0 */ - implicit def newBoxedFloatEncoder: Encoder[java.lang.Float] = Encoders.FLOAT - - /** @since 2.0.0 */ - implicit def newBoxedByteEncoder: Encoder[java.lang.Byte] = Encoders.BYTE - - /** @since 2.0.0 */ - implicit def newBoxedShortEncoder: Encoder[java.lang.Short] = Encoders.SHORT - - /** @since 2.0.0 */ - implicit def newBoxedBooleanEncoder: Encoder[java.lang.Boolean] = Encoders.BOOLEAN - - // Seqs - - /** - * @since 1.6.1 - * @deprecated use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder() - - /** - * @since 1.6.1 - * @deprecated use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder() - - /** - * @since 1.6.1 - * @deprecated use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder() - - /** - * @since 1.6.1 - * @deprecated use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder() - - /** - * @since 1.6.1 - * @deprecated use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder() - - /** - * @since 1.6.1 - * @deprecated use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder() - - /** - * @since 1.6.1 - * @deprecated use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder() - - /** - * @since 1.6.1 - * @deprecated use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder() - - /** - * @since 1.6.1 - * @deprecated use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder() - - /** @since 2.2.0 */ - implicit def newSequenceEncoder[T <: Seq[_] : TypeTag]: Encoder[T] = ExpressionEncoder() - - // Maps - /** @since 2.3.0 */ - implicit def newMapEncoder[T <: Map[_, _] : TypeTag]: Encoder[T] = ExpressionEncoder() - - /** - * Notice that we serialize `Set` to Catalyst array. The set property is only kept when - * manipulating the domain objects. The serialization format doesn't keep the set property. - * When we have a Catalyst array which contains duplicated elements and convert it to - * `Dataset[Set[T]]` by using the encoder, the elements will be de-duplicated. - * - * @since 2.3.0 - */ - implicit def newSetEncoder[T <: Set[_] : TypeTag]: Encoder[T] = ExpressionEncoder() - - // Arrays - - /** @since 1.6.1 */ - implicit def newIntArrayEncoder: Encoder[Array[Int]] = ExpressionEncoder() - - /** @since 1.6.1 */ - implicit def newLongArrayEncoder: Encoder[Array[Long]] = ExpressionEncoder() - - /** @since 1.6.1 */ - implicit def newDoubleArrayEncoder: Encoder[Array[Double]] = ExpressionEncoder() - - /** @since 1.6.1 */ - implicit def newFloatArrayEncoder: Encoder[Array[Float]] = ExpressionEncoder() - - /** @since 1.6.1 */ - implicit def newByteArrayEncoder: Encoder[Array[Byte]] = Encoders.BINARY - - /** @since 1.6.1 */ - implicit def newShortArrayEncoder: Encoder[Array[Short]] = ExpressionEncoder() - - /** @since 1.6.1 */ - implicit def newBooleanArrayEncoder: Encoder[Array[Boolean]] = ExpressionEncoder() - - /** @since 1.6.1 */ - implicit def newStringArrayEncoder: Encoder[Array[String]] = ExpressionEncoder() - - /** @since 1.6.1 */ - implicit def newProductArrayEncoder[A <: Product : TypeTag]: Encoder[Array[A]] = - ExpressionEncoder() - - /** - * Creates a [[Dataset]] from an RDD. - * - * @since 1.6.0 - */ - implicit def rddToDatasetHolder[T : Encoder](rdd: RDD[T]): DatasetHolder[T] = { - DatasetHolder(session.createDataset(rdd)) - } - - /** - * Creates a [[Dataset]] from a local Seq. - * @since 1.6.0 - */ - implicit def localSeqToDatasetHolder[T : Encoder](s: Seq[T]): DatasetHolder[T] = { - DatasetHolder(session.createDataset(s)) - } - - /** - * An implicit conversion that turns a Scala `Symbol` into a [[Column]]. - * @since 1.3.0 - */ - implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) - -} - -/** - * Lower priority implicit methods for converting Scala objects into [[Dataset]]s. - * Conflicting implicits are placed here to disambiguate resolution. - * - * Reasons for including specific implicits: - * newProductEncoder - to disambiguate for `List`s which are both `Seq` and `Product` - */ -trait LowPrioritySQLImplicits { - /** @since 1.6.0 */ - implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = Encoders.product[T] - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 720b77b0b9fe5..99ab3ca69fb20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -21,7 +21,7 @@ import java.net.URI import java.nio.file.Paths import java.util.{ServiceLoader, UUID} import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} +import java.util.concurrent.atomic.AtomicBoolean import scala.concurrent.duration.DurationInt import scala.jdk.CollectionConverters._ @@ -96,7 +96,7 @@ class SparkSession private( @transient private[sql] val extensions: SparkSessionExtensions, @transient private[sql] val initialSessionOptions: Map[String, String], @transient private val parentManagedJobTags: Map[String, String]) - extends api.SparkSession[Dataset] with Logging { self => + extends api.SparkSession with Logging { self => // The call site where this SparkSession was constructed. private val creationSite: CallSite = Utils.getCallSite() @@ -229,12 +229,7 @@ class SparkSession private( @Unstable def dataSource: DataSourceRegistration = sessionState.dataSourceRegistration - /** - * Returns a `StreamingQueryManager` that allows managing all the - * `StreamingQuery`s active on `this`. - * - * @since 2.0.0 - */ + /** @inheritdoc */ @Unstable def streams: StreamingQueryManager = sessionState.streamingQueryManager @@ -299,11 +294,7 @@ class SparkSession private( new Dataset(self, LocalRelation(encoder.schema), encoder) } - /** - * Creates a `DataFrame` from an RDD of Product (e.g. case classes, tuples). - * - * @since 2.0.0 - */ + /** @inheritdoc */ def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = withActive { val encoder = Encoders.product[A] Dataset.ofRows(self, ExternalRDD(rdd, self)(encoder)) @@ -316,37 +307,7 @@ class SparkSession private( Dataset.ofRows(self, LocalRelation.fromProduct(attributeSeq, data)) } - /** - * :: DeveloperApi :: - * Creates a `DataFrame` from an `RDD` containing [[Row]]s using the given schema. - * It is important to make sure that the structure of every [[Row]] of the provided RDD matches - * the provided schema. Otherwise, there will be runtime exception. - * Example: - * {{{ - * import org.apache.spark.sql._ - * import org.apache.spark.sql.types._ - * val sparkSession = new org.apache.spark.sql.SparkSession(sc) - * - * val schema = - * StructType( - * StructField("name", StringType, false) :: - * StructField("age", IntegerType, true) :: Nil) - * - * val people = - * sc.textFile("examples/src/main/resources/people.txt").map( - * _.split(",")).map(p => Row(p(0), p(1).trim.toInt)) - * val dataFrame = sparkSession.createDataFrame(people, schema) - * dataFrame.printSchema - * // root - * // |-- name: string (nullable = false) - * // |-- age: integer (nullable = true) - * - * dataFrame.createOrReplaceTempView("people") - * sparkSession.sql("select name from people").collect.foreach(println) - * }}} - * - * @since 2.0.0 - */ + /** @inheritdoc */ @DeveloperApi def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = withActive { val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] @@ -358,14 +319,7 @@ class SparkSession private( internalCreateDataFrame(catalystRows.setName(rowRDD.name), schema) } - /** - * :: DeveloperApi :: - * Creates a `DataFrame` from a `JavaRDD` containing [[Row]]s using the given schema. - * It is important to make sure that the structure of every [[Row]] of the provided RDD matches - * the provided schema. Otherwise, there will be runtime exception. - * - * @since 2.0.0 - */ + /** @inheritdoc */ @DeveloperApi def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] @@ -379,14 +333,7 @@ class SparkSession private( Dataset.ofRows(self, LocalRelation.fromExternalRows(toAttributes(replaced), rows.asScala.toSeq)) } - /** - * Applies a schema to an RDD of Java Beans. - * - * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, - * SELECT * queries will return the columns in an undefined order. - * - * @since 2.0.0 - */ + /** @inheritdoc */ def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = withActive { val attributeSeq: Seq[AttributeReference] = getSchema(beanClass) val className = beanClass.getName @@ -397,14 +344,7 @@ class SparkSession private( Dataset.ofRows(self, LogicalRDD(attributeSeq, rowRdd.setName(rdd.name))(self)) } - /** - * Applies a schema to an RDD of Java Beans. - * - * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, - * SELECT * queries will return the columns in an undefined order. - * - * @since 2.0.0 - */ + /** @inheritdoc */ def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { createDataFrame(rdd.rdd, beanClass) } @@ -439,14 +379,7 @@ class SparkSession private( Dataset[T](self, plan) } - /** - * Creates a [[Dataset]] from an RDD of a given type. This method requires an - * encoder (to convert a JVM object of type `T` to and from the internal Spark SQL representation) - * that is generally created automatically through implicits from a `SparkSession`, or can be - * created explicitly by calling static methods on [[Encoders]]. - * - * @since 2.0.0 - */ + /** @inheritdoc */ def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = { Dataset[T](self, ExternalRDD(data, self)) } @@ -739,32 +672,13 @@ class SparkSession private( /** @inheritdoc */ def read: DataFrameReader = new DataFrameReader(self) - /** - * Returns a `DataStreamReader` that can be used to read streaming data in as a `DataFrame`. - * {{{ - * sparkSession.readStream.parquet("/path/to/directory/of/parquet/files") - * sparkSession.readStream.schema(schema).json("/path/to/directory/of/json/files") - * }}} - * - * @since 2.0.0 - */ + /** @inheritdoc */ def readStream: DataStreamReader = new DataStreamReader(self) // scalastyle:off // Disable style checker so "implicits" object can start with lowercase i - /** - * (Scala-specific) Implicit methods available in Scala for converting - * common Scala objects into `DataFrame`s. - * - * {{{ - * val sparkSession = SparkSession.builder.getOrCreate() - * import sparkSession.implicits._ - * }}} - * - * @since 2.0.0 - */ - object implicits extends SQLImplicits with Serializable { - protected override def session: SparkSession = SparkSession.this + object implicits extends SQLImplicits { + override protected def session: SparkSession = self } // scalastyle:on @@ -829,7 +743,7 @@ class SparkSession private( // Use the active session thread local directly to make sure we get the session that is actually // set and not the default session. This to prevent that we promote the default session to the // active session once we are done. - val old = SparkSession.activeThreadSession.get() + val old = SparkSession.getActiveSession.orNull SparkSession.setActiveSession(this) try block finally { SparkSession.setActiveSession(old) @@ -860,133 +774,71 @@ class SparkSession private( } private[sql] lazy val observationManager = new ObservationManager(this) + + override private[sql] def isUsable: Boolean = !sparkContext.isStopped } @Stable -object SparkSession extends Logging { +object SparkSession extends api.BaseSparkSessionCompanion with Logging { + override private[sql] type Session = SparkSession /** * Builder for [[SparkSession]]. */ @Stable - class Builder extends Logging { - - private[this] val options = new scala.collection.mutable.HashMap[String, String] + class Builder extends api.SparkSessionBuilder { private[this] val extensions = new SparkSessionExtensions private[this] var userSuppliedContext: Option[SparkContext] = None - private[spark] def sparkContext(sparkContext: SparkContext): Builder = synchronized { + private[spark] def sparkContext(sparkContext: SparkContext): this.type = synchronized { userSuppliedContext = Option(sparkContext) this } - /** - * Sets a name for the application, which will be shown in the Spark web UI. - * If no application name is set, a randomly generated name will be used. - * - * @since 2.0.0 - */ - def appName(name: String): Builder = config("spark.app.name", name) + /** @inheritdoc */ + override def remote(connectionString: String): this.type = this - /** - * Sets a config option. Options set using this method are automatically propagated to - * both `SparkConf` and SparkSession's own configuration. - * - * @since 2.0.0 - */ - def config(key: String, value: String): Builder = synchronized { - options += key -> value - this - } + /** @inheritdoc */ + override def appName(name: String): this.type = super.appName(name) - /** - * Sets a config option. Options set using this method are automatically propagated to - * both `SparkConf` and SparkSession's own configuration. - * - * @since 2.0.0 - */ - def config(key: String, value: Long): Builder = synchronized { - options += key -> value.toString - this - } + /** @inheritdoc */ + override def config(key: String, value: String): this.type = super.config(key, value) - /** - * Sets a config option. Options set using this method are automatically propagated to - * both `SparkConf` and SparkSession's own configuration. - * - * @since 2.0.0 - */ - def config(key: String, value: Double): Builder = synchronized { - options += key -> value.toString - this - } + /** @inheritdoc */ + override def config(key: String, value: Long): this.type = super.config(key, value) - /** - * Sets a config option. Options set using this method are automatically propagated to - * both `SparkConf` and SparkSession's own configuration. - * - * @since 2.0.0 - */ - def config(key: String, value: Boolean): Builder = synchronized { - options += key -> value.toString - this - } + /** @inheritdoc */ + override def config(key: String, value: Double): this.type = super.config(key, value) - /** - * Sets a config option. Options set using this method are automatically propagated to - * both `SparkConf` and SparkSession's own configuration. - * - * @since 3.4.0 - */ - def config(map: Map[String, Any]): Builder = synchronized { - map.foreach { - kv: (String, Any) => { - options += kv._1 -> kv._2.toString - } - } - this - } + /** @inheritdoc */ + override def config(key: String, value: Boolean): this.type = super.config(key, value) - /** - * Sets a config option. Options set using this method are automatically propagated to - * both `SparkConf` and SparkSession's own configuration. - * - * @since 3.4.0 - */ - def config(map: java.util.Map[String, Any]): Builder = synchronized { - config(map.asScala.toMap) - } + /** @inheritdoc */ + override def config(map: Map[String, Any]): this.type = super.config(map) + + /** @inheritdoc */ + override def config(map: java.util.Map[String, Any]): this.type = super.config(map) /** * Sets a list of config options based on the given `SparkConf`. * * @since 2.0.0 */ - def config(conf: SparkConf): Builder = synchronized { + def config(conf: SparkConf): this.type = synchronized { conf.getAll.foreach { case (k, v) => options += k -> v } this } - /** - * Sets the Spark master URL to connect to, such as "local" to run locally, "local[4]" to - * run locally with 4 cores, or "spark://master:7077" to run on a Spark standalone cluster. - * - * @since 2.0.0 - */ - def master(master: String): Builder = config("spark.master", master) + /** @inheritdoc */ + override def master(master: String): this.type = super.master(master) - /** - * Enables Hive support, including connectivity to a persistent Hive metastore, support for - * Hive serdes, and Hive user-defined functions. - * - * @since 2.0.0 - */ - def enableHiveSupport(): Builder = synchronized { + /** @inheritdoc */ + override def enableHiveSupport(): this.type = synchronized { if (hiveClassesArePresent) { - config(CATALOG_IMPLEMENTATION.key, "hive") + super.enableHiveSupport() } else { throw new IllegalArgumentException( "Unable to instantiate SparkSession with Hive support because " + @@ -1000,27 +852,12 @@ object SparkSession extends Logging { * * @since 2.2.0 */ - def withExtensions(f: SparkSessionExtensions => Unit): Builder = synchronized { + def withExtensions(f: SparkSessionExtensions => Unit): this.type = synchronized { f(extensions) this } - /** - * Gets an existing [[SparkSession]] or, if there is no existing one, creates a new - * one based on the options set in this builder. - * - * This method first checks whether there is a valid thread-local SparkSession, - * and if yes, return that one. It then checks whether there is a valid global - * default SparkSession, and if yes, return that one. If no valid global default - * SparkSession exists, the method creates a new SparkSession and assigns the - * newly created SparkSession as the global default. - * - * In case an existing SparkSession is returned, the non-static config options specified in - * this builder will be applied to the existing SparkSession. - * - * @since 2.0.0 - */ - def getOrCreate(): SparkSession = synchronized { + private def build(forceCreate: Boolean): SparkSession = synchronized { val sparkConf = new SparkConf() options.foreach { case (k, v) => sparkConf.set(k, v) } @@ -1029,8 +866,9 @@ object SparkSession extends Logging { } // Get the session from current thread's active session. - var session = activeThreadSession.get() - if ((session ne null) && !session.sparkContext.isStopped) { + val active = getActiveSession + if (!forceCreate && active.isDefined) { + val session = active.get applyModifiableSettings(session, new java.util.HashMap[String, String](options.asJava)) return session } @@ -1038,8 +876,9 @@ object SparkSession extends Logging { // Global synchronization so we will only set the default session once. SparkSession.synchronized { // If the current thread does not have an active session, get it from the global session. - session = defaultSession.get() - if ((session ne null) && !session.sparkContext.isStopped) { + val default = getDefaultSession + if (!forceCreate && default.isDefined) { + val session = default.get applyModifiableSettings(session, new java.util.HashMap[String, String](options.asJava)) return session } @@ -1058,109 +897,43 @@ object SparkSession extends Logging { loadExtensions(extensions) applyExtensions(sparkContext, extensions) - session = new SparkSession(sparkContext, + val session = new SparkSession(sparkContext, existingSharedState = None, parentSessionState = None, extensions, initialSessionOptions = options.toMap, parentManagedJobTags = Map.empty) - setDefaultSession(session) - setActiveSession(session) + setDefaultAndActiveSession(session) registerContextListener(sparkContext) + session } - - return session } - } - /** - * Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]]. - * - * @since 2.0.0 - */ - def builder(): Builder = new Builder - - /** - * Changes the SparkSession that will be returned in this thread and its children when - * SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives - * a SparkSession with an isolated session, instead of the global (first created) context. - * - * @since 2.0.0 - */ - def setActiveSession(session: SparkSession): Unit = { - activeThreadSession.set(session) - } + /** @inheritdoc */ + def getOrCreate(): SparkSession = build(forceCreate = false) - /** - * Clears the active SparkSession for current thread. Subsequent calls to getOrCreate will - * return the first created context instead of a thread-local override. - * - * @since 2.0.0 - */ - def clearActiveSession(): Unit = { - activeThreadSession.remove() + /** @inheritdoc */ + def create(): SparkSession = build(forceCreate = true) } /** - * Sets the default SparkSession that is returned by the builder. + * Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]]. * * @since 2.0.0 */ - def setDefaultSession(session: SparkSession): Unit = { - defaultSession.set(session) - } + def builder(): Builder = new Builder - /** - * Clears the default SparkSession that is returned by the builder. - * - * @since 2.0.0 - */ - def clearDefaultSession(): Unit = { - defaultSession.set(null) - } + /** @inheritdoc */ + override def getActiveSession: Option[SparkSession] = super.getActiveSession - /** - * Returns the active SparkSession for the current thread, returned by the builder. - * - * @note Return None, when calling this function on executors - * - * @since 2.2.0 - */ - def getActiveSession: Option[SparkSession] = { - if (Utils.isInRunningSparkTask) { - // Return None when running on executors. - None - } else { - Option(activeThreadSession.get) - } - } + /** @inheritdoc */ + override def getDefaultSession: Option[SparkSession] = super.getDefaultSession - /** - * Returns the default SparkSession that is returned by the builder. - * - * @note Return None, when calling this function on executors - * - * @since 2.2.0 - */ - def getDefaultSession: Option[SparkSession] = { - if (Utils.isInRunningSparkTask) { - // Return None when running on executors. - None - } else { - Option(defaultSession.get) - } - } + /** @inheritdoc */ + override def active: SparkSession = super.active - /** - * Returns the currently active SparkSession, otherwise the default one. If there is no default - * SparkSession, throws an exception. - * - * @since 2.4.0 - */ - def active: SparkSession = { - getActiveSession.getOrElse(getDefaultSession.getOrElse( - throw SparkException.internalError("No active or default Spark session found"))) - } + override protected def canUseSession(session: SparkSession): Boolean = + session.isUsable && !Utils.isInRunningSparkTask /** * Apply modifiable settings to an existing [[SparkSession]]. This method are used @@ -1231,7 +1004,8 @@ object SparkSession extends Logging { if (!listenerRegistered.get()) { sparkContext.addSparkListener(new SparkListener { override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { - defaultSession.set(null) + clearDefaultSession() + clearActiveSession() listenerRegistered.set(false) } }) @@ -1239,12 +1013,6 @@ object SparkSession extends Logging { } } - /** The active SparkSession for the current thread. */ - private val activeThreadSession = new InheritableThreadLocal[SparkSession] - - /** Reference to the root SparkSession. */ - private val defaultSession = new AtomicReference[SparkSession] - private val HIVE_SESSION_STATE_BUILDER_CLASS_NAME = "org.apache.spark.sql.hive.HiveSessionStateBuilder" @@ -1332,7 +1100,7 @@ object SparkSession extends Logging { private def applyExtensions( sparkContext: SparkContext, extensions: SparkSessionExtensions): SparkSessionExtensions = { - val extensionConfClassNames = sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS) + val extensionConfClassNames = sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS) .getOrElse(Seq.empty) extensionConfClassNames.foreach { extensionConfClassName => try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 93082740cca64..08395ef4c347c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -29,14 +29,13 @@ import org.apache.spark.internal.LogKeys.CLASS_LOADER import org.apache.spark.security.SocketAuthServer import org.apache.spark.sql.{internal, Column, DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TableFunctionRegistry} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.execution.{ExplainMode, QueryExecution} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.execution.python.EvaluatePython -import org.apache.spark.sql.functions.lit import org.apache.spark.sql.internal.ExpressionUtils.{column, expression} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} @@ -69,7 +68,10 @@ private[sql] object PythonSQLUtils extends Logging { // This is needed when generating SQL documentation for built-in functions. def listBuiltinFunctionInfos(): Array[ExpressionInfo] = { - FunctionRegistry.functionSet.flatMap(f => FunctionRegistry.builtin.lookupFunction(f)).toArray + (FunctionRegistry.functionSet.flatMap(f => FunctionRegistry.builtin.lookupFunction(f)) ++ + TableFunctionRegistry.functionSet.flatMap( + f => TableFunctionRegistry.builtin.lookupFunction(f))). + groupBy(_.getName).map(v => v._2.head).toArray } private def listAllSQLConfigs(): Seq[(String, String, String, String)] = { @@ -141,48 +143,6 @@ private[sql] object PythonSQLUtils extends Logging { } } - def castTimestampNTZToLong(c: Column): Column = - Column.internalFn("timestamp_ntz_to_long", c) - - def ewm(e: Column, alpha: Double, ignoreNA: Boolean): Column = - Column.internalFn("ewm", e, lit(alpha), lit(ignoreNA)) - - def nullIndex(e: Column): Column = Column.internalFn("null_index", e) - - def collect_top_k(e: Column, num: Int, reverse: Boolean): Column = - Column.internalFn("collect_top_k", e, lit(num), lit(reverse)) - - def binary_search(e: Column, value: Column): Column = - Column.internalFn("array_binary_search", e, value) - - def pandasProduct(e: Column, ignoreNA: Boolean): Column = - Column.internalFn("pandas_product", e, lit(ignoreNA)) - - def pandasStddev(e: Column, ddof: Int): Column = - Column.internalFn("pandas_stddev", e, lit(ddof)) - - def pandasVariance(e: Column, ddof: Int): Column = - Column.internalFn("pandas_var", e, lit(ddof)) - - def pandasSkewness(e: Column): Column = - Column.internalFn("pandas_skew", e) - - def pandasKurtosis(e: Column): Column = - Column.internalFn("pandas_kurt", e) - - def pandasMode(e: Column, ignoreNA: Boolean): Column = - Column.internalFn("pandas_mode", e, lit(ignoreNA)) - - def pandasCovar(col1: Column, col2: Column, ddof: Int): Column = - Column.internalFn("pandas_covar", col1, col2, lit(ddof)) - - /** - * A long column that increases one by one. - * This is for 'distributed-sequence' default index in pandas API on Spark. - */ - def distributed_sequence_id(): Column = - Column.internalFn("distributed_sequence_id") - def unresolvedNamedLambdaVariable(name: String): Column = Column(internal.UnresolvedNamedLambdaVariable.apply(name)) @@ -202,6 +162,9 @@ private[sql] object PythonSQLUtils extends Logging { @scala.annotation.varargs def fn(name: String, arguments: Column*): Column = Column.fn(name, arguments: _*) + + @scala.annotation.varargs + def internalFn(name: String, inputs: Column*): Column = Column.internalFn(name, inputs: _*) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 661e43fe73cae..c39018ff06fca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.catalog import java.util import org.apache.spark.sql.{api, DataFrame, Dataset} +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.types.StructType /** @inheritdoc */ -abstract class Catalog extends api.Catalog[Dataset] { +abstract class Catalog extends api.Catalog { /** @inheritdoc */ override def listDatabases(): Dataset[Database] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/InvokeProcedures.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/InvokeProcedures.scala new file mode 100644 index 0000000000000..c7320d350a7ff --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/InvokeProcedures.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import scala.jdk.CollectionConverters.IteratorHasAsScala + +import org.apache.spark.SparkException +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow} +import org.apache.spark.sql.catalyst.plans.logical.{Call, LocalRelation, LogicalPlan, MultiResult} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure +import org.apache.spark.sql.connector.read.{LocalScan, Scan} +import org.apache.spark.util.ArrayImplicits._ + +class InvokeProcedures(session: SparkSession) extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case c: Call if c.resolved && c.bound && c.execute && c.checkArgTypes().isSuccess => + session.sessionState.optimizer.execute(c) match { + case Call(ResolvedProcedure(_, _, procedure: BoundProcedure), args, _) => + invoke(procedure, args) + case _ => + throw SparkException.internalError("Unexpected plan for optimized CALL statement") + } + } + + private def invoke(procedure: BoundProcedure, args: Seq[Expression]): LogicalPlan = { + val input = toInternalRow(args) + val scanIterator = procedure.call(input) + val relations = scanIterator.asScala.map(toRelation).toSeq + relations match { + case Nil => LocalRelation(Nil) + case Seq(relation) => relation + case _ => MultiResult(relations) + } + } + + private def toRelation(scan: Scan): LogicalPlan = scan match { + case s: LocalScan => + val attrs = DataTypeUtils.toAttributes(s.readSchema) + val data = s.rows.toImmutableArraySeq + LocalRelation(attrs, data) + case _ => + throw SparkException.internalError( + s"Only local scans are temporarily supported as procedure output: ${scan.getClass.getName}") + } + + private def toInternalRow(args: Seq[Expression]): InternalRow = { + require(args.forall(_.foldable), "args must be foldable") + val values = args.map(_.eval()).toArray + new GenericInternalRow(values) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 02ad2e79a5645..884c870e8eed3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -26,9 +26,9 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, toPrettySQL, ResolveDefaultColumns => DefaultCols} +import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, toPrettySQL, CharVarcharUtils, ResolveDefaultColumns => DefaultCols} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ -import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, CatalogV2Util, DelegatingCatalogExtension, LookupCatalog, SupportsNamespaces, V1Table} +import org.apache.spark.sql.connector.catalog.{CatalogExtension, CatalogManager, CatalogPlugin, CatalogV2Util, LookupCatalog, SupportsNamespaces, V1Table} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command._ @@ -36,7 +36,7 @@ import org.apache.spark.sql.execution.datasources.{CreateTable => CreateTableV1} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.internal.connector.V1Function -import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType} +import org.apache.spark.sql.types.{MetadataBuilder, StringType, StructField, StructType} import org.apache.spark.util.ArrayImplicits._ /** @@ -87,7 +87,11 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) val colName = a.column.name(0) val dataType = a.dataType.getOrElse { table.schema.findNestedField(Seq(colName), resolver = conf.resolver) - .map(_._2.dataType) + .map { + case (_, StructField(_, st: StringType, _, metadata)) => + CharVarcharUtils.getRawType(metadata).getOrElse(st) + case (_, field) => field.dataType + } .getOrElse { throw QueryCompilationErrors.unresolvedColumnError( toSQLId(a.column.name), table.schema.fieldNames) @@ -706,6 +710,6 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) private def supportsV1Command(catalog: CatalogPlugin): Boolean = { isSessionCatalog(catalog) && ( SQLConf.get.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isEmpty || - catalog.isInstanceOf[DelegatingCatalogExtension]) + catalog.isInstanceOf[CatalogExtension]) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/ClassicConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/ClassicConversions.scala new file mode 100644 index 0000000000000..8c3223fa72f55 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/ClassicConversions.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.classic + +import scala.language.implicitConversions + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.internal.ExpressionUtils + +/** + * Conversions from sql interfaces to the Classic specific implementation. + * + * This class is mainly used by the implementation. It is also meant to be used by extension + * developers. + * + * We provide both a trait and an object. The trait is useful in situations where an extension + * developer needs to use these conversions in a project covering multiple Spark versions. They can + * create a shim for these conversions, the Spark 4+ version of the shim implements this trait, and + * shims for older versions do not. + */ +@DeveloperApi +trait ClassicConversions { + implicit def castToImpl(session: api.SparkSession): SparkSession = + session.asInstanceOf[SparkSession] + + implicit def castToImpl[T](ds: api.Dataset[T]): Dataset[T] = + ds.asInstanceOf[Dataset[T]] + + implicit def castToImpl(rgds: api.RelationalGroupedDataset): RelationalGroupedDataset = + rgds.asInstanceOf[RelationalGroupedDataset] + + implicit def castToImpl[K, V](kvds: api.KeyValueGroupedDataset[K, V]) + : KeyValueGroupedDataset[K, V] = kvds.asInstanceOf[KeyValueGroupedDataset[K, V]] + + /** + * Helper that makes it easy to construct a Column from an Expression. + */ + implicit class ColumnConstructorExt(val c: Column.type) { + def apply(e: Expression): Column = ExpressionUtils.column(e) + } +} + +object ClassicConversions extends ClassicConversions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala index dc918e51d0550..0a487bac77696 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala @@ -60,9 +60,13 @@ case class CollectMetricsExec( override def outputOrdering: Seq[SortOrder] = child.outputOrdering + override def resetMetrics(): Unit = { + accumulator.reset() + super.resetMetrics() + } + override protected def doExecute(): RDD[InternalRow] = { val collector = accumulator - collector.reset() child.execute().mapPartitions { rows => // Only publish the value of the accumulator when the task has completed. This is done by // updating a task local accumulator ('updater') which will be merged with the actual diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/EmptyRelationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/EmptyRelationExec.scala index 8a544de7567e8..a0c3d7b51c2c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/EmptyRelationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/EmptyRelationExec.scala @@ -71,13 +71,15 @@ case class EmptyRelationExec(@transient logical: LogicalPlan) extends LeafExecNo maxFields, printNodeId, indent) - lastChildren.add(true) - logical.generateTreeString( - depth + 1, lastChildren, append, verbose, "", false, maxFields, printNodeId, indent) - lastChildren.remove(lastChildren.size() - 1) + Option(logical).foreach { _ => + lastChildren.add(true) + logical.generateTreeString( + depth + 1, lastChildren, append, verbose, "", false, maxFields, printNodeId, indent) + lastChildren.remove(lastChildren.size() - 1) + } } override def doCanonicalize(): SparkPlan = { - this.copy(logical = LocalRelation(logical.output).canonicalized) + this.copy(logical = LocalRelation(output).canonicalized) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/MultiResultExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/MultiResultExec.scala new file mode 100644 index 0000000000000..c2b12b053c927 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/MultiResultExec.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class MultiResultExec(children: Seq[SparkPlan]) extends SparkPlan { + + override def output: Seq[Attribute] = children.lastOption.map(_.output).getOrElse(Nil) + + override protected def doExecute(): RDD[InternalRow] = { + children.lastOption.map(_.execute()).getOrElse(sparkContext.emptyRDD) + } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[SparkPlan]): MultiResultExec = { + copy(children = newChildren) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 5c894eb7555b1..6ff2c5d4b9d32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -46,8 +46,8 @@ import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, WatermarkPropagator} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.util.{LazyTry, Utils} import org.apache.spark.util.ArrayImplicits._ -import org.apache.spark.util.Utils /** * The primary workflow for executing relational queries using Spark. Designed to allow easy @@ -86,7 +86,7 @@ class QueryExecution( } } - lazy val analyzed: LogicalPlan = { + private val lazyAnalyzed = LazyTry { val plan = executePhase(QueryPlanningTracker.ANALYSIS) { // We can't clone `logical` here, which will reset the `_analyzed` flag. sparkSession.sessionState.analyzer.executeAndCheck(logical, tracker) @@ -95,12 +95,18 @@ class QueryExecution( plan } - lazy val commandExecuted: LogicalPlan = mode match { - case CommandExecutionMode.NON_ROOT => analyzed.mapChildren(eagerlyExecuteCommands) - case CommandExecutionMode.ALL => eagerlyExecuteCommands(analyzed) - case CommandExecutionMode.SKIP => analyzed + def analyzed: LogicalPlan = lazyAnalyzed.get + + private val lazyCommandExecuted = LazyTry { + mode match { + case CommandExecutionMode.NON_ROOT => analyzed.mapChildren(eagerlyExecuteCommands) + case CommandExecutionMode.ALL => eagerlyExecuteCommands(analyzed) + case CommandExecutionMode.SKIP => analyzed + } } + def commandExecuted: LogicalPlan = lazyCommandExecuted.get + private def commandExecutionName(command: Command): String = command match { case _: CreateTableAsSelect => "create" case _: ReplaceTableAsSelect => "replace" @@ -141,22 +147,28 @@ class QueryExecution( } } - // The plan that has been normalized by custom rules, so that it's more likely to hit cache. - lazy val normalized: LogicalPlan = { + private val lazyNormalized = LazyTry { QueryExecution.normalize(sparkSession, commandExecuted, Some(tracker)) } - lazy val withCachedData: LogicalPlan = sparkSession.withActive { - assertAnalyzed() - assertSupported() - // clone the plan to avoid sharing the plan instance between different stages like analyzing, - // optimizing and planning. - sparkSession.sharedState.cacheManager.useCachedData(normalized.clone()) + // The plan that has been normalized by custom rules, so that it's more likely to hit cache. + def normalized: LogicalPlan = lazyNormalized.get + + private val lazyWithCachedData = LazyTry { + sparkSession.withActive { + assertAnalyzed() + assertSupported() + // clone the plan to avoid sharing the plan instance between different stages like analyzing, + // optimizing and planning. + sparkSession.sharedState.cacheManager.useCachedData(normalized.clone()) + } } + def withCachedData: LogicalPlan = lazyWithCachedData.get + def assertCommandExecuted(): Unit = commandExecuted - lazy val optimizedPlan: LogicalPlan = { + private val lazyOptimizedPlan = LazyTry { // We need to materialize the commandExecuted here because optimizedPlan is also tracked under // the optimizing phase assertCommandExecuted() @@ -174,9 +186,11 @@ class QueryExecution( } } + def optimizedPlan: LogicalPlan = lazyOptimizedPlan.get + def assertOptimized(): Unit = optimizedPlan - lazy val sparkPlan: SparkPlan = { + private val lazySparkPlan = LazyTry { // We need to materialize the optimizedPlan here because sparkPlan is also tracked under // the planning phase assertOptimized() @@ -187,11 +201,11 @@ class QueryExecution( } } + def sparkPlan: SparkPlan = lazySparkPlan.get + def assertSparkPlanPrepared(): Unit = sparkPlan - // executedPlan should not be used to initialize any SparkPlan. It should be - // only used for execution. - lazy val executedPlan: SparkPlan = { + private val lazyExecutedPlan = LazyTry { // We need to materialize the optimizedPlan here, before tracking the planning phase, to ensure // that the optimization time is not counted as part of the planning phase. assertOptimized() @@ -206,8 +220,16 @@ class QueryExecution( plan } + // executedPlan should not be used to initialize any SparkPlan. It should be + // only used for execution. + def executedPlan: SparkPlan = lazyExecutedPlan.get + def assertExecutedPlanPrepared(): Unit = executedPlan + val lazyToRdd = LazyTry { + new SQLExecutionRDD(executedPlan.execute(), sparkSession.sessionState.conf) + } + /** * Internal version of the RDD. Avoids copies and has no schema. * Note for callers: Spark may apply various optimization including reusing object: this means @@ -218,8 +240,7 @@ class QueryExecution( * Given QueryExecution is not a public class, end users are discouraged to use this: please * use `Dataset.rdd` instead where conversion will be applied. */ - lazy val toRdd: RDD[InternalRow] = new SQLExecutionRDD( - executedPlan.execute(), sparkSession.sessionState.conf) + def toRdd: RDD[InternalRow] = lazyToRdd.get /** Get the metrics observed during the execution of the query plan. */ def observedMetrics: Map[String, Row] = CollectMetricsExec.collect(executedPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 7bc770a0c9e33..fb3ec3ad41812 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.execution.datasources.WriteFilesSpec import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.NextIterator +import org.apache.spark.util.{LazyTry, NextIterator} import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} object SparkPlan { @@ -182,6 +182,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** Specifies sort order for each partition requirements on the input data for this operator. */ def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) + @transient + private val executeRDD = LazyTry { + doExecute() + } + /** * Returns the result of this query as an RDD[InternalRow] by delegating to `doExecute` after * preparations. @@ -192,7 +197,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ if (isCanonicalizedPlan) { throw SparkException.internalError("A canonicalized plan is not supposed to be executed.") } - doExecute() + executeRDD.get + } + + private val executeBroadcastBcast = LazyTry { + doExecuteBroadcast() } /** @@ -205,7 +214,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ if (isCanonicalizedPlan) { throw SparkException.internalError("A canonicalized plan is not supposed to be executed.") } - doExecuteBroadcast() + executeBroadcastBcast.get.asInstanceOf[broadcast.Broadcast[T]] + } + + private val executeColumnarRDD = LazyTry { + doExecuteColumnar() } /** @@ -219,7 +232,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ if (isCanonicalizedPlan) { throw SparkException.internalError("A canonicalized plan is not supposed to be executed.") } - doExecuteColumnar() + executeColumnarRDD.get } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index a8261e5d98ba0..9fbe400a555fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -27,7 +27,7 @@ import org.antlr.v4.runtime.tree.TerminalNode import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, PersistedView, SchemaEvolution, SchemaTypeEvolution, UnresolvedFunctionName, UnresolvedIdentifier, UnresolvedNamespace} +import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, PersistedView, PlanWithUnresolvedIdentifier, SchemaEvolution, SchemaTypeEvolution, UnresolvedFunctionName, UnresolvedIdentifier, UnresolvedNamespace} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} import org.apache.spark.sql.catalyst.parser._ @@ -67,6 +67,25 @@ class SparkSqlAstBuilder extends AstBuilder { private val configValueDef = """([^;]*);*""".r private val strLiteralDef = """(".*?[^\\]"|'.*?[^\\]'|[^ \n\r\t"']+)""".r + private def withCatalogIdentClause( + ctx: CatalogIdentifierReferenceContext, + builder: Seq[String] => LogicalPlan): LogicalPlan = { + val exprCtx = ctx.expression + if (exprCtx != null) { + // resolve later in analyzer + PlanWithUnresolvedIdentifier(withOrigin(exprCtx) { expression(exprCtx) }, Nil, + (ident, _) => builder(ident)) + } else if (ctx.errorCapturingIdentifier() != null) { + // resolve immediately + builder.apply(Seq(ctx.errorCapturingIdentifier().getText)) + } else if (ctx.stringLit() != null) { + // resolve immediately + builder.apply(Seq(string(visitStringLit(ctx.stringLit())))) + } else { + throw SparkException.internalError("Invalid catalog name") + } + } + /** * Create a [[SetCommand]] logical plan. * @@ -149,6 +168,10 @@ class SparkSqlAstBuilder extends AstBuilder { * }}} */ override def visitSetCollation(ctx: SetCollationContext): LogicalPlan = withOrigin(ctx) { + val collationName = ctx.collationName.getText + if (!SQLConf.get.trimCollationEnabled && collationName.toUpperCase().contains("TRIM")) { + throw QueryCompilationErrors.trimCollationNotEnabledError() + } val key = SQLConf.DEFAULT_COLLATION.key SetCommand(Some(key -> Some(ctx.identifier.getText.toUpperCase(Locale.ROOT)))) } @@ -166,10 +189,29 @@ class SparkSqlAstBuilder extends AstBuilder { val key = SQLConf.SESSION_LOCAL_TIMEZONE.key if (ctx.interval != null) { val interval = parseIntervalLiteral(ctx.interval) - if (interval.months != 0 || interval.days != 0 || - math.abs(interval.microseconds) > 18 * DateTimeConstants.MICROS_PER_HOUR || - interval.microseconds % DateTimeConstants.MICROS_PER_SECOND != 0) { - throw QueryParsingErrors.intervalValueOutOfRangeError(ctx.interval()) + if (interval.months != 0) { + throw QueryParsingErrors.intervalValueOutOfRangeError( + toSQLValue(interval.months), + ctx.interval() + ) + } + else if (interval.days != 0) { + throw QueryParsingErrors.intervalValueOutOfRangeError( + toSQLValue(interval.days), + ctx.interval() + ) + } + else if (math.abs(interval.microseconds) > 18 * DateTimeConstants.MICROS_PER_HOUR) { + throw QueryParsingErrors.intervalValueOutOfRangeError( + toSQLValue((math.abs(interval.microseconds) / DateTimeConstants.MICROS_PER_HOUR).toInt), + ctx.interval() + ) + } + else if (interval.microseconds % DateTimeConstants.MICROS_PER_SECOND != 0) { + throw QueryParsingErrors.intervalValueOutOfRangeError( + toSQLValue((interval.microseconds / DateTimeConstants.MICROS_PER_SECOND).toInt), + ctx.interval() + ) } else { val seconds = (interval.microseconds / DateTimeConstants.MICROS_PER_SECOND).toInt SetCommand(Some(key -> Some(ZoneOffset.ofTotalSeconds(seconds).toString))) @@ -276,13 +318,13 @@ class SparkSqlAstBuilder extends AstBuilder { * Create a [[SetCatalogCommand]] logical command. */ override def visitSetCatalog(ctx: SetCatalogContext): LogicalPlan = withOrigin(ctx) { - if (ctx.errorCapturingIdentifier() != null) { - SetCatalogCommand(ctx.errorCapturingIdentifier().getText) - } else if (ctx.stringLit() != null) { - SetCatalogCommand(string(visitStringLit(ctx.stringLit()))) - } else { - throw SparkException.internalError("Invalid catalog name") - } + withCatalogIdentClause(ctx.catalogIdentifierReference, identifiers => { + if (identifiers.size > 1) { + // can occur when user put multipart string in IDENTIFIER(...) clause + throw QueryParsingErrors.invalidNameForSetCatalog(identifiers, ctx) + } + SetCatalogCommand(identifiers.head) + }) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6d940a30619fb..53c335c1eced6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -269,8 +269,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + def canMerge(joinType: JoinType): Boolean = joinType match { + case LeftSingle => false + case _ => true + } + def createSortMergeJoin() = { - if (RowOrdering.isOrderable(leftKeys)) { + if (canMerge(joinType) && RowOrdering.isOrderable(leftKeys)) { Some(Seq(joins.SortMergeJoinExec( leftKeys, rightKeys, joinType, nonEquiCond, planLater(left), planLater(right)))) } else { @@ -297,7 +302,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // This join could be very slow or OOM // Build the smaller side unless the join requires a particular build side // (e.g. NO_BROADCAST_AND_REPLICATION hint) - val requiredBuildSide = getBroadcastNestedLoopJoinBuildSide(hint) + val requiredBuildSide = getBroadcastNestedLoopJoinBuildSide(hint, joinType) val buildSide = requiredBuildSide.getOrElse(getSmallerSide(left, right)) Seq(joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, j.condition)) @@ -390,7 +395,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // This join could be very slow or OOM // Build the desired side unless the join requires a particular build side // (e.g. NO_BROADCAST_AND_REPLICATION hint) - val requiredBuildSide = getBroadcastNestedLoopJoinBuildSide(hint) + val requiredBuildSide = getBroadcastNestedLoopJoinBuildSide(hint, joinType) val buildSide = requiredBuildSide.getOrElse(desiredBuildSide) Seq(joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition)) @@ -1041,6 +1046,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case WriteFiles(child, fileFormat, partitionColumns, bucket, options, staticPartitions) => WriteFilesExec(planLater(child), fileFormat, partitionColumns, bucket, options, staticPartitions) :: Nil + case MultiResult(children) => + MultiResultExec(children.map(planLater)) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala index df4d895867586..5f2638655c37c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala @@ -30,7 +30,7 @@ case class PlanAdaptiveSubqueries( def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressionsWithPruning( _.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) { - case expressions.ScalarSubquery(_, _, exprId, _, _, _) => + case expressions.ScalarSubquery(_, _, exprId, _, _, _, _) => val subquery = SubqueryExec.createForScalarSubquery( s"subquery#${exprId.id}", subqueryMap(exprId.id)) execution.ScalarSubquery(subquery, exprId) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 45a71b4da7287..19a36483abe6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -173,7 +173,8 @@ abstract class HashMapGenerator( ${hashBytes(bytes)} """ } - case st: StringType if st.supportsBinaryEquality => hashBytes(s"$input.getBytes()") + case st: StringType if st.supportsBinaryEquality => + hashBytes(s"$input.getBytes()") case st: StringType if !st.supportsBinaryEquality => hashLong(s"CollationFactory.fetchCollation(${st.collationId})" + s".hashFunction.applyAsLong($input)") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index cfcfd282e5480..cbd60804b27e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -99,48 +99,6 @@ case class InMemoryTableScanExec( relation.cacheBuilder.serializer.supportsColumnarOutput(relation.schema) } - private lazy val columnarInputRDD: RDD[ColumnarBatch] = { - val numOutputRows = longMetric("numOutputRows") - val buffers = filteredCachedBatches() - relation.cacheBuilder.serializer.convertCachedBatchToColumnarBatch( - buffers, - relation.output, - attributes, - conf).map { cb => - numOutputRows += cb.numRows() - cb - } - } - - private lazy val inputRDD: RDD[InternalRow] = { - if (enableAccumulatorsForTest) { - readPartitions.setValue(0) - readBatches.setValue(0) - } - - val numOutputRows = longMetric("numOutputRows") - // Using these variables here to avoid serialization of entire objects (if referenced - // directly) within the map Partitions closure. - val relOutput = relation.output - val serializer = relation.cacheBuilder.serializer - - // update SQL metrics - val withMetrics = - filteredCachedBatches().mapPartitionsInternal { iter => - if (enableAccumulatorsForTest && iter.hasNext) { - readPartitions.add(1) - } - iter.map { batch => - if (enableAccumulatorsForTest) { - readBatches.add(1) - } - numOutputRows += batch.numRows - batch - } - } - serializer.convertCachedBatchToInternalRow(withMetrics, relOutput, attributes, conf) - } - override def output: Seq[Attribute] = attributes private def cachedPlan = relation.cachedPlan match { @@ -191,11 +149,47 @@ case class InMemoryTableScanExec( } protected override def doExecute(): RDD[InternalRow] = { - inputRDD + // Resulting RDD is cached and reused by SparkPlan.executeRDD + if (enableAccumulatorsForTest) { + readPartitions.setValue(0) + readBatches.setValue(0) + } + + val numOutputRows = longMetric("numOutputRows") + // Using these variables here to avoid serialization of entire objects (if referenced + // directly) within the map Partitions closure. + val relOutput = relation.output + val serializer = relation.cacheBuilder.serializer + + // update SQL metrics + val withMetrics = + filteredCachedBatches().mapPartitionsInternal { iter => + if (enableAccumulatorsForTest && iter.hasNext) { + readPartitions.add(1) + } + iter.map { batch => + if (enableAccumulatorsForTest) { + readBatches.add(1) + } + numOutputRows += batch.numRows + batch + } + } + serializer.convertCachedBatchToInternalRow(withMetrics, relOutput, attributes, conf) } protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { - columnarInputRDD + // Resulting RDD is cached and reused by SparkPlan.executeColumnarRDD + val numOutputRows = longMetric("numOutputRows") + val buffers = filteredCachedBatches() + relation.cacheBuilder.serializer.convertCachedBatchToColumnarBatch( + buffers, + relation.output, + attributes, + conf).map { cb => + numOutputRows += cb.numRows() + cb + } } override def isMaterialized: Boolean = relation.cacheBuilder.isCachedColumnBuffersLoaded diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index ea2736b2c1266..ea9d53190546e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, SupervisingCommand} +import org.apache.spark.sql.catalyst.plans.logical.{Command, ExecutableDuringAnalysis, LogicalPlan, SupervisingCommand} import org.apache.spark.sql.catalyst.trees.{LeafLike, UnaryLike} import org.apache.spark.sql.connector.ExternalCommandRunner import org.apache.spark.sql.execution.{CommandExecutionMode, ExplainMode, LeafExecNode, SparkPlan, UnaryExecNode} @@ -165,14 +165,19 @@ case class ExplainCommand( // Run through the optimizer to generate the physical plan. override def run(sparkSession: SparkSession): Seq[Row] = try { - val outputString = sparkSession.sessionState.executePlan(logicalPlan, CommandExecutionMode.SKIP) - .explainString(mode) + val stagedLogicalPlan = stageForAnalysis(logicalPlan) + val qe = sparkSession.sessionState.executePlan(stagedLogicalPlan, CommandExecutionMode.SKIP) + val outputString = qe.explainString(mode) Seq(Row(outputString)) } catch { case NonFatal(cause) => ("Error occurred during query planning: \n" + cause.getMessage).split("\n") .map(Row(_)).toImmutableArraySeq } + private def stageForAnalysis(plan: LogicalPlan): LogicalPlan = plan transform { + case p: ExecutableDuringAnalysis => p.stageForExplain() + } + def withTransformedSupervisedPlan(transformer: LogicalPlan => LogicalPlan): LogicalPlan = copy(logicalPlan = transformer(logicalPlan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala index 4fa1e0c1f2c58..fd47feef25d57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning -import org.apache.spark.sql.types.{DataType, StringType} +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.util.SchemaUtils object BucketingUtils { // The file name of bucketed data should have 3 parts: @@ -53,10 +54,7 @@ object BucketingUtils { bucketIdGenerator(mutableInternalRow).getInt(0) } - def canBucketOn(dataType: DataType): Boolean = dataType match { - case st: StringType => st.supportsBinaryOrdering - case other => true - } + def canBucketOn(dataType: DataType): Boolean = !SchemaUtils.hasNonUTF8BinaryCollation(dataType) def bucketIdToString(id: Int): String = f"_$id%05d" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 968c204841e46..3698dc2f0808e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -514,7 +514,8 @@ case class DataSource( dataSource.createRelation( sparkSession.sqlContext, mode, caseInsensitiveOptions, Dataset.ofRows(sparkSession, data)) case format: FileFormat => - disallowWritingIntervals(outputColumns.map(_.dataType), forbidAnsiIntervals = false) + disallowWritingIntervals( + outputColumns.toStructType.asNullable, format.toString, forbidAnsiIntervals = false) val cmd = planForWritingFileFormat(format, mode, data) val qe = sparkSession.sessionState.executePlan(cmd) qe.assertCommandExecuted() @@ -539,8 +540,8 @@ case class DataSource( } SaveIntoDataSourceCommand(data, dataSource, caseInsensitiveOptions, mode) case format: FileFormat => - disallowWritingIntervals(data.schema.map(_.dataType), forbidAnsiIntervals = false) - DataSource.validateSchema(data.schema, sparkSession.sessionState.conf) + disallowWritingIntervals(data.schema, format.toString, forbidAnsiIntervals = false) + DataSource.validateSchema(format.toString, data.schema, sparkSession.sessionState.conf) planForWritingFileFormat(format, mode, data) case _ => throw SparkException.internalError( s"${providingClass.getCanonicalName} does not allow create table as select.") @@ -566,12 +567,15 @@ case class DataSource( } private def disallowWritingIntervals( - dataTypes: Seq[DataType], + outputColumns: Seq[StructField], + format: String, forbidAnsiIntervals: Boolean): Unit = { - dataTypes.foreach( - TypeUtils.invokeOnceForInterval(_, forbidAnsiIntervals) { - throw QueryCompilationErrors.cannotSaveIntervalIntoExternalStorageError() - }) + outputColumns.foreach { field => + TypeUtils.invokeOnceForInterval(field.dataType, forbidAnsiIntervals) { + throw QueryCompilationErrors.dataTypeUnsupportedByDataSourceError( + format, field + )} + } } } @@ -838,7 +842,7 @@ object DataSource extends Logging { * @param schema * @param conf */ - def validateSchema(schema: StructType, conf: SQLConf): Unit = { + def validateSchema(formatName: String, schema: StructType, conf: SQLConf): Unit = { val shouldAllowEmptySchema = conf.getConf(SQLConf.ALLOW_EMPTY_SCHEMAS_FOR_WRITES) def hasEmptySchema(schema: StructType): Boolean = { schema.size == 0 || schema.exists { @@ -849,7 +853,7 @@ object DataSource extends Logging { if (!shouldAllowEmptySchema && hasEmptySchema(schema)) { - throw QueryCompilationErrors.writeEmptySchemasUnsupportedByDataSourceError() + throw QueryCompilationErrors.writeEmptySchemasUnsupportedByDataSourceError(formatName) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2be4b236872f0..a2707da2d1023 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -28,7 +28,7 @@ import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.PREDICATES import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, FullQualifiedTableName, InternalRow, SQLConfHelper} +import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, InternalRow, QualifiedTableName, SQLConfHelper} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ @@ -249,7 +249,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] private def readDataSourceTable( table: CatalogTable, extraOptions: CaseInsensitiveStringMap): LogicalPlan = { val qualifiedTableName = - FullQualifiedTableName(table.identifier.catalog.get, table.database, table.identifier.table) + QualifiedTableName(table.identifier.catalog.get, table.database, table.identifier.table) val catalog = sparkSession.sessionState.catalog val dsOptions = DataSourceUtils.generateDatasourceOptions(extraOptions, table) catalog.getCachedPlan(qualifiedTableName, () => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 91749ddd794fb..5e6107c4f49c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -383,32 +383,41 @@ object FileFormatWriter extends Logging { committer.setupTask(taskAttemptContext) - val dataWriter = - if (sparkPartitionId != 0 && !iterator.hasNext) { - // In case of empty job, leave first partition to save meta for file format like parquet. - new EmptyDirectoryDataWriter(description, taskAttemptContext, committer) - } else if (description.partitionColumns.isEmpty && description.bucketSpec.isEmpty) { - new SingleDirectoryDataWriter(description, taskAttemptContext, committer) - } else { - concurrentOutputWriterSpec match { - case Some(spec) => - new DynamicPartitionDataConcurrentWriter( - description, taskAttemptContext, committer, spec) - case _ => - new DynamicPartitionDataSingleWriter(description, taskAttemptContext, committer) - } - } + var dataWriter: FileFormatDataWriter = null Utils.tryWithSafeFinallyAndFailureCallbacks(block = { + dataWriter = + if (sparkPartitionId != 0 && !iterator.hasNext) { + // In case of empty job, leave first partition to save meta for file format like parquet. + new EmptyDirectoryDataWriter(description, taskAttemptContext, committer) + } else if (description.partitionColumns.isEmpty && description.bucketSpec.isEmpty) { + new SingleDirectoryDataWriter(description, taskAttemptContext, committer) + } else { + concurrentOutputWriterSpec match { + case Some(spec) => + new DynamicPartitionDataConcurrentWriter( + description, taskAttemptContext, committer, spec) + case _ => + new DynamicPartitionDataSingleWriter(description, taskAttemptContext, committer) + } + } + // Execute the task to write rows out and commit the task. dataWriter.writeWithIterator(iterator) dataWriter.commit() })(catchBlock = { // If there is an error, abort the task - dataWriter.abort() - logError(log"Job ${MDC(JOB_ID, jobId)} aborted.") + if (dataWriter != null) { + dataWriter.abort() + } else { + committer.abortTask(taskAttemptContext) + } + logError(log"Job: ${MDC(JOB_ID, jobId)}, Task: ${MDC(TASK_ID, taskId)}, " + + log"Task attempt ${MDC(TASK_ATTEMPT_ID, taskAttemptId)} aborted.") }, finallyBlock = { - dataWriter.close() + if (dataWriter != null) { + dataWriter.close() + } }) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index ffdca65151052..402b70065d8e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -173,14 +173,9 @@ object PartitioningUtils extends SQLConfHelper { // "hdfs://host:9000/path" // TODO: Selective case sensitivity. val discoveredBasePaths = optDiscoveredBasePaths.flatten.map(_.toString.toLowerCase()) - assert( - ignoreInvalidPartitionPaths || discoveredBasePaths.distinct.size == 1, - "Conflicting directory structures detected. Suspicious paths:\b" + - discoveredBasePaths.distinct.mkString("\n\t", "\n\t", "\n\n") + - "If provided paths are partition directories, please set " + - "\"basePath\" in the options of the data source to specify the " + - "root directory of the table. If there are multiple root directories, " + - "please load them separately and then union them.") + if (!ignoreInvalidPartitionPaths && discoveredBasePaths.distinct.size != 1) { + throw QueryExecutionErrors.conflictingDirectoryStructuresError(discoveredBasePaths) + } val resolvedPartitionValues = resolvePartitions(pathsWithPartitionValues, caseSensitive) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index f7d2d61eab653..7946068b9452e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -1262,13 +1262,14 @@ object JdbcUtils extends Logging with SQLConfHelper { errorClass: String, messageParameters: Map[String, String], dialect: JdbcDialect, - description: String)(f: => T): T = { + description: String, + isRuntime: Boolean)(f: => T): T = { try { f } catch { case e: SparkThrowable with Throwable => throw e case e: Throwable => - throw dialect.classifyException(e, errorClass, messageParameters, description) + throw dialect.classifyException(e, errorClass, messageParameters, description, isRuntime) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 7c98c31bba220..cb4c4f5290880 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -164,7 +164,8 @@ object MultiLineJsonDataSource extends JsonDataSource { .getOrElse(createParser(_: JsonFactory, _: PortableDataStream)) SQLExecution.withSQLConfPropagated(sparkSession) { - new JsonInferSchema(parsedOptions).infer[PortableDataStream](sampled, parser) + new JsonInferSchema(parsedOptions) + .infer[PortableDataStream](sampled, parser, isReadFile = true) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 29385904a7525..cbbf9f88f89d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -89,9 +89,9 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { LogicalRelation(ds.resolveRelation()) } catch { case _: ClassNotFoundException => u - case e: SparkIllegalArgumentException if e.getErrorClass != null => + case e: SparkIllegalArgumentException if e.getCondition != null => u.failAnalysis( - errorClass = e.getErrorClass, + errorClass = e.getCondition, messageParameters = e.getMessageParameters.asScala.toMap, cause = e) case e: Exception if !e.isInstanceOf[AnalysisException] => @@ -469,8 +469,8 @@ object PreprocessTableInsertion extends ResolveInsertionBase { supportColDefaultValue = true) } catch { case e: AnalysisException if staticPartCols.nonEmpty && - (e.getErrorClass == "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS" || - e.getErrorClass == "INSERT_COLUMN_ARITY_MISMATCH.TOO_MANY_DATA_COLUMNS") => + (e.getCondition == "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS" || + e.getCondition == "INSERT_COLUMN_ARITY_MISMATCH.TOO_MANY_DATA_COLUMNS") => val newException = e.copy( errorClass = Some("INSERT_PARTITION_COLUMN_ARITY_MISMATCH"), messageParameters = e.messageParameters ++ Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index d7f46c32f99a0..76cd33b815edd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -32,8 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.{toPrettySQL, GeneratedColumn, - IdentityColumn, ResolveDefaultColumns, V2ExpressionBuilder} +import org.apache.spark.sql.catalyst.util.{toPrettySQL, GeneratedColumn, IdentityColumn, ResolveDefaultColumns, V2ExpressionBuilder} import org.apache.spark.sql.connector.catalog.{Identifier, StagingTableCatalog, SupportsDeleteV2, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, Table, TableCapability, TableCatalog, TruncatableTable} import org.apache.spark.sql.connector.catalog.index.SupportsIndex import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue} @@ -554,6 +553,9 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat systemScope, pattern) :: Nil + case c: Call => + ExplainOnlySparkPlan(c) :: Nil + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeTableExec.scala index 7f7f280d8cdca..7cfd601ef774f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeTableExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, ResolveDefaultColumns} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsMetadataColumns, SupportsRead, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.{ClusterByTransform, IdentityTransform} import org.apache.spark.sql.connector.read.SupportsReportStatistics +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.ArrayImplicits._ @@ -156,9 +157,12 @@ case class DescribeTableExec( .map(_.asInstanceOf[IdentityTransform].ref.fieldNames()) .map { fieldNames => val nestedField = table.schema.findNestedField(fieldNames.toImmutableArraySeq) - assert(nestedField.isDefined, - s"Not found the partition column ${fieldNames.map(quoteIfNeeded).mkString(".")} " + - s"in the table schema ${table.schema().catalogString}.") + if (nestedField.isEmpty) { + throw QueryExecutionErrors.partitionColumnNotFoundInTheTableSchemaError( + fieldNames.toSeq, + table.schema() + ) + } nestedField.get }.map { case (path, field) => toCatalystRow( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExplainOnlySparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExplainOnlySparkPlan.scala new file mode 100644 index 0000000000000..bbf56eaa71184 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExplainOnlySparkPlan.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.LeafLike +import org.apache.spark.sql.execution.SparkPlan + +case class ExplainOnlySparkPlan(toExplain: LogicalPlan) extends SparkPlan with LeafLike[SparkPlan] { + + override def output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + toExplain.simpleString(maxFields) + } + + override protected def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala index 168aea5b041f8..4242fc5d8510a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala @@ -131,7 +131,7 @@ object FileDataSourceV2 { // The error is already FAILED_READ_FILE, throw it directly. To be consistent, schema // inference code path throws `FAILED_READ_FILE`, but the file reading code path can reach // that code path as well and we should not double-wrap the error. - case e: SparkException if e.getErrorClass == "FAILED_READ_FILE.CANNOT_READ_FILE_FOOTER" => + case e: SparkException if e.getCondition == "FAILED_READ_FILE.CANNOT_READ_FILE_FOOTER" => throw e case e: SchemaColumnConvertNotSupportedException => throw QueryExecutionErrors.parquetColumnDataTypeMismatchError( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index d890107277d6c..5c0f8c0a4afd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -164,7 +164,7 @@ trait FileScan extends Scan if (splitFiles.length == 1) { val path = splitFiles(0).toPath if (!isSplitable(path) && splitFiles(0).length > - sparkSession.sparkContext.getConf.get(IO_WARNING_LARGEFILETHRESHOLD)) { + sparkSession.sparkContext.conf.get(IO_WARNING_LARGEFILETHRESHOLD)) { logWarning(log"Loading one large unsplittable file ${MDC(PATH, path.toString)} with only " + log"one partition, the reason is: ${MDC(REASON, getFileUnSplittableReason(path))}") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala index cdcf6f21fd008..f4cabcb69d08c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala @@ -96,7 +96,7 @@ trait FileWrite extends Write { SchemaUtils.checkColumnNameDuplication( schema.fields.map(_.name).toImmutableArraySeq, caseSensitiveAnalysis) } - DataSource.validateSchema(schema, sqlConf) + DataSource.validateSchema(formatName, schema, sqlConf) // TODO: [SPARK-36340] Unify check schema filed of DataSource V2 Insert. schema.foreach { field => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index bd1df87d15c3c..22c13fd98ced1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import scala.jdk.CollectionConverters._ import org.apache.spark.SparkUnsupportedOperationException -import org.apache.spark.sql.catalyst.{FullQualifiedTableName, FunctionIdentifier, SQLConfHelper, TableIdentifier} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, QualifiedTableName, SQLConfHelper, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogStorageFormat, CatalogTable, CatalogTableType, CatalogUtils, ClusterBySpec, SessionCatalog} import org.apache.spark.sql.catalyst.util.TypeUtils._ @@ -93,7 +93,7 @@ class V2SessionCatalog(catalog: SessionCatalog) // table here. To avoid breaking it we do not resolve the table provider and still return // `V1Table` if the custom session catalog is present. if (table.provider.isDefined && !hasCustomSessionCatalog) { - val qualifiedTableName = FullQualifiedTableName( + val qualifiedTableName = QualifiedTableName( table.identifier.catalog.get, table.database, table.identifier.table) // Check if the table is in the v1 table cache to skip the v2 table lookup. if (catalog.getCachedTable(qualifiedTableName) != null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTable.scala index 6828bb0f0c4d8..20283cc124596 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTable.scala @@ -70,7 +70,8 @@ case class JDBCTable(ident: Identifier, schema: StructType, jdbcOptions: JDBCOpt "indexName" -> toSQLId(indexName), "tableName" -> toSQLId(name)), dialect = JdbcDialects.get(jdbcOptions.url), - description = s"Failed to create index $indexName in ${name()}") { + description = s"Failed to create index $indexName in ${name()}", + isRuntime = false) { JdbcUtils.createIndex( conn, indexName, ident, columns, columnsProperties, properties, jdbcOptions) } @@ -92,7 +93,8 @@ case class JDBCTable(ident: Identifier, schema: StructType, jdbcOptions: JDBCOpt "indexName" -> toSQLId(indexName), "tableName" -> toSQLId(name)), dialect = JdbcDialects.get(jdbcOptions.url), - description = s"Failed to drop index $indexName in ${name()}") { + description = s"Failed to drop index $indexName in ${name()}", + isRuntime = false) { JdbcUtils.dropIndex(conn, indexName, ident, jdbcOptions) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala index 3871bdf501771..99e9abe965183 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala @@ -73,7 +73,8 @@ class JDBCTableCatalog extends TableCatalog "url" -> options.getRedactUrl(), "namespace" -> toSQLId(namespace.toSeq)), dialect, - description = s"Failed get tables from: ${namespace.mkString(".")}") { + description = s"Failed get tables from: ${namespace.mkString(".")}", + isRuntime = false) { conn.getMetaData.getTables(null, schemaPattern, "%", Array("TABLE")) } new Iterator[Identifier] { @@ -93,7 +94,8 @@ class JDBCTableCatalog extends TableCatalog "url" -> options.getRedactUrl(), "tableName" -> toSQLId(ident)), dialect, - description = s"Failed table existence check: $ident") { + description = s"Failed table existence check: $ident", + isRuntime = false) { JdbcUtils.withConnection(options)(JdbcUtils.tableExists(_, writeOptions)) } } @@ -120,7 +122,8 @@ class JDBCTableCatalog extends TableCatalog "oldName" -> toSQLId(oldIdent), "newName" -> toSQLId(newIdent)), dialect, - description = s"Failed table renaming from $oldIdent to $newIdent") { + description = s"Failed table renaming from $oldIdent to $newIdent", + isRuntime = false) { JdbcUtils.renameTable(conn, oldIdent, newIdent, options) } } @@ -136,7 +139,8 @@ class JDBCTableCatalog extends TableCatalog "url" -> options.getRedactUrl(), "tableName" -> toSQLId(ident)), dialect, - description = s"Failed to load table: $ident" + description = s"Failed to load table: $ident", + isRuntime = false ) { val schema = JDBCRDD.resolveTable(optionsWithTableName) JDBCTable(ident, schema, optionsWithTableName) @@ -192,7 +196,8 @@ class JDBCTableCatalog extends TableCatalog "url" -> options.getRedactUrl(), "tableName" -> toSQLId(ident)), dialect, - description = s"Failed table creation: $ident") { + description = s"Failed table creation: $ident", + isRuntime = false) { JdbcUtils.createTable(conn, getTableName(ident), schema, caseSensitive, writeOptions) } } @@ -209,7 +214,8 @@ class JDBCTableCatalog extends TableCatalog "url" -> options.getRedactUrl(), "tableName" -> toSQLId(ident)), dialect, - description = s"Failed table altering: $ident") { + description = s"Failed table altering: $ident", + isRuntime = false) { JdbcUtils.alterTable(conn, getTableName(ident), changes, options) } loadTable(ident) @@ -225,7 +231,8 @@ class JDBCTableCatalog extends TableCatalog "url" -> options.getRedactUrl(), "namespace" -> toSQLId(namespace.toSeq)), dialect, - description = s"Failed namespace exists: ${namespace.mkString}") { + description = s"Failed namespace exists: ${namespace.mkString}", + isRuntime = false) { JdbcUtils.schemaExists(conn, options, db) } } @@ -238,7 +245,8 @@ class JDBCTableCatalog extends TableCatalog errorClass = "FAILED_JDBC.LIST_NAMESPACES", messageParameters = Map("url" -> options.getRedactUrl()), dialect, - description = s"Failed list namespaces") { + description = s"Failed list namespaces", + isRuntime = false) { JdbcUtils.listSchemas(conn, options) } } @@ -292,7 +300,8 @@ class JDBCTableCatalog extends TableCatalog "url" -> options.getRedactUrl(), "namespace" -> toSQLId(db)), dialect, - description = s"Failed create name space: $db") { + description = s"Failed create name space: $db", + isRuntime = false) { JdbcUtils.createSchema(conn, options, db, comment) } } @@ -317,7 +326,8 @@ class JDBCTableCatalog extends TableCatalog "url" -> options.getRedactUrl(), "namespace" -> toSQLId(db)), dialect, - description = s"Failed create comment on name space: $db") { + description = s"Failed create comment on name space: $db", + isRuntime = false) { JdbcUtils.alterSchemaComment(conn, options, db, set.value) } } @@ -334,7 +344,8 @@ class JDBCTableCatalog extends TableCatalog "url" -> options.getRedactUrl(), "namespace" -> toSQLId(db)), dialect, - description = s"Failed remove comment on name space: $db") { + description = s"Failed remove comment on name space: $db", + isRuntime = false) { JdbcUtils.removeSchemaComment(conn, options, db) } } @@ -362,7 +373,8 @@ class JDBCTableCatalog extends TableCatalog "url" -> options.getRedactUrl(), "namespace" -> toSQLId(db)), dialect, - description = s"Failed drop name space: $db") { + description = s"Failed drop name space: $db", + isRuntime = false) { JdbcUtils.dropSchema(conn, options, db, cascade) true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 50b90641d309b..edddfbd6ccaef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -29,15 +29,16 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.DataSourceOptions import org.apache.spark.sql.connector.catalog.{Table, TableProvider} import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{JoinSideValues, STATE_VAR_NAME} +import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{JoinSideValues, READ_REGISTERED_TIMERS, STATE_VAR_NAME} import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues.JoinSideValues import org.apache.spark.sql.execution.datasources.v2.state.metadata.{StateMetadataPartitionReader, StateMetadataTableEntry} import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil -import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, OffsetSeqMetadata, TransformWithStateOperatorProperties, TransformWithStateVariableInfo} +import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, OffsetSeqMetadata, StateVariableType, TimerStateUtils, TransformWithStateOperatorProperties, TransformWithStateVariableInfo} import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DIR_NAME_COMMITS, DIR_NAME_OFFSETS, DIR_NAME_STATE} import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide} import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateSchemaCompatibilityChecker, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId} import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.streaming.TimeMode import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration @@ -132,7 +133,7 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging sourceOptions: StateSourceOptions, stateStoreMetadata: Array[StateMetadataTableEntry]): Unit = { val twsShortName = "transformWithStateExec" - if (sourceOptions.stateVarName.isDefined) { + if (sourceOptions.stateVarName.isDefined || sourceOptions.readRegisteredTimers) { // Perform checks for transformWithState operator in case state variable name is provided require(stateStoreMetadata.size == 1) val opMetadata = stateStoreMetadata.head @@ -153,18 +154,31 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging "No state variable names are defined for the transformWithState operator") } + val twsOperatorProperties = TransformWithStateOperatorProperties.fromJson(operatorProperties) + val timeMode = twsOperatorProperties.timeMode + if (sourceOptions.readRegisteredTimers && timeMode == TimeMode.None().toString) { + throw StateDataSourceErrors.invalidOptionValue(READ_REGISTERED_TIMERS, + "Registered timers are not available in TimeMode=None.") + } + // if the state variable is not one of the defined/available state variables, then we // fail the query - val stateVarName = sourceOptions.stateVarName.get - val twsOperatorProperties = TransformWithStateOperatorProperties.fromJson(operatorProperties) + val stateVarName = if (sourceOptions.readRegisteredTimers) { + TimerStateUtils.getTimerStateVarName(timeMode) + } else { + sourceOptions.stateVarName.get + } + val stateVars = twsOperatorProperties.stateVariables - if (stateVars.filter(stateVar => stateVar.stateName == stateVarName).size != 1) { + val stateVarInfo = stateVars.filter(stateVar => stateVar.stateName == stateVarName) + if (stateVarInfo.size != 1) { throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME, s"State variable $stateVarName is not defined for the transformWithState operator.") } - // TODO: Support change feed and transformWithState together - if (sourceOptions.readChangeFeed) { + // TODO: add support for list and map type + if (sourceOptions.readChangeFeed && + stateVarInfo.head.stateVariableType != StateVariableType.ValueState) { throw StateDataSourceErrors.conflictOptions(Seq(StateSourceOptions.READ_CHANGE_FEED, StateSourceOptions.STATE_VAR_NAME)) } @@ -196,9 +210,10 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging var keyStateEncoderSpecOpt: Option[KeyStateEncoderSpec] = None var stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema] = None var transformWithStateVariableInfoOpt: Option[TransformWithStateVariableInfo] = None + var timeMode: String = TimeMode.None.toString if (sourceOptions.joinSide == JoinSideValues.none) { - val stateVarName = sourceOptions.stateVarName + var stateVarName = sourceOptions.stateVarName .getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME) // Read the schema file path from operator metadata version v2 onwards @@ -208,6 +223,12 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging val storeMetadataEntry = storeMetadata.head val operatorProperties = TransformWithStateOperatorProperties.fromJson( storeMetadataEntry.operatorPropertiesJson) + timeMode = operatorProperties.timeMode + + if (sourceOptions.readRegisteredTimers) { + stateVarName = TimerStateUtils.getTimerStateVarName(timeMode) + } + val stateVarInfoList = operatorProperties.stateVariables .filter(stateVar => stateVar.stateName == stateVarName) require(stateVarInfoList.size == 1, s"Failed to find unique state variable info " + @@ -303,13 +324,16 @@ case class StateSourceOptions( readChangeFeed: Boolean, fromSnapshotOptions: Option[FromSnapshotOptions], readChangeFeedOptions: Option[ReadChangeFeedOptions], - stateVarName: Option[String]) { + stateVarName: Option[String], + readRegisteredTimers: Boolean, + flattenCollectionTypes: Boolean) { def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE) override def toString: String = { var desc = s"StateSourceOptions(checkpointLocation=$resolvedCpLocation, batchId=$batchId, " + s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide, " + - s"stateVarName=${stateVarName.getOrElse("None")}" + s"stateVarName=${stateVarName.getOrElse("None")}, +" + + s"flattenCollectionTypes=$flattenCollectionTypes" if (fromSnapshotOptions.isDefined) { desc += s", snapshotStartBatchId=${fromSnapshotOptions.get.snapshotStartBatchId}" desc += s", snapshotPartitionId=${fromSnapshotOptions.get.snapshotPartitionId}" @@ -334,6 +358,8 @@ object StateSourceOptions extends DataSourceOptions { val CHANGE_START_BATCH_ID = newOption("changeStartBatchId") val CHANGE_END_BATCH_ID = newOption("changeEndBatchId") val STATE_VAR_NAME = newOption("stateVarName") + val READ_REGISTERED_TIMERS = newOption("readRegisteredTimers") + val FLATTEN_COLLECTION_TYPES = newOption("flattenCollectionTypes") object JoinSideValues extends Enumeration { type JoinSideValues = Value @@ -374,6 +400,28 @@ object StateSourceOptions extends DataSourceOptions { val stateVarName = Option(options.get(STATE_VAR_NAME)) .map(_.trim) + val readRegisteredTimers = try { + Option(options.get(READ_REGISTERED_TIMERS)) + .map(_.toBoolean).getOrElse(false) + } catch { + case _: IllegalArgumentException => + throw StateDataSourceErrors.invalidOptionValue(READ_REGISTERED_TIMERS, + "Boolean value is expected") + } + + if (readRegisteredTimers && stateVarName.isDefined) { + throw StateDataSourceErrors.conflictOptions(Seq(READ_REGISTERED_TIMERS, STATE_VAR_NAME)) + } + + val flattenCollectionTypes = try { + Option(options.get(FLATTEN_COLLECTION_TYPES)) + .map(_.toBoolean).getOrElse(true) + } catch { + case _: IllegalArgumentException => + throw StateDataSourceErrors.invalidOptionValue(FLATTEN_COLLECTION_TYPES, + "Boolean value is expected") + } + val joinSide = try { Option(options.get(JOIN_SIDE)) .map(JoinSideValues.withName).getOrElse(JoinSideValues.none) @@ -477,7 +525,8 @@ object StateSourceOptions extends DataSourceOptions { StateSourceOptions( resolvedCpLocation, batchId.get, operatorId, storeName, joinSide, - readChangeFeed, fromSnapshotOptions, readChangeFeedOptions, stateVarName) + readChangeFeed, fromSnapshotOptions, readChangeFeedOptions, + stateVarName, readRegisteredTimers, flattenCollectionTypes) } private def resolvedCheckpointLocation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 24166a46bbd39..b925aee5b627a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.datasources.v2.state import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} -import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil import org.apache.spark.sql.execution.streaming.{StateVariableType, TransformWithStateVariableInfo} @@ -75,9 +74,11 @@ abstract class StatePartitionReaderBase( StructType(Array(StructField("__dummy__", NullType))) protected val keySchema = { - if (!SchemaUtil.isMapStateVariable(stateVariableInfoOpt)) { + if (SchemaUtil.checkVariableType(stateVariableInfoOpt, StateVariableType.MapState)) { + SchemaUtil.getCompositeKeySchema(schema, partition.sourceOptions) + } else { SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] - } else SchemaUtil.getCompositeKeySchema(schema) + } } protected val valueSchema = if (stateVariableInfoOpt.isDefined) { @@ -98,18 +99,16 @@ abstract class StatePartitionReaderBase( false } - val useMultipleValuesPerKey = if (stateVariableInfoOpt.isDefined && - stateVariableInfoOpt.get.stateVariableType == StateVariableType.ListState) { - true - } else { - false - } + val useMultipleValuesPerKey = SchemaUtil.checkVariableType(stateVariableInfoOpt, + StateVariableType.ListState) val provider = StateStoreProvider.createAndInit( stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec, useColumnFamilies = useColFamilies, storeConf, hadoopConf.value, useMultipleValuesPerKey = useMultipleValuesPerKey) + val isInternal = partition.sourceOptions.readRegisteredTimers + if (useColFamilies) { val store = provider.getStore(partition.sourceOptions.batchId + 1) require(stateStoreColFamilySchemaOpt.isDefined) @@ -120,7 +119,8 @@ abstract class StatePartitionReaderBase( stateStoreColFamilySchema.keySchema, stateStoreColFamilySchema.valueSchema, stateStoreColFamilySchema.keyStateEncoderSpec.get, - useMultipleValuesPerKey = useMultipleValuesPerKey) + useMultipleValuesPerKey = useMultipleValuesPerKey, + isInternal = isInternal) } provider } @@ -149,7 +149,7 @@ abstract class StatePartitionReaderBase( /** * An implementation of [[StatePartitionReaderBase]] for the normal mode of State Data - * Source. It reads the the state at a particular batchId. + * Source. It reads the state at a particular batchId. */ class StatePartitionReader( storeConf: StateStoreConf, @@ -181,41 +181,17 @@ class StatePartitionReader( override lazy val iter: Iterator[InternalRow] = { val stateVarName = stateVariableInfoOpt .map(_.stateName).getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME) - if (SchemaUtil.isMapStateVariable(stateVariableInfoOpt)) { - SchemaUtil.unifyMapStateRowPair( - store.iterator(stateVarName), keySchema, partition.partition) + + if (stateVariableInfoOpt.isDefined) { + val stateVariableInfo = stateVariableInfoOpt.get + val stateVarType = stateVariableInfo.stateVariableType + SchemaUtil.processStateEntries(stateVarType, stateVarName, store, + keySchema, partition.partition, partition.sourceOptions) } else { store .iterator(stateVarName) .map { pair => - stateVariableInfoOpt match { - case Some(stateVarInfo) => - val stateVarType = stateVarInfo.stateVariableType - - stateVarType match { - case StateVariableType.ValueState => - SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition) - - case StateVariableType.ListState => - val key = pair.key - val result = store.valuesIterator(key, stateVarName) - var unsafeRowArr: Seq[UnsafeRow] = Seq.empty - result.foreach { entry => - unsafeRowArr = unsafeRowArr :+ entry.copy() - } - // convert the list of values to array type - val arrData = new GenericArrayData(unsafeRowArr.toArray) - SchemaUtil.unifyStateRowPairWithMultipleValues((pair.key, arrData), - partition.partition) - - case _ => - throw new IllegalStateException( - s"Unsupported state variable type: $stateVarType") - } - - case None => - SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition) - } + SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition) } } } @@ -247,10 +223,18 @@ class StateStoreChangeDataPartitionReader( throw StateStoreErrors.stateStoreProviderDoesNotSupportFineGrainedReplay( provider.getClass.toString) } + + val colFamilyNameOpt = if (stateVariableInfoOpt.isDefined) { + Some(stateVariableInfoOpt.get.stateName) + } else { + None + } + provider.asInstanceOf[SupportsFineGrainedReplay] .getStateStoreChangeDataReader( partition.sourceOptions.readChangeFeedOptions.get.changeStartBatchId + 1, - partition.sourceOptions.readChangeFeedOptions.get.changeEndBatchId + 1) + partition.sourceOptions.readChangeFeedOptions.get.changeEndBatchId + 1, + colFamilyNameOpt) } override lazy val iter: Iterator[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala index 88ea06d598e56..c337d548fa42b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.v2.state.utils import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import org.apache.spark.sql.AnalysisException @@ -24,9 +25,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceErrors, StateSourceOptions} +import org.apache.spark.sql.execution.streaming.{StateVariableType, TransformWithStateVariableInfo} import org.apache.spark.sql.execution.streaming.StateVariableType._ -import org.apache.spark.sql.execution.streaming.TransformWithStateVariableInfo -import org.apache.spark.sql.execution.streaming.state.{StateStoreColFamilySchema, UnsafeRowPair} +import org.apache.spark.sql.execution.streaming.state.{ReadStateStore, StateStoreColFamilySchema, UnsafeRowPair} import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, LongType, MapType, StringType, StructType} import org.apache.spark.util.ArrayImplicits._ @@ -58,7 +59,7 @@ object SchemaUtil { } else if (transformWithStateVariableInfoOpt.isDefined) { require(stateStoreColFamilySchemaOpt.isDefined) generateSchemaForStateVar(transformWithStateVariableInfoOpt.get, - stateStoreColFamilySchemaOpt.get) + stateStoreColFamilySchemaOpt.get, sourceOptions) } else { new StructType() .add("key", keySchema) @@ -101,7 +102,8 @@ object SchemaUtil { def unifyMapStateRowPair( stateRows: Iterator[UnsafeRowPair], compositeKeySchema: StructType, - partitionId: Int): Iterator[InternalRow] = { + partitionId: Int, + stateSourceOptions: StateSourceOptions): Iterator[InternalRow] = { val groupingKeySchema = SchemaUtil.getSchemaAsDataType( compositeKeySchema, "key" ).asInstanceOf[StructType] @@ -130,61 +132,84 @@ object SchemaUtil { row } - // All of the rows with the same grouping key were co-located and were - // grouped together consecutively. - new Iterator[InternalRow] { - var curGroupingKey: UnsafeRow = _ - var curStateRowPair: UnsafeRowPair = _ - val curMap = mutable.Map.empty[Any, Any] - - override def hasNext: Boolean = - stateRows.hasNext || !curMap.isEmpty - - override def next(): InternalRow = { - var foundNewGroupingKey = false - while (stateRows.hasNext && !foundNewGroupingKey) { - curStateRowPair = stateRows.next() - if (curGroupingKey == null) { - // First time in the iterator - // Need to make a copy because we need to keep the - // value across function calls - curGroupingKey = curStateRowPair.key - .get(0, groupingKeySchema).asInstanceOf[UnsafeRow].copy() - appendKVPairToMap(curMap, curStateRowPair) - } else { - val curPairGroupingKey = - curStateRowPair.key.get(0, groupingKeySchema) - if (curPairGroupingKey == curGroupingKey) { + def createFlattenedRow( + groupingKey: UnsafeRow, + userMapKey: UnsafeRow, + userMapValue: UnsafeRow, + partitionId: Int): GenericInternalRow = { + val row = new GenericInternalRow(4) + row.update(0, groupingKey) + row.update(1, userMapKey) + row.update(2, userMapValue) + row.update(3, partitionId) + row + } + + if (stateSourceOptions.flattenCollectionTypes) { + stateRows + .map { pair => + val groupingKey = pair.key.get(0, groupingKeySchema).asInstanceOf[UnsafeRow] + val userMapKey = pair.key.get(1, userKeySchema).asInstanceOf[UnsafeRow] + val userMapValue = pair.value + createFlattenedRow(groupingKey, userMapKey, userMapValue, partitionId) + } + } else { + // All of the rows with the same grouping key were co-located and were + // grouped together consecutively. + new Iterator[InternalRow] { + var curGroupingKey: UnsafeRow = _ + var curStateRowPair: UnsafeRowPair = _ + val curMap = mutable.Map.empty[Any, Any] + + override def hasNext: Boolean = + stateRows.hasNext || !curMap.isEmpty + + override def next(): InternalRow = { + var foundNewGroupingKey = false + while (stateRows.hasNext && !foundNewGroupingKey) { + curStateRowPair = stateRows.next() + if (curGroupingKey == null) { + // First time in the iterator + // Need to make a copy because we need to keep the + // value across function calls + curGroupingKey = curStateRowPair.key + .get(0, groupingKeySchema).asInstanceOf[UnsafeRow].copy() appendKVPairToMap(curMap, curStateRowPair) } else { - // find a different grouping key, exit loop and return a row - foundNewGroupingKey = true + val curPairGroupingKey = + curStateRowPair.key.get(0, groupingKeySchema) + if (curPairGroupingKey == curGroupingKey) { + appendKVPairToMap(curMap, curStateRowPair) + } else { + // find a different grouping key, exit loop and return a row + foundNewGroupingKey = true + } } } - } - if (foundNewGroupingKey) { - // found a different grouping key - val row = createDataRow(curGroupingKey, curMap) - // update vars - curGroupingKey = - curStateRowPair.key.get(0, groupingKeySchema) - .asInstanceOf[UnsafeRow].copy() - // empty the map, append current row - curMap.clear() - appendKVPairToMap(curMap, curStateRowPair) - // return map value of previous grouping key - row - } else { - if (curMap.isEmpty) { - throw new NoSuchElementException("Please check if the iterator hasNext(); Likely " + - "user is trying to get element from an exhausted iterator.") - } - else { - // reach the end of the state rows + if (foundNewGroupingKey) { + // found a different grouping key val row = createDataRow(curGroupingKey, curMap) - // clear the map to end the iterator + // update vars + curGroupingKey = + curStateRowPair.key.get(0, groupingKeySchema) + .asInstanceOf[UnsafeRow].copy() + // empty the map, append current row curMap.clear() + appendKVPairToMap(curMap, curStateRowPair) + // return map value of previous grouping key row + } else { + if (curMap.isEmpty) { + throw new NoSuchElementException("Please check if the iterator hasNext(); Likely " + + "user is trying to get element from an exhausted iterator.") + } + else { + // reach the end of the state rows + val row = createDataRow(curGroupingKey, curMap) + // clear the map to end the iterator + curMap.clear() + row + } } } } @@ -200,9 +225,12 @@ object SchemaUtil { "change_type" -> classOf[StringType], "key" -> classOf[StructType], "value" -> classOf[StructType], - "single_value" -> classOf[StructType], + "list_element" -> classOf[StructType], "list_value" -> classOf[ArrayType], "map_value" -> classOf[MapType], + "user_map_key" -> classOf[StructType], + "user_map_value" -> classOf[StructType], + "expiration_timestamp_ms" -> classOf[LongType], "partition_id" -> classOf[IntegerType]) val expectedFieldNames = if (sourceOptions.readChangeFeed) { @@ -213,13 +241,24 @@ object SchemaUtil { stateVarType match { case ValueState => - Seq("key", "single_value", "partition_id") + Seq("key", "value", "partition_id") case ListState => - Seq("key", "list_value", "partition_id") + if (sourceOptions.flattenCollectionTypes) { + Seq("key", "list_element", "partition_id") + } else { + Seq("key", "list_value", "partition_id") + } case MapState => - Seq("key", "map_value", "partition_id") + if (sourceOptions.flattenCollectionTypes) { + Seq("key", "user_map_key", "user_map_value", "partition_id") + } else { + Seq("key", "map_value", "partition_id") + } + + case TimerState => + Seq("key", "expiration_timestamp_ms", "partition_id") case _ => throw StateDataSourceErrors @@ -241,21 +280,29 @@ object SchemaUtil { private def generateSchemaForStateVar( stateVarInfo: TransformWithStateVariableInfo, - stateStoreColFamilySchema: StateStoreColFamilySchema): StructType = { + stateStoreColFamilySchema: StateStoreColFamilySchema, + stateSourceOptions: StateSourceOptions): StructType = { val stateVarType = stateVarInfo.stateVariableType stateVarType match { case ValueState => new StructType() .add("key", stateStoreColFamilySchema.keySchema) - .add("single_value", stateStoreColFamilySchema.valueSchema) + .add("value", stateStoreColFamilySchema.valueSchema) .add("partition_id", IntegerType) case ListState => - new StructType() - .add("key", stateStoreColFamilySchema.keySchema) - .add("list_value", ArrayType(stateStoreColFamilySchema.valueSchema)) - .add("partition_id", IntegerType) + if (stateSourceOptions.flattenCollectionTypes) { + new StructType() + .add("key", stateStoreColFamilySchema.keySchema) + .add("list_element", stateStoreColFamilySchema.valueSchema) + .add("partition_id", IntegerType) + } else { + new StructType() + .add("key", stateStoreColFamilySchema.keySchema) + .add("list_value", ArrayType(stateStoreColFamilySchema.valueSchema)) + .add("partition_id", IntegerType) + } case MapState => val groupingKeySchema = SchemaUtil.getSchemaAsDataType( @@ -266,9 +313,25 @@ object SchemaUtil { valueType = stateStoreColFamilySchema.valueSchema ) + if (stateSourceOptions.flattenCollectionTypes) { + new StructType() + .add("key", groupingKeySchema) + .add("user_map_key", userKeySchema) + .add("user_map_value", stateStoreColFamilySchema.valueSchema) + .add("partition_id", IntegerType) + } else { + new StructType() + .add("key", groupingKeySchema) + .add("map_value", valueMapSchema) + .add("partition_id", IntegerType) + } + + case TimerState => + val groupingKeySchema = SchemaUtil.getSchemaAsDataType( + stateStoreColFamilySchema.keySchema, "key") new StructType() .add("key", groupingKeySchema) - .add("map_value", valueMapSchema) + .add("expiration_timestamp_ms", LongType) .add("partition_id", IntegerType) case _ => @@ -276,33 +339,29 @@ object SchemaUtil { } } - /** - * Helper functions for map state data source reader. - * - * Map state variables are stored in RocksDB state store has the schema of - * `TransformWithStateKeyValueRowSchemaUtils.getCompositeKeySchema()`; - * But for state store reader, we need to return in format of: - * "key": groupingKey, "map_value": Map(userKey -> value). - * - * The following functions help to translate between two schema. - */ - def isMapStateVariable( - stateVariableInfoOpt: Option[TransformWithStateVariableInfo]): Boolean = { + def checkVariableType( + stateVariableInfoOpt: Option[TransformWithStateVariableInfo], + varType: StateVariableType): Boolean = { stateVariableInfoOpt.isDefined && - stateVariableInfoOpt.get.stateVariableType == MapState + stateVariableInfoOpt.get.stateVariableType == varType } /** * Given key-value schema generated from `generateSchemaForStateVar()`, * returns the compositeKey schema that key is stored in the state store */ - def getCompositeKeySchema(schema: StructType): StructType = { + def getCompositeKeySchema( + schema: StructType, + stateSourceOptions: StateSourceOptions): StructType = { val groupingKeySchema = SchemaUtil.getSchemaAsDataType( schema, "key").asInstanceOf[StructType] val userKeySchema = try { - Option( - SchemaUtil.getSchemaAsDataType(schema, "map_value").asInstanceOf[MapType] + if (stateSourceOptions.flattenCollectionTypes) { + Option(SchemaUtil.getSchemaAsDataType(schema, "user_map_key").asInstanceOf[StructType]) + } else { + Option(SchemaUtil.getSchemaAsDataType(schema, "map_value").asInstanceOf[MapType] .keyType.asInstanceOf[StructType]) + } } catch { case NonFatal(e) => throw StateDataSourceErrors.internalError(s"No such field named as 'map_value' " + @@ -312,4 +371,78 @@ object SchemaUtil { .add("key", groupingKeySchema) .add("userKey", userKeySchema.get) } + + def processStateEntries( + stateVarType: StateVariableType, + stateVarName: String, + store: ReadStateStore, + compositeKeySchema: StructType, + partitionId: Int, + stateSourceOptions: StateSourceOptions): Iterator[InternalRow] = { + stateVarType match { + case StateVariableType.ValueState => + store + .iterator(stateVarName) + .map { pair => + unifyStateRowPair((pair.key, pair.value), partitionId) + } + + case StateVariableType.ListState => + if (stateSourceOptions.flattenCollectionTypes) { + store + .iterator(stateVarName) + .flatMap { pair => + val key = pair.key + val result = store.valuesIterator(key, stateVarName) + result.map { entry => + SchemaUtil.unifyStateRowPair((key, entry), partitionId) + } + } + } else { + store + .iterator(stateVarName) + .map { pair => + val key = pair.key + val result = store.valuesIterator(key, stateVarName) + val unsafeRowArr = ArrayBuffer[UnsafeRow]() + result.foreach { entry => + unsafeRowArr += entry.copy() + } + // convert the list of values to array type + val arrData = new GenericArrayData(unsafeRowArr.toArray) + // convert the list of values to a single row + SchemaUtil.unifyStateRowPairWithMultipleValues((key, arrData), partitionId) + } + } + + case StateVariableType.MapState => + unifyMapStateRowPair(store.iterator(stateVarName), + compositeKeySchema, partitionId, stateSourceOptions) + + case StateVariableType.TimerState => + store + .iterator(stateVarName) + .map { pair => + unifyTimerRow(pair.key, compositeKeySchema, partitionId) + } + + case _ => + throw new IllegalStateException( + s"Unsupported state variable type: $stateVarType") + } + } + + private def unifyTimerRow( + rowKey: UnsafeRow, + groupingKeySchema: StructType, + partitionId: Int): InternalRow = { + val groupingKey = rowKey.get(0, groupingKeySchema).asInstanceOf[UnsafeRow] + val expirationTimestamp = rowKey.getLong(1) + + val row = new GenericInternalRow(3) + row.update(0, groupingKey) + row.update(1, expirationTimestamp) + row.update(2, partitionId) + row + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index e669165f4f2f8..8ec903f8e61da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -70,7 +70,16 @@ case class EnsureRequirements( case (child, distribution) => val numPartitions = distribution.requiredNumPartitions .getOrElse(conf.numShufflePartitions) - ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child, shuffleOrigin) + distribution match { + case _: StatefulOpClusteredDistribution => + ShuffleExchangeExec( + distribution.createPartitioning(numPartitions), child, + REQUIRED_BY_STATEFUL_OPERATOR) + + case _ => + ShuffleExchangeExec( + distribution.createPartitioning(numPartitions), child, shuffleOrigin) + } } // Get the indexes of children which have specified distribution requirements and need to be diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 90f00a5035e15..31a3f53eb7191 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -177,6 +177,11 @@ case object REBALANCE_PARTITIONS_BY_NONE extends ShuffleOrigin // the output needs to be partitioned by the given columns. case object REBALANCE_PARTITIONS_BY_COL extends ShuffleOrigin +// Indicates that the shuffle operator was added by the internal `EnsureRequirements` rule, but +// was required by a stateful operator. The physical partitioning is static and Spark shouldn't +// change it. +case object REQUIRED_BY_STATEFUL_OPERATOR extends ShuffleOrigin + /** * Performs a shuffle that will result in the desired partitioning. */ @@ -249,17 +254,10 @@ case class ShuffleExchangeExec( dep } - /** - * Caches the created ShuffleRowRDD so we can reuse that. - */ - private var cachedShuffleRDD: ShuffledRowRDD = null - protected override def doExecute(): RDD[InternalRow] = { - // Returns the same ShuffleRowRDD if this plan is used by multiple plans. - if (cachedShuffleRDD == null) { - cachedShuffleRDD = new ShuffledRowRDD(shuffleDependency, readMetrics) - } - cachedShuffleRDD + // The ShuffleRowRDD will be cached in SparkPlan.executeRDD and reused if this plan is used by + // multiple plans. + new ShuffledRowRDD(shuffleDependency, readMetrics) } override protected def withNewChildInternal(newChild: SparkPlan): ShuffleExchangeExec = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index 6dd41aca3a5e1..a7292ee1f8fa7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{CodegenSupport, ExplainUtils, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.ArrayImplicits._ @@ -63,13 +64,15 @@ case class BroadcastNestedLoopJoinExec( override def outputPartitioning: Partitioning = (joinType, buildSide) match { case (_: InnerLike, _) | (LeftOuter, BuildRight) | (RightOuter, BuildLeft) | - (LeftSemi, BuildRight) | (LeftAnti, BuildRight) => streamed.outputPartitioning + (LeftSingle, BuildRight) | (LeftSemi, BuildRight) | (LeftAnti, BuildRight) => + streamed.outputPartitioning case _ => super.outputPartitioning } override def outputOrdering: Seq[SortOrder] = (joinType, buildSide) match { case (_: InnerLike, _) | (LeftOuter, BuildRight) | (RightOuter, BuildLeft) | - (LeftSemi, BuildRight) | (LeftAnti, BuildRight) => streamed.outputOrdering + (LeftSingle, BuildRight) | (LeftSemi, BuildRight) | (LeftAnti, BuildRight) => + streamed.outputOrdering case _ => Nil } @@ -87,7 +90,7 @@ case class BroadcastNestedLoopJoinExec( joinType match { case _: InnerLike => left.output ++ right.output - case LeftOuter => + case LeftOuter | LeftSingle => left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output @@ -135,8 +138,14 @@ case class BroadcastNestedLoopJoinExec( * * LeftOuter with BuildRight * RightOuter with BuildLeft + * LeftSingle with BuildRight + * + * For the (LeftSingle, BuildRight) case we pass 'singleJoin' flag that + * makes sure there is at most 1 matching build row per every probe tuple. */ - private def outerJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + private def outerJoin( + relation: Broadcast[Array[InternalRow]], + singleJoin: Boolean = false): RDD[InternalRow] = { streamed.execute().mapPartitionsInternal { streamedIter => val buildRows = relation.value val joinedRow = new JoinedRow @@ -167,6 +176,9 @@ case class BroadcastNestedLoopJoinExec( resultRow = joinedRow(streamRow, buildRows(nextIndex)) nextIndex += 1 if (boundCondition(resultRow)) { + if (foundMatch && singleJoin) { + throw QueryExecutionErrors.scalarSubqueryReturnsMultipleRows(); + } foundMatch = true return true } @@ -382,12 +394,18 @@ case class BroadcastNestedLoopJoinExec( innerJoin(broadcastedRelation) case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) => outerJoin(broadcastedRelation) + case (LeftSingle, BuildRight) => + outerJoin(broadcastedRelation, singleJoin = true) case (LeftSemi, _) => leftExistenceJoin(broadcastedRelation, exists = true) case (LeftAnti, _) => leftExistenceJoin(broadcastedRelation, exists = false) case (_: ExistenceJoin, _) => existenceJoin(broadcastedRelation) + case (LeftSingle, BuildLeft) => + throw new IllegalArgumentException( + s"BroadcastNestedLoopJoin should not use the left side as build when " + + s"executing a LeftSingle join") case _ => /** * LeftOuter with BuildLeft @@ -410,7 +428,7 @@ case class BroadcastNestedLoopJoinExec( override def supportCodegen: Boolean = (joinType, buildSide) match { case (_: InnerLike, _) | (LeftOuter, BuildRight) | (RightOuter, BuildLeft) | - (LeftSemi | LeftAnti, BuildRight) => true + (LeftSemi | LeftAnti, BuildRight) | (LeftSingle, BuildRight) => true case _ => false } @@ -428,6 +446,7 @@ case class BroadcastNestedLoopJoinExec( (joinType, buildSide) match { case (_: InnerLike, _) => codegenInner(ctx, input) case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) => codegenOuter(ctx, input) + case (LeftSingle, BuildRight) => codegenOuter(ctx, input) case (LeftSemi, BuildRight) => codegenLeftExistence(ctx, input, exists = true) case (LeftAnti, BuildRight) => codegenLeftExistence(ctx, input, exists = false) case _ => @@ -473,7 +492,9 @@ case class BroadcastNestedLoopJoinExec( """.stripMargin } - private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = { + private def codegenOuter( + ctx: CodegenContext, + input: Seq[ExprCode]): String = { val (buildRowArray, buildRowArrayTerm) = prepareBroadcast(ctx) val (buildRow, checkCondition, _) = getJoinCondition(ctx, input, streamed, broadcast) val buildVars = genOneSideJoinVars(ctx, buildRow, broadcast, setDefaultValue = true) @@ -494,12 +515,23 @@ case class BroadcastNestedLoopJoinExec( |${consume(ctx, resultVars)} """.stripMargin } else { + // For LeftSingle joins, generate the check on the number of matches. + val evaluateSingleCheck = if (joinType == LeftSingle) { + s""" + |if ($foundMatch) { + | throw QueryExecutionErrors.scalarSubqueryReturnsMultipleRows(); + |} + |""".stripMargin + } else { + "" + } s""" |boolean $foundMatch = false; |for (int $arrayIndex = 0; $arrayIndex < $buildRowArrayTerm.length; $arrayIndex++) { | UnsafeRow $buildRow = (UnsafeRow) $buildRowArrayTerm[$arrayIndex]; | boolean $shouldOutputRow = false; | $checkCondition { + | $evaluateSingleCheck | $shouldOutputRow = true; | $foundMatch = true; | } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 5d59a48d544a0..ce7d48babc91e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{CodegenSupport, ExplainUtils, RowIterator} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.{BooleanType, IntegralType, LongType} @@ -52,7 +53,7 @@ trait HashJoin extends JoinCodegenSupport { joinType match { case _: InnerLike => left.output ++ right.output - case LeftOuter => + case LeftOuter | LeftSingle => left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output @@ -75,7 +76,7 @@ trait HashJoin extends JoinCodegenSupport { } case BuildRight => joinType match { - case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => + case _: InnerLike | LeftOuter | LeftSingle | LeftSemi | LeftAnti | _: ExistenceJoin => left.outputPartitioning case x => throw new IllegalArgumentException( @@ -93,7 +94,7 @@ trait HashJoin extends JoinCodegenSupport { } case BuildRight => joinType match { - case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => + case _: InnerLike | LeftOuter | LeftSingle | LeftSemi | LeftAnti | _: ExistenceJoin => left.outputOrdering case x => throw new IllegalArgumentException( @@ -191,7 +192,8 @@ trait HashJoin extends JoinCodegenSupport { private def outerJoin( streamedIter: Iterator[InternalRow], - hashedRelation: HashedRelation): Iterator[InternalRow] = { + hashedRelation: HashedRelation, + singleJoin: Boolean = false): Iterator[InternalRow] = { val joinedRow = new JoinedRow() val keyGenerator = streamSideKeyGenerator() val nullRow = new GenericInternalRow(buildPlan.output.length) @@ -218,6 +220,9 @@ trait HashJoin extends JoinCodegenSupport { while (buildIter != null && buildIter.hasNext) { val nextBuildRow = buildIter.next() if (boundCondition(joinedRow.withRight(nextBuildRow))) { + if (found && singleJoin) { + throw QueryExecutionErrors.scalarSubqueryReturnsMultipleRows(); + } found = true return true } @@ -329,6 +334,8 @@ trait HashJoin extends JoinCodegenSupport { innerJoin(streamedIter, hashed) case LeftOuter | RightOuter => outerJoin(streamedIter, hashed) + case LeftSingle => + outerJoin(streamedIter, hashed, singleJoin = true) case LeftSemi => semiJoin(streamedIter, hashed) case LeftAnti => @@ -354,7 +361,7 @@ trait HashJoin extends JoinCodegenSupport { override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { joinType match { case _: InnerLike => codegenInner(ctx, input) - case LeftOuter | RightOuter => codegenOuter(ctx, input) + case LeftOuter | RightOuter | LeftSingle => codegenOuter(ctx, input) case LeftSemi => codegenSemi(ctx, input) case LeftAnti => codegenAnti(ctx, input) case _: ExistenceJoin => codegenExistence(ctx, input) @@ -492,6 +499,17 @@ trait HashJoin extends JoinCodegenSupport { val matches = ctx.freshName("matches") val iteratorCls = classOf[Iterator[UnsafeRow]].getName val found = ctx.freshName("found") + // For LeftSingle joins generate the check on the number of build rows that match every + // probe row. Return an error for >1 matches. + val evaluateSingleCheck = if (joinType == LeftSingle) { + s""" + |if ($found) { + | throw QueryExecutionErrors.scalarSubqueryReturnsMultipleRows(); + |} + |""".stripMargin + } else { + "" + } s""" |// generate join key for stream side @@ -505,6 +523,7 @@ trait HashJoin extends JoinCodegenSupport { | (UnsafeRow) $matches.next() : null; | ${checkCondition.trim} | if ($conditionPassed) { + | $evaluateSingleCheck | $found = true; | $numOutput.add(1); | ${consume(ctx, resultVars)} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala index 7c4628c8576c5..60e5a7769a503 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, LeftExistence, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, LeftExistence, LeftOuter, LeftSingle, RightOuter} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning, PartitioningCollection, UnknownPartitioning, UnspecifiedDistribution} /** @@ -47,7 +47,7 @@ trait ShuffledJoin extends JoinCodegenSupport { override def outputPartitioning: Partitioning = joinType match { case _: InnerLike => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) - case LeftOuter => left.outputPartitioning + case LeftOuter | LeftSingle => left.outputPartitioning case RightOuter => right.outputPartitioning case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) case LeftExistence(_) => left.outputPartitioning @@ -60,7 +60,7 @@ trait ShuffledJoin extends JoinCodegenSupport { joinType match { case _: InnerLike => left.output ++ right.output - case LeftOuter => + case LeftOuter | LeftSingle => left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasDeserializer.scala new file mode 100644 index 0000000000000..82d4978853cb6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasDeserializer.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.DataInputStream + +import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ + +import org.apache.arrow.vector.ipc.ArrowStreamReader + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} + +/** + * A helper class to deserialize state Arrow batches from the state socket in + * TransformWithStateInPandas. + */ +class TransformWithStateInPandasDeserializer(deserializer: ExpressionEncoder.Deserializer[Row]) + extends Logging { + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdin reader for transformWithStateInPandas state socket", 0, Long.MaxValue) + + /** + * Read Arrow batches from the given stream and deserialize them into rows. + */ + def readArrowBatches(stream: DataInputStream): Seq[Row] = { + val reader = new ArrowStreamReader(stream, allocator) + val root = reader.getVectorSchemaRoot + val vectors = root.getFieldVectors.asScala.map { vector => + new ArrowColumnVector(vector) + }.toArray[ColumnVector] + val rows = ArrayBuffer[Row]() + while (reader.loadNextBatch()) { + val batch = new ColumnarBatch(vectors) + batch.setNumRows(root.getRowCount) + rows.appendAll(batch.rowIterator().asScala.map(r => deserializer(r.copy()))) + } + reader.close(false) + rows.toSeq + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala index 7d0c177d1df8f..b4b516ba9e5a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala @@ -103,7 +103,8 @@ class TransformWithStateInPandasPythonRunner( executionContext.execute( new TransformWithStateInPandasStateServer(stateServerSocket, processorHandle, - groupingKeySchema)) + groupingKeySchema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes, + sqlConf.arrowTransformWithStateInPandasMaxRecordsPerBatch)) context.addTaskCompletionListener[Unit] { _ => logInfo(log"completion listener called") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala index b5ec26b401d28..fed1843acfa56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala @@ -24,15 +24,18 @@ import java.time.Duration import scala.collection.mutable import com.google.protobuf.ByteString +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowStreamWriter import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.sql.{Encoders, Row} import org.apache.spark.sql.api.python.PythonSQLUtils import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState} -import org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, ImplicitGroupingKeyRequest, StatefulProcessorCall, StateRequest, StateResponse, StateVariableRequest, ValueStateCall} -import org.apache.spark.sql.streaming.{TTLConfig, ValueState} +import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState, StateVariableType} +import org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, ImplicitGroupingKeyRequest, ListStateCall, StatefulProcessorCall, StateRequest, StateResponse, StateVariableRequest, ValueStateCall} +import org.apache.spark.sql.streaming.{ListState, TTLConfig, ValueState} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils /** * This class is used to handle the state requests from the Python side. It runs on a separate @@ -48,9 +51,16 @@ class TransformWithStateInPandasStateServer( stateServerSocket: ServerSocket, statefulProcessorHandle: StatefulProcessorHandleImpl, groupingKeySchema: StructType, + timeZoneId: String, + errorOnDuplicatedFieldNames: Boolean, + largeVarTypes: Boolean, + arrowTransformWithStateInPandasMaxRecordsPerBatch: Int, outputStreamForTest: DataOutputStream = null, - valueStateMapForTest: mutable.HashMap[String, - (ValueState[Row], StructType, ExpressionEncoder.Deserializer[Row])] = null) + valueStateMapForTest: mutable.HashMap[String, ValueStateInfo] = null, + deserializerForTest: TransformWithStateInPandasDeserializer = null, + arrowStreamWriterForTest: BaseStreamingArrowWriter = null, + listStatesMapForTest : mutable.HashMap[String, ListStateInfo] = null, + listStateIteratorMapForTest: mutable.HashMap[String, Iterator[Row]] = null) extends Runnable with Logging { private val keyRowDeserializer: ExpressionEncoder.Deserializer[Row] = ExpressionEncoder(groupingKeySchema).resolveAndBind().createDeserializer() @@ -60,8 +70,22 @@ class TransformWithStateInPandasStateServer( private val valueStates = if (valueStateMapForTest != null) { valueStateMapForTest } else { - new mutable.HashMap[String, (ValueState[Row], StructType, - ExpressionEncoder.Deserializer[Row])]() + new mutable.HashMap[String, ValueStateInfo]() + } + // A map to store the list state name -> (list state, schema, list state row deserializer, + // list state row serializer) mapping. + private val listStates = if (listStatesMapForTest != null) { + listStatesMapForTest + } else { + new mutable.HashMap[String, ListStateInfo]() + } + // A map to store the iterator id -> iterator mapping. This is to keep track of the + // current iterator position for each list state in a grouping key in case user tries to fetch + // another list state before the current iterator is exhausted. + private var listStateIterators = if (listStateIteratorMapForTest != null) { + listStateIteratorMapForTest + } else { + new mutable.HashMap[String, Iterator[Row]]() } def run(): Unit = { @@ -125,9 +149,13 @@ class TransformWithStateInPandasStateServer( // The key row is serialized as a byte array, we need to convert it back to a Row val keyRow = PythonSQLUtils.toJVMRow(keyBytes, groupingKeySchema, keyRowDeserializer) ImplicitGroupingKeyTracker.setImplicitKey(keyRow) + // Reset the list state iterators for a new grouping key. + listStateIterators = new mutable.HashMap[String, Iterator[Row]]() sendResponse(0) case ImplicitGroupingKeyRequest.MethodCase.REMOVEIMPLICITKEY => ImplicitGroupingKeyTracker.removeImplicitKey() + // Reset the list state iterators for a new grouping key. + listStateIterators = new mutable.HashMap[String, Iterator[Row]]() sendResponse(0) case _ => throw new IllegalArgumentException("Invalid method call") @@ -157,7 +185,16 @@ class TransformWithStateInPandasStateServer( val ttlDurationMs = if (message.getGetValueState.hasTtl) { Some(message.getGetValueState.getTtl.getDurationMs) } else None - initializeValueState(stateName, schema, ttlDurationMs) + initializeStateVariable(stateName, schema, StateVariableType.ValueState, ttlDurationMs) + case StatefulProcessorCall.MethodCase.GETLISTSTATE => + val stateName = message.getGetListState.getStateName + val schema = message.getGetListState.getSchema + val ttlDurationMs = if (message.getGetListState.hasTtl) { + Some(message.getGetListState.getTtl.getDurationMs) + } else { + None + } + initializeStateVariable(stateName, schema, StateVariableType.ListState, ttlDurationMs) case _ => throw new IllegalArgumentException("Invalid method call") } @@ -167,6 +204,8 @@ class TransformWithStateInPandasStateServer( message.getMethodCase match { case StateVariableRequest.MethodCase.VALUESTATECALL => handleValueStateRequest(message.getValueStateCall) + case StateVariableRequest.MethodCase.LISTSTATECALL => + handleListStateRequest(message.getListStateCall) case _ => throw new IllegalArgumentException("Invalid method call") } @@ -179,16 +218,17 @@ class TransformWithStateInPandasStateServer( sendResponse(1, s"Value state $stateName is not initialized.") return } + val valueStateInfo = valueStates(stateName) message.getMethodCase match { case ValueStateCall.MethodCase.EXISTS => - if (valueStates(stateName)._1.exists()) { + if (valueStateInfo.valueState.exists()) { sendResponse(0) } else { // Send status code 2 to indicate that the value state doesn't have a value yet. sendResponse(2, s"state $stateName doesn't exist") } case ValueStateCall.MethodCase.GET => - val valueOption = valueStates(stateName)._1.getOption() + val valueOption = valueStateInfo.valueState.getOption() if (valueOption.isDefined) { // Serialize the value row as a byte array val valueBytes = PythonSQLUtils.toPyRow(valueOption.get) @@ -201,13 +241,95 @@ class TransformWithStateInPandasStateServer( } case ValueStateCall.MethodCase.VALUESTATEUPDATE => val byteArray = message.getValueStateUpdate.getValue.toByteArray - val valueStateTuple = valueStates(stateName) // The value row is serialized as a byte array, we need to convert it back to a Row - val valueRow = PythonSQLUtils.toJVMRow(byteArray, valueStateTuple._2, valueStateTuple._3) - valueStateTuple._1.update(valueRow) + val valueRow = PythonSQLUtils.toJVMRow(byteArray, valueStateInfo.schema, + valueStateInfo.deserializer) + valueStateInfo.valueState.update(valueRow) sendResponse(0) case ValueStateCall.MethodCase.CLEAR => - valueStates(stateName)._1.clear() + valueStateInfo.valueState.clear() + sendResponse(0) + case _ => + throw new IllegalArgumentException("Invalid method call") + } + } + + private[sql] def handleListStateRequest(message: ListStateCall): Unit = { + val stateName = message.getStateName + if (!listStates.contains(stateName)) { + logWarning(log"List state ${MDC(LogKeys.STATE_NAME, stateName)} is not initialized.") + sendResponse(1, s"List state $stateName is not initialized.") + return + } + val listStateInfo = listStates(stateName) + val deserializer = if (deserializerForTest != null) { + deserializerForTest + } else { + new TransformWithStateInPandasDeserializer(listStateInfo.deserializer) + } + message.getMethodCase match { + case ListStateCall.MethodCase.EXISTS => + if (listStateInfo.listState.exists()) { + sendResponse(0) + } else { + // Send status code 2 to indicate that the list state doesn't have a value yet. + sendResponse(2, s"state $stateName doesn't exist") + } + case ListStateCall.MethodCase.LISTSTATEPUT => + val rows = deserializer.readArrowBatches(inputStream) + listStateInfo.listState.put(rows.toArray) + sendResponse(0) + case ListStateCall.MethodCase.LISTSTATEGET => + val iteratorId = message.getListStateGet.getIteratorId + var iteratorOption = listStateIterators.get(iteratorId) + if (iteratorOption.isEmpty) { + iteratorOption = Some(listStateInfo.listState.get()) + listStateIterators.put(iteratorId, iteratorOption.get) + } + if (!iteratorOption.get.hasNext) { + sendResponse(2, s"List state $stateName doesn't contain any value.") + return + } else { + sendResponse(0) + } + outputStream.flush() + val arrowStreamWriter = if (arrowStreamWriterForTest != null) { + arrowStreamWriterForTest + } else { + val arrowSchema = ArrowUtils.toArrowSchema(listStateInfo.schema, timeZoneId, + errorOnDuplicatedFieldNames, largeVarTypes) + val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdout writer for transformWithStateInPandas state socket", 0, Long.MaxValue) + val root = VectorSchemaRoot.create(arrowSchema, allocator) + new BaseStreamingArrowWriter(root, new ArrowStreamWriter(root, null, outputStream), + arrowTransformWithStateInPandasMaxRecordsPerBatch) + } + val listRowSerializer = listStateInfo.serializer + // Only write a single batch in each GET request. Stops writing row if rowCount reaches + // the arrowTransformWithStateInPandasMaxRecordsPerBatch limit. This is to handle a case + // when there are multiple state variables, user tries to access a different state variable + // while the current state variable is not exhausted yet. + var rowCount = 0 + while (iteratorOption.get.hasNext && + rowCount < arrowTransformWithStateInPandasMaxRecordsPerBatch) { + val row = iteratorOption.get.next() + val internalRow = listRowSerializer(row) + arrowStreamWriter.writeRow(internalRow) + rowCount += 1 + } + arrowStreamWriter.finalizeCurrentArrowBatch() + case ListStateCall.MethodCase.APPENDVALUE => + val byteArray = message.getAppendValue.getValue.toByteArray + val newRow = PythonSQLUtils.toJVMRow(byteArray, listStateInfo.schema, + listStateInfo.deserializer) + listStateInfo.listState.appendValue(newRow) + sendResponse(0) + case ListStateCall.MethodCase.APPENDLIST => + val rows = deserializer.readArrowBatches(inputStream) + listStateInfo.listState.appendList(rows.toArray) + sendResponse(0) + case ListStateCall.MethodCase.CLEAR => + listStates(stateName).listState.clear() sendResponse(0) case _ => throw new IllegalArgumentException("Invalid method call") @@ -232,23 +354,58 @@ class TransformWithStateInPandasStateServer( outputStream.write(responseMessageBytes) } - private def initializeValueState( + private def initializeStateVariable( stateName: String, schemaString: String, + stateType: StateVariableType.StateVariableType, ttlDurationMs: Option[Int]): Unit = { - if (!valueStates.contains(stateName)) { - val schema = StructType.fromString(schemaString) - val state = if (ttlDurationMs.isEmpty) { - statefulProcessorHandle.getValueState[Row](stateName, Encoders.row(schema)) - } else { - statefulProcessorHandle.getValueState( - stateName, Encoders.row(schema), TTLConfig(Duration.ofMillis(ttlDurationMs.get))) - } - val valueRowDeserializer = ExpressionEncoder(schema).resolveAndBind().createDeserializer() - valueStates.put(stateName, (state, schema, valueRowDeserializer)) - sendResponse(0) - } else { - sendResponse(1, s"state $stateName already exists") + val schema = StructType.fromString(schemaString) + val expressionEncoder = ExpressionEncoder(schema).resolveAndBind() + stateType match { + case StateVariableType.ValueState => if (!valueStates.contains(stateName)) { + val state = if (ttlDurationMs.isEmpty) { + statefulProcessorHandle.getValueState[Row](stateName, Encoders.row(schema)) + } else { + statefulProcessorHandle.getValueState( + stateName, Encoders.row(schema), TTLConfig(Duration.ofMillis(ttlDurationMs.get))) + } + valueStates.put(stateName, + ValueStateInfo(state, schema, expressionEncoder.createDeserializer())) + sendResponse(0) + } else { + sendResponse(1, s"Value state $stateName already exists") + } + case StateVariableType.ListState => if (!listStates.contains(stateName)) { + val state = if (ttlDurationMs.isEmpty) { + statefulProcessorHandle.getListState[Row](stateName, Encoders.row(schema)) + } else { + statefulProcessorHandle.getListState( + stateName, Encoders.row(schema), TTLConfig(Duration.ofMillis(ttlDurationMs.get))) + } + listStates.put(stateName, + ListStateInfo(state, schema, expressionEncoder.createDeserializer(), + expressionEncoder.createSerializer())) + sendResponse(0) + } else { + sendResponse(1, s"List state $stateName already exists") + } } } } + +/** + * Case class to store the information of a value state. + */ +case class ValueStateInfo( + valueState: ValueState[Row], + schema: StructType, + deserializer: ExpressionEncoder.Deserializer[Row]) + +/** + * Case class to store the information of a list state. + */ +case class ListStateInfo( + listState: ListState[Row], + schema: StructType, + deserializer: ExpressionEncoder.Deserializer[Row], + serializer: ExpressionEncoder.Serializer[Row]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index d5facc245e72f..e1e5b3a7ef88e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -101,7 +101,9 @@ object OffsetSeqMetadata extends Logging { SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY, FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION, STREAMING_JOIN_STATE_FORMAT_VERSION, STATE_STORE_COMPRESSION_CODEC, - STATE_STORE_ROCKSDB_FORMAT_VERSION, STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION) + STATE_STORE_ROCKSDB_FORMAT_VERSION, STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION, + PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN + ) /** * Default values of relevant configurations that are used for backward compatibility. @@ -122,7 +124,8 @@ object OffsetSeqMetadata extends Logging { STREAMING_JOIN_STATE_FORMAT_VERSION.key -> SymmetricHashJoinStateManager.legacyVersion.toString, STATE_STORE_COMPRESSION_CODEC.key -> CompressionCodec.LZ4, - STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION.key -> "false" + STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION.key -> "false", + PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN.key -> "true" ) def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala index 99229c6132eb2..7da8408f98b0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala @@ -20,6 +20,7 @@ import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._ import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateStoreColFamilySchema} +import org.apache.spark.sql.types.StructType object StateStoreColumnFamilySchemaUtils { @@ -61,4 +62,15 @@ object StateStoreColumnFamilySchemaUtils { Some(PrefixKeyScanStateEncoderSpec(compositeKeySchema, 1)), Some(userKeyEnc.schema)) } + + def getTimerStateSchema( + stateName: String, + keySchema: StructType, + valSchema: StructType): StateStoreColFamilySchema = { + StateStoreColFamilySchema( + stateName, + keySchema, + valSchema, + Some(PrefixKeyScanStateEncoderSpec(keySchema, 1))) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala index 1f5ad2fc85470..b70f9699195d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala @@ -288,6 +288,9 @@ class TimerKeyEncoder(keyExprEnc: ExpressionEncoder[Any]) { .add("key", new StructType(keyExprEnc.schema.fields)) .add("expiryTimestampMs", LongType, nullable = false) + val schemaForValueRow: StructType = + StructType(Array(StructField("__dummy__", NullType))) + private val keySerializer = keyExprEnc.createSerializer() private val keyDeserializer = keyExprEnc.resolveAndBind().createDeserializer() private val prefixKeyProjection = UnsafeProjection.create(schemaForPrefixKey) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 942d395dec0e2..8beacbec7e6ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -308,6 +308,12 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi private val stateVariableInfos: mutable.Map[String, TransformWithStateVariableInfo] = new mutable.HashMap[String, TransformWithStateVariableInfo]() + // If timeMode is not None, add a timer column family schema to the operator metadata so that + // registered timers can be read using the state data source reader. + if (timeMode != TimeMode.None()) { + addTimerColFamily() + } + def getColumnFamilySchemas: Map[String, StateStoreColFamilySchema] = columnFamilySchemas.toMap def getStateVariableInfos: Map[String, TransformWithStateVariableInfo] = stateVariableInfos.toMap @@ -318,6 +324,16 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi } } + private def addTimerColFamily(): Unit = { + val stateName = TimerStateUtils.getTimerStateVarName(timeMode.toString) + val timerEncoder = new TimerKeyEncoder(keyExprEnc) + val colFamilySchema = StateStoreColumnFamilySchemaUtils. + getTimerStateSchema(stateName, timerEncoder.schemaForKeyRow, timerEncoder.schemaForValueRow) + columnFamilySchemas.put(stateName, colFamilySchema) + val stateVariableInfo = TransformWithStateVariableUtils.getTimerState(stateName) + stateVariableInfos.put(stateName, stateVariableInfo) + } + override def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] = { verifyStateVarOperations("get_value_state", PRE_INIT) val colFamilySchema = StateStoreColumnFamilySchemaUtils. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 8f030884ad33b..14adf951f07e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -374,7 +374,7 @@ abstract class StreamExecution( "message" -> message)) errorClassOpt = e match { - case t: SparkThrowable => Option(t.getErrorClass) + case t: SparkThrowable => Option(t.getCondition) case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala index 82a4226fcfd54..d0fbaf6600609 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala @@ -34,6 +34,15 @@ object TimerStateUtils { val EVENT_TIMERS_STATE_NAME = "$eventTimers" val KEY_TO_TIMESTAMP_CF = "_keyToTimestamp" val TIMESTAMP_TO_KEY_CF = "_timestampToKey" + + def getTimerStateVarName(timeMode: String): String = { + assert(timeMode == TimeMode.EventTime.toString || timeMode == TimeMode.ProcessingTime.toString) + if (timeMode == TimeMode.EventTime.toString) { + TimerStateUtils.EVENT_TIMERS_STATE_NAME + TimerStateUtils.KEY_TO_TIMESTAMP_CF + } else { + TimerStateUtils.PROC_TIMERS_STATE_NAME + TimerStateUtils.KEY_TO_TIMESTAMP_CF + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala index 0a32564f973a3..4a192b3e51c71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala @@ -43,12 +43,16 @@ object TransformWithStateVariableUtils { def getMapState(stateName: String, ttlEnabled: Boolean): TransformWithStateVariableInfo = { TransformWithStateVariableInfo(stateName, StateVariableType.MapState, ttlEnabled) } + + def getTimerState(stateName: String): TransformWithStateVariableInfo = { + TransformWithStateVariableInfo(stateName, StateVariableType.TimerState, ttlEnabled = false) + } } // Enum of possible State Variable types object StateVariableType extends Enumeration { type StateVariableType = Value - val ValueState, ListState, MapState = Value + val ValueState, ListState, MapState, TimerState = Value } case class TransformWithStateVariableInfo( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index d9f4443b79618..3df63c41dbf97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -282,7 +282,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with newMap } catch { - case e: SparkException if e.getErrorClass.contains("CANNOT_LOAD_STATE_STORE") => + case e: SparkException if e.getCondition.contains("CANNOT_LOAD_STATE_STORE") => throw e case e: OutOfMemoryError => throw QueryExecutionErrors.notEnoughMemoryToLoadStore( @@ -991,8 +991,16 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with result } - override def getStateStoreChangeDataReader(startVersion: Long, endVersion: Long): + override def getStateStoreChangeDataReader( + startVersion: Long, + endVersion: Long, + colFamilyNameOpt: Option[String] = None): StateStoreChangeDataReader = { + // Multiple column families are not supported with HDFSBackedStateStoreProvider + if (colFamilyNameOpt.isDefined) { + throw StateStoreErrors.multipleColumnFamiliesNotSupported(providerName) + } + new HDFSBackedStateStoreChangeDataReader(fm, baseDir, startVersion, endVersion, CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec), keySchema, valueSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 81e80629092a0..c7f8434e5345b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -115,37 +115,35 @@ class RocksDB( tableFormatConfig.setPinL0FilterAndIndexBlocksInCache(true) } - private[state] val columnFamilyOptions = new ColumnFamilyOptions() + private[state] val rocksDbOptions = new Options() // options to open the RocksDB + + rocksDbOptions.setCreateIfMissing(true) // Set RocksDB options around MemTable memory usage. By default, we let RocksDB // use its internal default values for these settings. if (conf.writeBufferSizeMB > 0L) { - columnFamilyOptions.setWriteBufferSize(conf.writeBufferSizeMB * 1024 * 1024) + rocksDbOptions.setWriteBufferSize(conf.writeBufferSizeMB * 1024 * 1024) } if (conf.maxWriteBufferNumber > 0L) { - columnFamilyOptions.setMaxWriteBufferNumber(conf.maxWriteBufferNumber) + rocksDbOptions.setMaxWriteBufferNumber(conf.maxWriteBufferNumber) } - columnFamilyOptions.setCompressionType(getCompressionType(conf.compression)) - columnFamilyOptions.setMergeOperator(new StringAppendOperator()) - - private val dbOptions = - new Options(new DBOptions(), columnFamilyOptions) // options to open the RocksDB + rocksDbOptions.setCompressionType(getCompressionType(conf.compression)) - dbOptions.setCreateIfMissing(true) - dbOptions.setTableFormatConfig(tableFormatConfig) - dbOptions.setMaxOpenFiles(conf.maxOpenFiles) - dbOptions.setAllowFAllocate(conf.allowFAllocate) - dbOptions.setMergeOperator(new StringAppendOperator()) + rocksDbOptions.setTableFormatConfig(tableFormatConfig) + rocksDbOptions.setMaxOpenFiles(conf.maxOpenFiles) + rocksDbOptions.setAllowFAllocate(conf.allowFAllocate) + rocksDbOptions.setAvoidFlushDuringShutdown(true) + rocksDbOptions.setMergeOperator(new StringAppendOperator()) if (conf.boundedMemoryUsage) { - dbOptions.setWriteBufferManager(writeBufferManager) + rocksDbOptions.setWriteBufferManager(writeBufferManager) } private val dbLogger = createLogger() // for forwarding RocksDB native logs to log4j - dbOptions.setStatistics(new Statistics()) - private val nativeStats = dbOptions.statistics() + rocksDbOptions.setStatistics(new Statistics()) + private val nativeStats = rocksDbOptions.statistics() private val workingDir = createTempDir("workingDir") private val fileManager = new RocksDBFileManager(dfsRootDir, createTempDir("fileManager"), @@ -646,15 +644,15 @@ class RocksDB( // is enabled. if (shouldForceSnapshot.get()) { uploadSnapshot() + shouldForceSnapshot.set(false) + } + + // ensure that changelog files are always written + try { + assert(changelogWriter.isDefined) + changelogWriter.foreach(_.commit()) + } finally { changelogWriter = None - changelogWriter.foreach(_.abort()) - } else { - try { - assert(changelogWriter.isDefined) - changelogWriter.foreach(_.commit()) - } finally { - changelogWriter = None - } } } else { assert(changelogWriter.isEmpty) @@ -782,7 +780,7 @@ class RocksDB( readOptions.close() writeOptions.close() flushOptions.close() - dbOptions.close() + rocksDbOptions.close() dbLogger.close() synchronized { latestSnapshot.foreach(_.close()) @@ -941,7 +939,7 @@ class RocksDB( private def openDB(): Unit = { assert(db == null) - db = NativeRocksDB.open(dbOptions, workingDir.toString) + db = NativeRocksDB.open(rocksDbOptions, workingDir.toString) logInfo(log"Opened DB with conf ${MDC(LogKeys.CONFIG, conf)}") } @@ -962,7 +960,7 @@ class RocksDB( /** Create a native RocksDB logger that forwards native logs to log4j with correct log levels. */ private def createLogger(): Logger = { - val dbLogger = new Logger(dbOptions.infoLogLevel()) { + val dbLogger = new Logger(rocksDbOptions.infoLogLevel()) { override def log(infoLogLevel: InfoLogLevel, logMsg: String) = { // Map DB log level to log4j levels // Warn is mapped to info because RocksDB warn is too verbose @@ -985,8 +983,8 @@ class RocksDB( dbLogger.setInfoLogLevel(dbLogLevel) // The log level set in dbLogger is effective and the one to dbOptions isn't applied to // customized logger. We still set it as it might show up in RocksDB config file or logging. - dbOptions.setInfoLogLevel(dbLogLevel) - dbOptions.setLogger(dbLogger) + rocksDbOptions.setInfoLogLevel(dbLogLevel) + rocksDbOptions.setLogger(dbLogger) logInfo(log"Set RocksDB native logging level to ${MDC(LogKeys.ROCKS_DB_LOG_LEVEL, dbLogLevel)}") dbLogger } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 85f80ce9eb1ae..870ed79ec1747 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -389,7 +389,7 @@ private[sql] class RocksDBStateStoreProvider new RocksDBStateStore(version) } catch { - case e: SparkException if e.getErrorClass.contains("CANNOT_LOAD_STATE_STORE") => + case e: SparkException if e.getCondition.contains("CANNOT_LOAD_STATE_STORE") => throw e case e: OutOfMemoryError => throw QueryExecutionErrors.notEnoughMemoryToLoadStore( @@ -409,7 +409,7 @@ private[sql] class RocksDBStateStoreProvider new RocksDBStateStore(version) } catch { - case e: SparkException if e.getErrorClass.contains("CANNOT_LOAD_STATE_STORE") => + case e: SparkException if e.getCondition.contains("CANNOT_LOAD_STATE_STORE") => throw e case e: OutOfMemoryError => throw QueryExecutionErrors.notEnoughMemoryToLoadStore( @@ -498,7 +498,10 @@ private[sql] class RocksDBStateStoreProvider } } - override def getStateStoreChangeDataReader(startVersion: Long, endVersion: Long): + override def getStateStoreChangeDataReader( + startVersion: Long, + endVersion: Long, + colFamilyNameOpt: Option[String] = None): StateStoreChangeDataReader = { val statePath = stateStoreId.storeCheckpointLocation() val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) @@ -508,7 +511,8 @@ private[sql] class RocksDBStateStoreProvider startVersion, endVersion, CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec), - keyValueEncoderMap) + keyValueEncoderMap, + colFamilyNameOpt) } /** @@ -676,27 +680,70 @@ class RocksDBStateStoreChangeDataReader( endVersion: Long, compressionCodec: CompressionCodec, keyValueEncoderMap: - ConcurrentHashMap[String, (RocksDBKeyStateEncoder, RocksDBValueStateEncoder)]) + ConcurrentHashMap[String, (RocksDBKeyStateEncoder, RocksDBValueStateEncoder)], + colFamilyNameOpt: Option[String] = None) extends StateStoreChangeDataReader( - fm, stateLocation, startVersion, endVersion, compressionCodec) { + fm, stateLocation, startVersion, endVersion, compressionCodec, colFamilyNameOpt) { override protected var changelogSuffix: String = "changelog" + private def getColFamilyIdBytes: Option[Array[Byte]] = { + if (colFamilyNameOpt.isDefined) { + val colFamilyName = colFamilyNameOpt.get + if (!keyValueEncoderMap.containsKey(colFamilyName)) { + throw new IllegalStateException( + s"Column family $colFamilyName not found in the key value encoder map") + } + Some(keyValueEncoderMap.get(colFamilyName)._1.getColumnFamilyIdBytes()) + } else { + None + } + } + + private val colFamilyIdBytesOpt: Option[Array[Byte]] = getColFamilyIdBytes + override def getNext(): (RecordType.Value, UnsafeRow, UnsafeRow, Long) = { - val reader = currentChangelogReader() - if (reader == null) { - return null + var currRecord: (RecordType.Value, Array[Byte], Array[Byte]) = null + val currEncoder: (RocksDBKeyStateEncoder, RocksDBValueStateEncoder) = + keyValueEncoderMap.get(colFamilyNameOpt + .getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME)) + + if (colFamilyIdBytesOpt.isDefined) { + // If we are reading records for a particular column family, the corresponding vcf id + // will be encoded in the key byte array. We need to extract that and compare for the + // expected column family id. If it matches, we return the record. If not, we move to + // the next record. Note that this has be handled across multiple changelog files and we + // rely on the currentChangelogReader to move to the next changelog file when needed. + while (currRecord == null) { + val reader = currentChangelogReader() + if (reader == null) { + return null + } + + val nextRecord = reader.next() + val colFamilyIdBytes: Array[Byte] = colFamilyIdBytesOpt.get + val endIndex = colFamilyIdBytes.size + // Function checks for byte arrays being equal + // from index 0 to endIndex - 1 (both inclusive) + if (java.util.Arrays.equals(nextRecord._2, 0, endIndex, + colFamilyIdBytes, 0, endIndex)) { + currRecord = nextRecord + } + } + } else { + val reader = currentChangelogReader() + if (reader == null) { + return null + } + currRecord = reader.next() } - val (recordType, keyArray, valueArray) = reader.next() - // Todo: does not support multiple virtual column families - val (rocksDBKeyStateEncoder, rocksDBValueStateEncoder) = - keyValueEncoderMap.get(StateStore.DEFAULT_COL_FAMILY_NAME) - val keyRow = rocksDBKeyStateEncoder.decodeKey(keyArray) - if (valueArray == null) { - (recordType, keyRow, null, currentChangelogVersion - 1) + + val keyRow = currEncoder._1.decodeKey(currRecord._2) + if (currRecord._3 == null) { + (currRecord._1, keyRow, null, currentChangelogVersion - 1) } else { - val valueRow = rocksDBValueStateEncoder.decodeValue(valueArray) - (recordType, keyRow, valueRow, currentChangelogVersion - 1) + val valueRow = currEncoder._2.decodeValue(currRecord._3) + (currRecord._1, keyRow, valueRow, currentChangelogVersion - 1) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index d55a973a14e16..6e616cc71a80c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -519,10 +519,14 @@ trait SupportsFineGrainedReplay { * * @param startVersion starting changelog version * @param endVersion ending changelog version + * @param colFamilyNameOpt optional column family name to read from * @return iterator that gives tuple(recordType: [[RecordType.Value]], nested key: [[UnsafeRow]], * nested value: [[UnsafeRow]], batchId: [[Long]]) */ - def getStateStoreChangeDataReader(startVersion: Long, endVersion: Long): + def getStateStoreChangeDataReader( + startVersion: Long, + endVersion: Long, + colFamilyNameOpt: Option[String] = None): NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala index 651d72da16095..e89550da37e03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala @@ -397,13 +397,15 @@ class StateStoreChangelogReaderV2( * @param startVersion start version of the changelog file to read * @param endVersion end version of the changelog file to read * @param compressionCodec de-compression method using for reading changelog file + * @param colFamilyNameOpt optional column family name to read from */ abstract class StateStoreChangeDataReader( fm: CheckpointFileManager, stateLocation: Path, startVersion: Long, endVersion: Long, - compressionCodec: CompressionCodec) + compressionCodec: CompressionCodec, + colFamilyNameOpt: Option[String] = None) extends NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)] with Logging { assert(startVersion >= 1) @@ -451,9 +453,12 @@ abstract class StateStoreChangeDataReader( finished = true return null } - // Todo: Does not support StateStoreChangelogReaderV2 - changelogReader = + + changelogReader = if (colFamilyNameOpt.isDefined) { + new StateStoreChangelogReaderV2(fm, fileIterator.next(), compressionCodec) + } else { new StateStoreChangelogReaderV1(fm, fileIterator.next(), compressionCodec) + } } changelogReader } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index a2539828733fc..0d0258f11efb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.internal import org.apache.spark.annotation.Unstable import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _} import org.apache.spark.sql.artifact.ArtifactManager -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, ReplaceCharWithVarchar, ResolveSessionCatalog, ResolveTranspose, TableFunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, InvokeProcedures, ReplaceCharWithVarchar, ResolveSessionCatalog, ResolveTranspose, TableFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{FunctionExpressionBuilder, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.Optimizer @@ -206,6 +206,7 @@ abstract class BaseSessionStateBuilder( ResolveWriteToStream +: new EvalSubqueriesForTimeTravel +: new ResolveTranspose(session) +: + new InvokeProcedures(session) +: customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 52b8d35e2fbf8..64689e75e2e5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -177,7 +177,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { try { Some(makeTable(catalogName +: ns :+ tableName)) } catch { - case e: AnalysisException if e.getErrorClass == "UNSUPPORTED_FEATURE.HIVE_TABLE_TYPE" => + case e: AnalysisException if e.getCondition == "UNSUPPORTED_FEATURE.HIVE_TABLE_TYPE" => Some(new Table( name = tableName, catalog = catalogName, @@ -189,7 +189,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } } } catch { - case e: AnalysisException if e.getErrorClass == "TABLE_OR_VIEW_NOT_FOUND" => None + case e: AnalysisException if e.getCondition == "TABLE_OR_VIEW_NOT_FOUND" => None } } @@ -203,7 +203,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { case _ => false } } catch { - case e: AnalysisException if e.getErrorClass == "TABLE_OR_VIEW_NOT_FOUND" => false + case e: AnalysisException if e.getCondition == "TABLE_OR_VIEW_NOT_FOUND" => false } } @@ -323,7 +323,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { case _ => false } } catch { - case e: AnalysisException if e.getErrorClass == "UNRESOLVED_ROUTINE" => false + case e: AnalysisException if e.getCondition == "UNRESOLVED_ROUTINE" => false } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala index f0eef9ae1cbb0..8164d33f46fee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala @@ -429,7 +429,7 @@ final class DataFrameWriterImpl[T] private[sql](ds: Dataset[T]) extends DataFram val canUseV2 = lookupV2Provider().isDefined || (df.sparkSession.sessionState.conf.getConf( SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isDefined && !df.sparkSession.sessionState.catalogManager.catalog(CatalogManager.SESSION_CATALOG_NAME) - .isInstanceOf[DelegatingCatalogExtension]) + .isInstanceOf[CatalogExtension]) session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match { case nameParts @ NonSessionCatalogAndIdentifier(catalog, ident) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala index ca439cdb89958..0ef879387727a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala @@ -42,7 +42,7 @@ class RuntimeConfigImpl private[sql](val sqlConf: SQLConf = new SQLConf) extends } /** @inheritdoc */ - @throws[NoSuchElementException]("if the key is not set") + @throws[NoSuchElementException]("if the key is not set and there is no default value") def get(key: String): String = { sqlConf.getConfString(key) } @@ -84,7 +84,7 @@ class RuntimeConfigImpl private[sql](val sqlConf: SQLConf = new SQLConf) extends sqlConf.contains(key) } - private def requireNonStaticConf(key: String): Unit = { + private[sql] def requireNonStaticConf(key: String): Unit = { if (SQLConf.isStaticConfigKey(key)) { throw QueryCompilationErrors.cannotModifyValueOfStaticConfigError(key) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala index 920c0371292c9..476956e58e8e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala @@ -54,8 +54,8 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres case Literal(value, None, _) => expressions.Literal(value) - case UnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn, _) => - convertUnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn) + case UnresolvedAttribute(nameParts, planId, isMetadataColumn, _) => + convertUnresolvedAttribute(nameParts, planId, isMetadataColumn) case UnresolvedStar(unparsedTarget, None, _) => val target = unparsedTarget.map { t => @@ -74,7 +74,7 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres analysis.UnresolvedRegex(columnNameRegex, Some(nameParts), conf.caseSensitiveAnalysis) case UnresolvedRegex(unparsedIdentifier, planId, _) => - convertUnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn = false) + convertUnresolvedRegex(unparsedIdentifier, planId) case UnresolvedFunction(functionName, arguments, isDistinct, isUDF, isInternal, _) => val nameParts = if (isUDF) { @@ -223,10 +223,10 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres } private def convertUnresolvedAttribute( - unparsedIdentifier: String, + nameParts: Seq[String], planId: Option[Long], isMetadataColumn: Boolean): analysis.UnresolvedAttribute = { - val attribute = analysis.UnresolvedAttribute.quotedString(unparsedIdentifier) + val attribute = analysis.UnresolvedAttribute(nameParts) if (planId.isDefined) { attribute.setTagValue(LogicalPlan.PLAN_ID_TAG, planId.get) } @@ -235,6 +235,16 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres } attribute } + + private def convertUnresolvedRegex( + unparsedIdentifier: String, + planId: Option[Long]): analysis.UnresolvedAttribute = { + val attribute = analysis.UnresolvedAttribute.quotedString(unparsedIdentifier) + if (planId.isDefined) { + attribute.setTagValue(LogicalPlan.PLAN_ID_TAG, planId.get) + } + attribute + } } private[sql] object ColumnNodeToExpressionConverter extends ColumnNodeToExpressionConverter { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index f7cf70ac957ba..2f54f1f62fde1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -22,8 +22,7 @@ import java.util.Locale import scala.util.control.NonFatal -import org.apache.spark.SparkUnsupportedOperationException -import org.apache.spark.sql.AnalysisException +import org.apache.spark.{SparkThrowable, SparkUnsupportedOperationException} import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException import org.apache.spark.sql.connector.catalog.Identifier @@ -153,12 +152,12 @@ private case class DB2Dialect() extends JdbcDialect with SQLConfHelper with NoLe override def removeSchemaCommentQuery(schema: String): String = { s"COMMENT ON SCHEMA ${quoteIdentifier(schema)} IS ''" } - override def classifyException( e: Throwable, errorClass: String, messageParameters: Map[String, String], - description: String): AnalysisException = { + description: String, + isRuntime: Boolean): Throwable with SparkThrowable = { e match { case sqlException: SQLException => sqlException.getSQLState match { @@ -171,9 +170,10 @@ private case class DB2Dialect() extends JdbcDialect with SQLConfHelper with NoLe case "42710" if errorClass == "FAILED_JDBC.RENAME_TABLE" => val newTable = messageParameters("newName") throw QueryCompilationErrors.tableAlreadyExistsError(newTable) - case _ => super.classifyException(e, errorClass, messageParameters, description) + case _ => + super.classifyException(e, errorClass, messageParameters, description, isRuntime) } - case _ => super.classifyException(e, errorClass, messageParameters, description) + case _ => super.classifyException(e, errorClass, messageParameters, description, isRuntime) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index 3ece44ece9e6a..798ecb5b36ff2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -27,8 +27,7 @@ import scala.util.control.NonFatal import org.apache.commons.lang3.StringUtils -import org.apache.spark.SparkUnsupportedOperationException -import org.apache.spark.sql.AnalysisException +import org.apache.spark.{SparkThrowable, SparkUnsupportedOperationException} import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.catalog.functions.UnboundFunction @@ -200,7 +199,8 @@ private[sql] case class H2Dialect() extends JdbcDialect with NoLegacyJDBCError { e: Throwable, errorClass: String, messageParameters: Map[String, String], - description: String): AnalysisException = { + description: String, + isRuntime: Boolean): Throwable with SparkThrowable = { e match { case exception: SQLException => // Error codes are from https://www.h2database.com/javadoc/org/h2/api/ErrorCode.html @@ -244,7 +244,7 @@ private[sql] case class H2Dialect() extends JdbcDialect with NoLegacyJDBCError { } case _ => // do nothing } - super.classifyException(e, errorClass, messageParameters, description) + super.classifyException(e, errorClass, messageParameters, description, isRuntime) } override def compileExpression(expr: Expression): Option[String] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 290665020f883..3bf1390cb664d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -28,7 +28,7 @@ import scala.util.control.NonFatal import org.apache.commons.lang3.StringUtils -import org.apache.spark.SparkUnsupportedOperationException +import org.apache.spark.{SparkRuntimeException, SparkThrowable, SparkUnsupportedOperationException} import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException @@ -741,13 +741,15 @@ abstract class JdbcDialect extends Serializable with Logging { * @param errorClass The error class assigned in the case of an unclassified `e` * @param messageParameters The message parameters of `errorClass` * @param description The error description - * @return `AnalysisException` or its sub-class. + * @param isRuntime Whether the exception is a runtime exception or not. + * @return `SparkThrowable + Throwable` or its sub-class. */ def classifyException( e: Throwable, errorClass: String, messageParameters: Map[String, String], - description: String): AnalysisException = { + description: String, + isRuntime: Boolean): Throwable with SparkThrowable = { classifyException(description, e) } @@ -850,11 +852,19 @@ trait NoLegacyJDBCError extends JdbcDialect { e: Throwable, errorClass: String, messageParameters: Map[String, String], - description: String): AnalysisException = { - new AnalysisException( - errorClass = errorClass, - messageParameters = messageParameters, - cause = Some(e)) + description: String, + isRuntime: Boolean): Throwable with SparkThrowable = { + if (isRuntime) { + new SparkRuntimeException( + errorClass = errorClass, + messageParameters = messageParameters, + cause = e) + } else { + new AnalysisException( + errorClass = errorClass, + messageParameters = messageParameters, + cause = Some(e)) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 369f710edccf0..7d476d43e5c7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -22,7 +22,7 @@ import java.util.Locale import scala.util.control.NonFatal -import org.apache.spark.sql.AnalysisException +import org.apache.spark.SparkThrowable import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.expressions.{Expression, NullOrdering, SortDirection} @@ -207,7 +207,8 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr e: Throwable, errorClass: String, messageParameters: Map[String, String], - description: String): AnalysisException = { + description: String, + isRuntime: Boolean): Throwable with SparkThrowable = { e match { case sqlException: SQLException => sqlException.getErrorCode match { @@ -219,9 +220,10 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr case 15335 if errorClass == "FAILED_JDBC.RENAME_TABLE" => val newTable = messageParameters("newName") throw QueryCompilationErrors.tableAlreadyExistsError(newTable) - case _ => super.classifyException(e, errorClass, messageParameters, description) + case _ => + super.classifyException(e, errorClass, messageParameters, description, isRuntime) } - case _ => super.classifyException(e, errorClass, messageParameters, description) + case _ => super.classifyException(e, errorClass, messageParameters, description, isRuntime) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index 785bf5b13aa78..dd0118d875998 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -24,8 +24,7 @@ import java.util.Locale import scala.collection.mutable.ArrayBuilder import scala.util.control.NonFatal -import org.apache.spark.SparkUnsupportedOperationException -import org.apache.spark.sql.AnalysisException +import org.apache.spark.{SparkThrowable, SparkUnsupportedOperationException} import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException} import org.apache.spark.sql.connector.catalog.Identifier @@ -353,7 +352,8 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper with No e: Throwable, errorClass: String, messageParameters: Map[String, String], - description: String): AnalysisException = { + description: String, + isRuntime: Boolean): Throwable with SparkThrowable = { e match { case sqlException: SQLException => sqlException.getErrorCode match { @@ -369,10 +369,11 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper with No val indexName = messageParameters("indexName") val tableName = messageParameters("tableName") throw new NoSuchIndexException(indexName, tableName, cause = Some(e)) - case _ => super.classifyException(e, errorClass, messageParameters, description) + case _ => + super.classifyException(e, errorClass, messageParameters, description, isRuntime) } case unsupported: UnsupportedOperationException => throw unsupported - case _ => super.classifyException(e, errorClass, messageParameters, description) + case _ => super.classifyException(e, errorClass, messageParameters, description, isRuntime) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 6175b5f659932..a73a34c646356 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -22,8 +22,7 @@ import java.util.Locale import scala.util.control.NonFatal -import org.apache.spark.SparkUnsupportedOperationException -import org.apache.spark.sql.AnalysisException +import org.apache.spark.{SparkThrowable, SparkUnsupportedOperationException} import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.connector.expressions.Expression import org.apache.spark.sql.errors.QueryCompilationErrors @@ -236,16 +235,18 @@ private case class OracleDialect() extends JdbcDialect with SQLConfHelper with N e: Throwable, errorClass: String, messageParameters: Map[String, String], - description: String): AnalysisException = { + description: String, + isRuntime: Boolean): Throwable with SparkThrowable = { e match { case sqlException: SQLException => sqlException.getErrorCode match { case 955 if errorClass == "FAILED_JDBC.RENAME_TABLE" => val newTable = messageParameters("newName") throw QueryCompilationErrors.tableAlreadyExistsError(newTable) - case _ => super.classifyException(e, errorClass, messageParameters, description) + case _ => + super.classifyException(e, errorClass, messageParameters, description, isRuntime) } - case _ => super.classifyException(e, errorClass, messageParameters, description) + case _ => super.classifyException(e, errorClass, messageParameters, description, isRuntime) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 03fefd82802ef..8341063e09890 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -23,14 +23,15 @@ import java.util import java.util.Locale import scala.util.Using +import scala.util.control.NonFatal +import org.apache.spark.SparkThrowable import org.apache.spark.internal.LogKeys.COLUMN_NAME import org.apache.spark.internal.MDC -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NonEmptyNamespaceException, NoSuchIndexException} import org.apache.spark.sql.connector.catalog.Identifier -import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.connector.expressions.{Expression, NamedReference} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo @@ -260,7 +261,8 @@ private case class PostgresDialect() e: Throwable, errorClass: String, messageParameters: Map[String, String], - description: String): AnalysisException = { + description: String, + isRuntime: Boolean): Throwable with SparkThrowable = { e match { case sqlException: SQLException => sqlException.getSQLState match { @@ -279,7 +281,7 @@ private case class PostgresDialect() if (tblRegexp.nonEmpty) { throw QueryCompilationErrors.tableAlreadyExistsError(tblRegexp.get.group(1)) } else { - super.classifyException(e, errorClass, messageParameters, description) + super.classifyException(e, errorClass, messageParameters, description, isRuntime) } } case "42704" if errorClass == "FAILED_JDBC.DROP_INDEX" => @@ -291,10 +293,33 @@ private case class PostgresDialect() namespace = messageParameters.get("namespace").toArray, details = sqlException.getMessage, cause = Some(e)) - case _ => super.classifyException(e, errorClass, messageParameters, description) + case _ => + super.classifyException(e, errorClass, messageParameters, description, isRuntime) } case unsupported: UnsupportedOperationException => throw unsupported - case _ => super.classifyException(e, errorClass, messageParameters, description) + case _ => super.classifyException(e, errorClass, messageParameters, description, isRuntime) + } + } + + class PostgresSQLBuilder extends JDBCSQLBuilder { + override def visitExtract(field: String, source: String): String = { + field match { + case "DAY_OF_YEAR" => s"EXTRACT(DOY FROM $source)" + case "YEAR_OF_WEEK" => s"EXTRACT(YEAR FROM $source)" + case "DAY_OF_WEEK" => s"EXTRACT(DOW FROM $source)" + case _ => super.visitExtract(field, source) + } + } + } + + override def compileExpression(expr: Expression): Option[String] = { + val postgresSQLBuilder = new PostgresSQLBuilder() + try { + Some(postgresSQLBuilder.build(expr)) + } catch { + case NonFatal(e) => + logWarning("Error occurs while compiling V2 expression", e) + None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index af9fd5464277c..3f8b0f1583adb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -592,3 +592,60 @@ class IterateStatementExec(val label: String) extends LeafStatementExec { var hasBeenMatched: Boolean = false override def reset(): Unit = hasBeenMatched = false } + +/** + * Executable node for LoopStatement. + * @param body Executable node for the body, executed on every loop iteration. + * @param label Label set to LoopStatement by user, None if not set. + */ +class LoopStatementExec( + body: CompoundBodyExec, + val label: Option[String]) extends NonLeafStatementExec { + + /** + * Loop can be interrupted by LeaveStatementExec + */ + private var interrupted: Boolean = false + + /** + * Loop can be iterated by IterateStatementExec + */ + private var iterated: Boolean = false + + private lazy val treeIterator = + new Iterator[CompoundStatementExec] { + override def hasNext: Boolean = !interrupted + + override def next(): CompoundStatementExec = { + if (!body.getTreeIterator.hasNext || iterated) { + reset() + } + + val retStmt = body.getTreeIterator.next() + + retStmt match { + case leaveStatementExec: LeaveStatementExec if !leaveStatementExec.hasBeenMatched => + if (label.contains(leaveStatementExec.label)) { + leaveStatementExec.hasBeenMatched = true + } + interrupted = true + case iterStatementExec: IterateStatementExec if !iterStatementExec.hasBeenMatched => + if (label.contains(iterStatementExec.label)) { + iterStatementExec.hasBeenMatched = true + } + iterated = true + case _ => + } + + retStmt + } + } + + override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator + + override def reset(): Unit = { + interrupted = false + iterated = false + body.reset() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 917b4d6f45ee0..78ef715e18982 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.scripting import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier -import org.apache.spark.sql.catalyst.parser.{CaseStatement, CompoundBody, CompoundPlanStatement, IfElseStatement, IterateStatement, LeaveStatement, RepeatStatement, SingleStatement, WhileStatement} +import org.apache.spark.sql.catalyst.parser.{CaseStatement, CompoundBody, CompoundPlanStatement, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, RepeatStatement, SingleStatement, WhileStatement} import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DropVariable, LogicalPlan} import org.apache.spark.sql.catalyst.trees.Origin @@ -120,6 +120,10 @@ case class SqlScriptingInterpreter() { transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec] new RepeatStatementExec(conditionExec, bodyExec, label, session) + case LoopStatement(body, label) => + val bodyExec = transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec] + new LoopStatementExec(bodyExec, label) + case leaveStatement: LeaveStatement => new LeaveStatementExec(leaveStatement.label) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 24d769fc8fc87..f42d8b667ab12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -22,12 +22,12 @@ import java.util.Locale import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Evolving -import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.{api, DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.connector.catalog.{SupportsRead, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.errors.QueryCompilationErrors @@ -49,25 +49,15 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap * @since 2.0.0 */ @Evolving -final class DataStreamReader private[sql](sparkSession: SparkSession) extends Logging { - /** - * Specifies the input data source format. - * - * @since 2.0.0 - */ - def format(source: String): DataStreamReader = { +final class DataStreamReader private[sql](sparkSession: SparkSession) extends api.DataStreamReader { + /** @inheritdoc */ + def format(source: String): this.type = { this.source = source this } - /** - * Specifies the input schema. Some data sources (e.g. JSON) can infer the input schema - * automatically from data. By specifying the schema here, the underlying data source can - * skip the schema inference step, and thus speed up data loading. - * - * @since 2.0.0 - */ - def schema(schema: StructType): DataStreamReader = { + /** @inheritdoc */ + def schema(schema: StructType): this.type = { if (schema != null) { val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] this.userSpecifiedSchema = Option(replaced) @@ -75,75 +65,19 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo this } - /** - * Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON) can - * infer the input schema automatically from data. By specifying the schema here, the underlying - * data source can skip the schema inference step, and thus speed up data loading. - * - * @since 2.3.0 - */ - def schema(schemaString: String): DataStreamReader = { - schema(StructType.fromDDL(schemaString)) - } - - /** - * Adds an input option for the underlying data source. - * - * @since 2.0.0 - */ - def option(key: String, value: String): DataStreamReader = { + /** @inheritdoc */ + def option(key: String, value: String): this.type = { this.extraOptions += (key -> value) this } - /** - * Adds an input option for the underlying data source. - * - * @since 2.0.0 - */ - def option(key: String, value: Boolean): DataStreamReader = option(key, value.toString) - - /** - * Adds an input option for the underlying data source. - * - * @since 2.0.0 - */ - def option(key: String, value: Long): DataStreamReader = option(key, value.toString) - - /** - * Adds an input option for the underlying data source. - * - * @since 2.0.0 - */ - def option(key: String, value: Double): DataStreamReader = option(key, value.toString) - - /** - * (Scala-specific) Adds input options for the underlying data source. - * - * @since 2.0.0 - */ - def options(options: scala.collection.Map[String, String]): DataStreamReader = { + /** @inheritdoc */ + def options(options: scala.collection.Map[String, String]): this.type = { this.extraOptions ++= options this } - /** - * (Java-specific) Adds input options for the underlying data source. - * - * @since 2.0.0 - */ - def options(options: java.util.Map[String, String]): DataStreamReader = { - this.options(options.asScala) - this - } - - - /** - * Loads input data stream in as a `DataFrame`, for data streams that don't require a path - * (e.g. external key-value stores). - * - * @since 2.0.0 - */ + /** @inheritdoc */ def load(): DataFrame = loadInternal(None) private def loadInternal(path: Option[String]): DataFrame = { @@ -205,11 +139,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo } } - /** - * Loads input in as a `DataFrame`, for data streams that read from some path. - * - * @since 2.0.0 - */ + /** @inheritdoc */ def load(path: String): DataFrame = { if (!sparkSession.sessionState.conf.legacyPathOptionBehavior && extraOptions.contains("path")) { @@ -218,133 +148,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo loadInternal(Some(path)) } - /** - * Loads a JSON file stream and returns the results as a `DataFrame`. - * - * JSON Lines (newline-delimited JSON) is supported by - * default. For JSON (one record per file), set the `multiLine` option to true. - * - * This function goes through the input once to determine the input schema. If you know the - * schema in advance, use the version that specifies the schema to avoid the extra scan. - * - * You can set the following option(s): - *
    - *
  • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be - * considered in every trigger.
  • - *
  • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files - * to be considered in every trigger.
  • - *
- * - * You can find the JSON-specific options for reading JSON file stream in - * - * Data Source Option in the version you use. - * - * @since 2.0.0 - */ - def json(path: String): DataFrame = { - userSpecifiedSchema.foreach(checkJsonSchema) - format("json").load(path) - } - - /** - * Loads a CSV file stream and returns the result as a `DataFrame`. - * - * This function will go through the input once to determine the input schema if `inferSchema` - * is enabled. To avoid going through the entire data once, disable `inferSchema` option or - * specify the schema explicitly using `schema`. - * - * You can set the following option(s): - *
    - *
  • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be - * considered in every trigger.
  • - *
  • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files - * to be considered in every trigger.
  • - *
- * - * You can find the CSV-specific options for reading CSV file stream in - * - * Data Source Option in the version you use. - * - * @since 2.0.0 - */ - def csv(path: String): DataFrame = format("csv").load(path) - - /** - * Loads a XML file stream and returns the result as a `DataFrame`. - * - * This function will go through the input once to determine the input schema if `inferSchema` - * is enabled. To avoid going through the entire data once, disable `inferSchema` option or - * specify the schema explicitly using `schema`. - * - * You can set the following option(s): - *
    - *
  • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be - * considered in every trigger.
  • - *
  • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files - * to be considered in every trigger.
  • - *
- * - * You can find the XML-specific options for reading XML file stream in - * - * Data Source Option in the version you use. - * - * @since 4.0.0 - */ - def xml(path: String): DataFrame = { - userSpecifiedSchema.foreach(checkXmlSchema) - format("xml").load(path) - } - - /** - * Loads a ORC file stream, returning the result as a `DataFrame`. - * - * You can set the following option(s): - *
    - *
  • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be - * considered in every trigger.
  • - *
  • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files - * to be considered in every trigger.
  • - *
- * - * ORC-specific option(s) for reading ORC file stream can be found in - * - * Data Source Option in the version you use. - * - * @since 2.3.0 - */ - def orc(path: String): DataFrame = { - format("orc").load(path) - } - - /** - * Loads a Parquet file stream, returning the result as a `DataFrame`. - * - * You can set the following option(s): - *
    - *
  • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be - * considered in every trigger.
  • - *
  • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files - * to be considered in every trigger.
  • - *
- * - * Parquet-specific option(s) for reading Parquet file stream can be found in - * - * Data Source Option in the version you use. - * - * @since 2.0.0 - */ - def parquet(path: String): DataFrame = { - format("parquet").load(path) - } - - /** - * Define a Streaming DataFrame on a Table. The DataSource corresponding to the table should - * support streaming mode. - * @param tableName The name of the table - * @since 3.1.0 - */ + /** @inheritdoc */ def table(tableName: String): DataFrame = { require(tableName != null, "The table name can't be null") val identifier = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(tableName) @@ -356,65 +160,56 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo isStreaming = true)) } - /** - * Loads text files and returns a `DataFrame` whose schema starts with a string column named - * "value", and followed by partitioned columns if there are any. - * The text files must be encoded as UTF-8. - * - * By default, each line in the text files is a new row in the resulting DataFrame. For example: - * {{{ - * // Scala: - * spark.readStream.text("/path/to/directory/") - * - * // Java: - * spark.readStream().text("/path/to/directory/") - * }}} - * - * You can set the following option(s): - *
    - *
  • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be - * considered in every trigger.
  • - *
  • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files - * to be considered in every trigger.
  • - *
- * - * You can find the text-specific options for reading text files in - * - * Data Source Option in the version you use. - * - * @since 2.0.0 - */ - def text(path: String): DataFrame = format("text").load(path) - - /** - * Loads text file(s) and returns a `Dataset` of String. The underlying schema of the Dataset - * contains a single string column named "value". - * The text files must be encoded as UTF-8. - * - * If the directory structure of the text files contains partitioning information, those are - * ignored in the resulting Dataset. To include partitioning information as columns, use `text`. - * - * By default, each line in the text file is a new element in the resulting Dataset. For example: - * {{{ - * // Scala: - * spark.readStream.textFile("/path/to/spark/README.md") - * - * // Java: - * spark.readStream().textFile("/path/to/spark/README.md") - * }}} - * - * You can set the text-specific options as specified in `DataStreamReader.text`. - * - * @param path input path - * @since 2.1.0 - */ - def textFile(path: String): Dataset[String] = { + override protected def assertNoSpecifiedSchema(operation: String): Unit = { if (userSpecifiedSchema.nonEmpty) { - throw QueryCompilationErrors.userSpecifiedSchemaUnsupportedError("textFile") + throw QueryCompilationErrors.userSpecifiedSchemaUnsupportedError(operation) } - text(path).select("value").as[String](sparkSession.implicits.newStringEncoder) } + override protected def validateJsonSchema(): Unit = userSpecifiedSchema.foreach(checkJsonSchema) + + override protected def validateXmlSchema(): Unit = userSpecifiedSchema.foreach(checkXmlSchema) + + /////////////////////////////////////////////////////////////////////////////////////// + // Covariant overrides. + /////////////////////////////////////////////////////////////////////////////////////// + + /** @inheritdoc */ + override def schema(schemaString: String): this.type = super.schema(schemaString) + + /** @inheritdoc */ + override def option(key: String, value: Boolean): this.type = super.option(key, value) + + /** @inheritdoc */ + override def option(key: String, value: Long): this.type = super.option(key, value) + + /** @inheritdoc */ + override def option(key: String, value: Double): this.type = super.option(key, value) + + /** @inheritdoc */ + override def options(options: java.util.Map[String, String]): this.type = super.options(options) + + /** @inheritdoc */ + override def json(path: String): DataFrame = super.json(path) + + /** @inheritdoc */ + override def csv(path: String): DataFrame = super.csv(path) + + /** @inheritdoc */ + override def xml(path: String): DataFrame = super.xml(path) + + /** @inheritdoc */ + override def orc(path: String): DataFrame = super.orc(path) + + /** @inheritdoc */ + override def parquet(path: String): DataFrame = super.parquet(path) + + /** @inheritdoc */ + override def text(path: String): DataFrame = super.text(path) + + /** @inheritdoc */ + override def textFile(path: String): Dataset[String] = super.textFile(path) + /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index ab4d350c1e68c..b0233d2c51b75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -55,253 +55,101 @@ import org.apache.spark.util.Utils * @since 2.0.0 */ @Evolving -final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { - import DataStreamWriter._ +final class DataStreamWriter[T] private[sql](ds: Dataset[T]) extends api.DataStreamWriter[T] { + type DS[U] = Dataset[U] - private val df = ds.toDF() - - /** - * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink. - *
    - *
  • `OutputMode.Append()`: only the new rows in the streaming DataFrame/Dataset will be - * written to the sink.
  • - *
  • `OutputMode.Complete()`: all the rows in the streaming DataFrame/Dataset will be written - * to the sink every time there are some updates.
  • - *
  • `OutputMode.Update()`: only the rows that were updated in the streaming - * DataFrame/Dataset will be written to the sink every time there are some updates. - * If the query doesn't contain aggregations, it will be equivalent to - * `OutputMode.Append()` mode.
  • - *
- * - * @since 2.0.0 - */ - def outputMode(outputMode: OutputMode): DataStreamWriter[T] = { + /** @inheritdoc */ + def outputMode(outputMode: OutputMode): this.type = { this.outputMode = outputMode this } - /** - * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink. - *
    - *
  • `append`: only the new rows in the streaming DataFrame/Dataset will be written to - * the sink.
  • - *
  • `complete`: all the rows in the streaming DataFrame/Dataset will be written to the sink - * every time there are some updates.
  • - *
  • `update`: only the rows that were updated in the streaming DataFrame/Dataset will - * be written to the sink every time there are some updates. If the query doesn't - * contain aggregations, it will be equivalent to `append` mode.
  • - *
- * - * @since 2.0.0 - */ - def outputMode(outputMode: String): DataStreamWriter[T] = { + /** @inheritdoc */ + def outputMode(outputMode: String): this.type = { this.outputMode = InternalOutputModes(outputMode) this } - /** - * Set the trigger for the stream query. The default value is `ProcessingTime(0)` and it will run - * the query as fast as possible. - * - * Scala Example: - * {{{ - * df.writeStream.trigger(ProcessingTime("10 seconds")) - * - * import scala.concurrent.duration._ - * df.writeStream.trigger(ProcessingTime(10.seconds)) - * }}} - * - * Java Example: - * {{{ - * df.writeStream().trigger(ProcessingTime.create("10 seconds")) - * - * import java.util.concurrent.TimeUnit - * df.writeStream().trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) - * }}} - * - * @since 2.0.0 - */ - def trigger(trigger: Trigger): DataStreamWriter[T] = { + /** @inheritdoc */ + def trigger(trigger: Trigger): this.type = { this.trigger = trigger this } - /** - * Specifies the name of the [[StreamingQuery]] that can be started with `start()`. - * This name must be unique among all the currently active queries in the associated SQLContext. - * - * @since 2.0.0 - */ - def queryName(queryName: String): DataStreamWriter[T] = { + /** @inheritdoc */ + def queryName(queryName: String): this.type = { this.extraOptions += ("queryName" -> queryName) this } - /** - * Specifies the underlying output data source. - * - * @since 2.0.0 - */ - def format(source: String): DataStreamWriter[T] = { + /** @inheritdoc */ + def format(source: String): this.type = { this.source = source this } - /** - * Partitions the output by the given columns on the file system. If specified, the output is - * laid out on the file system similar to Hive's partitioning scheme. As an example, when we - * partition a dataset by year and then month, the directory layout would look like: - * - *
    - *
  • year=2016/month=01/
  • - *
  • year=2016/month=02/
  • - *
- * - * Partitioning is one of the most widely used techniques to optimize physical data layout. - * It provides a coarse-grained index for skipping unnecessary data reads when queries have - * predicates on the partitioned columns. In order for partitioning to work well, the number - * of distinct values in each column should typically be less than tens of thousands. - * - * @since 2.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def partitionBy(colNames: String*): DataStreamWriter[T] = { + def partitionBy(colNames: String*): this.type = { this.partitioningColumns = Option(colNames) validatePartitioningAndClustering() this } - /** - * Clusters the output by the given columns. If specified, the output is laid out such that - * records with similar values on the clustering column are grouped together in the same file. - * - * Clustering improves query efficiency by allowing queries with predicates on the clustering - * columns to skip unnecessary data. Unlike partitioning, clustering can be used on very high - * cardinality columns. - * - * @since 4.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def clusterBy(colNames: String*): DataStreamWriter[T] = { + def clusterBy(colNames: String*): this.type = { this.clusteringColumns = Option(colNames) validatePartitioningAndClustering() this } - /** - * Adds an output option for the underlying data source. - * - * @since 2.0.0 - */ - def option(key: String, value: String): DataStreamWriter[T] = { + /** @inheritdoc */ + def option(key: String, value: String): this.type = { this.extraOptions += (key -> value) this } - /** - * Adds an output option for the underlying data source. - * - * @since 2.0.0 - */ - def option(key: String, value: Boolean): DataStreamWriter[T] = option(key, value.toString) - - /** - * Adds an output option for the underlying data source. - * - * @since 2.0.0 - */ - def option(key: String, value: Long): DataStreamWriter[T] = option(key, value.toString) - - /** - * Adds an output option for the underlying data source. - * - * @since 2.0.0 - */ - def option(key: String, value: Double): DataStreamWriter[T] = option(key, value.toString) - - /** - * (Scala-specific) Adds output options for the underlying data source. - * - * @since 2.0.0 - */ - def options(options: scala.collection.Map[String, String]): DataStreamWriter[T] = { + /** @inheritdoc */ + def options(options: scala.collection.Map[String, String]): this.type = { this.extraOptions ++= options this } - /** - * Adds output options for the underlying data source. - * - * @since 2.0.0 - */ - def options(options: java.util.Map[String, String]): DataStreamWriter[T] = { + /** @inheritdoc */ + def options(options: java.util.Map[String, String]): this.type = { this.options(options.asScala) this } - /** - * Starts the execution of the streaming query, which will continually output results to the given - * path as new data arrives. The returned [[StreamingQuery]] object can be used to interact with - * the stream. - * - * @since 2.0.0 - */ + /** @inheritdoc */ def start(path: String): StreamingQuery = { - if (!df.sparkSession.sessionState.conf.legacyPathOptionBehavior && + if (!ds.sparkSession.sessionState.conf.legacyPathOptionBehavior && extraOptions.contains("path")) { throw QueryCompilationErrors.setPathOptionAndCallWithPathParameterError("start") } startInternal(Some(path)) } - /** - * Starts the execution of the streaming query, which will continually output results to the given - * path as new data arrives. The returned [[StreamingQuery]] object can be used to interact with - * the stream. Throws a `TimeoutException` if the following conditions are met: - * - Another run of the same streaming query, that is a streaming query - * sharing the same checkpoint location, is already active on the same - * Spark Driver - * - The SQL configuration `spark.sql.streaming.stopActiveRunOnRestart` - * is enabled - * - The active run cannot be stopped within the timeout controlled by - * the SQL configuration `spark.sql.streaming.stopTimeout` - * - * @since 2.0.0 - */ + /** @inheritdoc */ @throws[TimeoutException] def start(): StreamingQuery = startInternal(None) - /** - * Starts the execution of the streaming query, which will continually output results to the given - * table as new data arrives. The returned [[StreamingQuery]] object can be used to interact with - * the stream. - * - * For v1 table, partitioning columns provided by `partitionBy` will be respected no matter the - * table exists or not. A new table will be created if the table not exists. - * - * For v2 table, `partitionBy` will be ignored if the table already exists. `partitionBy` will be - * respected only if the v2 table does not exist. Besides, the v2 table created by this API lacks - * some functionalities (e.g., customized properties, options, and serde info). If you need them, - * please create the v2 table manually before the execution to avoid creating a table with - * incomplete information. - * - * @since 3.1.0 - */ + /** @inheritdoc */ @Evolving @throws[TimeoutException] def toTable(tableName: String): StreamingQuery = { - this.tableName = tableName - import df.sparkSession.sessionState.analyzer.CatalogAndIdentifier + import ds.sparkSession.sessionState.analyzer.CatalogAndIdentifier import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ - val parser = df.sparkSession.sessionState.sqlParser + val parser = ds.sparkSession.sessionState.sqlParser val originalMultipartIdentifier = parser.parseMultipartIdentifier(tableName) val CatalogAndIdentifier(catalog, identifier) = originalMultipartIdentifier // Currently we don't create a logical streaming writer node in logical plan, so cannot rely // on analyzer to resolve it. Directly lookup only for temp view to provide clearer message. // TODO (SPARK-27484): we should add the writing node before the plan is analyzed. - if (df.sparkSession.sessionState.catalog.isTempView(originalMultipartIdentifier)) { + if (ds.sparkSession.sessionState.catalog.isTempView(originalMultipartIdentifier)) { throw QueryCompilationErrors.tempViewNotSupportStreamingWriteError(tableName) } @@ -327,14 +175,14 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { extraOptions.get("path"), None, None, - false) + external = false) val cmd = CreateTable( UnresolvedIdentifier(originalMultipartIdentifier), - df.schema.asNullable.map(ColumnDefinition.fromV1Column(_, parser)), + ds.schema.asNullable.map(ColumnDefinition.fromV1Column(_, parser)), partitioningOrClusteringTransform, tableSpec, ignoreIfExists = false) - Dataset.ofRows(df.sparkSession, cmd) + Dataset.ofRows(ds.sparkSession, cmd) } val tableInstance = catalog.asTableCatalog.loadTable(identifier) @@ -371,34 +219,34 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { throw QueryCompilationErrors.cannotOperateOnHiveDataSourceFilesError("write") } - if (source == SOURCE_NAME_MEMORY) { - assertNotPartitioned(SOURCE_NAME_MEMORY) + if (source == DataStreamWriter.SOURCE_NAME_MEMORY) { + assertNotPartitioned(DataStreamWriter.SOURCE_NAME_MEMORY) if (extraOptions.get("queryName").isEmpty) { throw QueryCompilationErrors.queryNameNotSpecifiedForMemorySinkError() } val sink = new MemorySink() - val resultDf = Dataset.ofRows(df.sparkSession, - MemoryPlan(sink, DataTypeUtils.toAttributes(df.schema))) + val resultDf = Dataset.ofRows(ds.sparkSession, + MemoryPlan(sink, DataTypeUtils.toAttributes(ds.schema))) val recoverFromCheckpoint = outputMode == OutputMode.Complete() val query = startQuery(sink, extraOptions, recoverFromCheckpoint = recoverFromCheckpoint, catalogTable = catalogTable) resultDf.createOrReplaceTempView(query.name) query - } else if (source == SOURCE_NAME_FOREACH) { - assertNotPartitioned(SOURCE_NAME_FOREACH) + } else if (source == DataStreamWriter.SOURCE_NAME_FOREACH) { + assertNotPartitioned(DataStreamWriter.SOURCE_NAME_FOREACH) val sink = ForeachWriterTable[Any](foreachWriter, foreachWriterEncoder) startQuery(sink, extraOptions, catalogTable = catalogTable) - } else if (source == SOURCE_NAME_FOREACH_BATCH) { - assertNotPartitioned(SOURCE_NAME_FOREACH_BATCH) + } else if (source == DataStreamWriter.SOURCE_NAME_FOREACH_BATCH) { + assertNotPartitioned(DataStreamWriter.SOURCE_NAME_FOREACH_BATCH) if (trigger.isInstanceOf[ContinuousTrigger]) { throw QueryCompilationErrors.sourceNotSupportedWithContinuousTriggerError(source) } val sink = new ForeachBatchSink[T](foreachBatchWriter, ds.exprEnc) startQuery(sink, extraOptions, catalogTable = catalogTable) } else { - val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) + val cls = DataSource.lookupDataSource(source, ds.sparkSession.sessionState.conf) val disabledSources = - Utils.stringToSeq(df.sparkSession.sessionState.conf.disabledV2StreamingWriters) + Utils.stringToSeq(ds.sparkSession.sessionState.conf.disabledV2StreamingWriters) val useV1Source = disabledSources.contains(cls.getCanonicalName) || // file source v2 does not support streaming yet. classOf[FileDataSourceV2].isAssignableFrom(cls) @@ -412,7 +260,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { val sink = if (classOf[TableProvider].isAssignableFrom(cls) && !useV1Source) { val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider] val sessionOptions = DataSourceV2Utils.extractSessionConfigs( - source = provider, conf = df.sparkSession.sessionState.conf) + source = provider, conf = ds.sparkSession.sessionState.conf) val finalOptions = sessionOptions.filter { case (k, _) => !optionsWithPath.contains(k) } ++ optionsWithPath.originalMap val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava) @@ -420,7 +268,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { // to `getTable`. This is for avoiding schema inference, which can be very expensive. // If the query schema is not compatible with the existing data, the behavior is undefined. val outputSchema = if (provider.supportsExternalMetadata()) { - Some(df.schema) + Some(ds.schema) } else { None } @@ -450,12 +298,12 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { recoverFromCheckpoint: Boolean = true, catalogAndIdent: Option[(TableCatalog, Identifier)] = None, catalogTable: Option[CatalogTable] = None): StreamingQuery = { - val useTempCheckpointLocation = SOURCES_ALLOW_ONE_TIME_QUERY.contains(source) + val useTempCheckpointLocation = DataStreamWriter.SOURCES_ALLOW_ONE_TIME_QUERY.contains(source) - df.sparkSession.sessionState.streamingQueryManager.startQuery( + ds.sparkSession.sessionState.streamingQueryManager.startQuery( newOptions.get("queryName"), newOptions.get("checkpointLocation"), - df, + ds, newOptions.originalMap, sink, outputMode, @@ -480,26 +328,21 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { case None => optionsWithoutClusteringKey } val ds = DataSource( - df.sparkSession, + this.ds.sparkSession, className = source, options = optionsWithClusteringColumns, partitionColumns = normalizedParCols.getOrElse(Nil)) ds.createSink(outputMode) } - /** - * Sets the output of the streaming query to be processed using the provided writer object. - * object. See [[org.apache.spark.sql.ForeachWriter]] for more details on the lifecycle and - * semantics. - * @since 2.0.0 - */ - def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = { + /** @inheritdoc */ + def foreach(writer: ForeachWriter[T]): this.type = { foreachImplementation(writer.asInstanceOf[ForeachWriter[Any]]) } private[sql] def foreachImplementation(writer: ForeachWriter[Any], - encoder: Option[ExpressionEncoder[Any]] = None): DataStreamWriter[T] = { - this.source = SOURCE_NAME_FOREACH + encoder: Option[ExpressionEncoder[Any]] = None): this.type = { + this.source = DataStreamWriter.SOURCE_NAME_FOREACH this.foreachWriter = if (writer != null) { ds.sparkSession.sparkContext.clean(writer) } else { @@ -509,47 +352,15 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { this } - /** - * :: Experimental :: - * - * (Scala-specific) Sets the output of the streaming query to be processed using the provided - * function. This is supported only in the micro-batch execution modes (that is, when the - * trigger is not continuous). In every micro-batch, the provided function will be called in - * every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. - * The batchId can be used to deduplicate and transactionally write the output - * (that is, the provided Dataset) to external systems. The output Dataset is guaranteed - * to be exactly the same for the same batchId (assuming all operations are deterministic - * in the query). - * - * @since 2.4.0 - */ + /** @inheritdoc */ @Evolving - def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] = { - this.source = SOURCE_NAME_FOREACH_BATCH + def foreachBatch(function: (Dataset[T], Long) => Unit): this.type = { + this.source = DataStreamWriter.SOURCE_NAME_FOREACH_BATCH if (function == null) throw new IllegalArgumentException("foreachBatch function cannot be null") this.foreachBatchWriter = function this } - /** - * :: Experimental :: - * - * (Java-specific) Sets the output of the streaming query to be processed using the provided - * function. This is supported only in the micro-batch execution modes (that is, when the - * trigger is not continuous). In every micro-batch, the provided function will be called in - * every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. - * The batchId can be used to deduplicate and transactionally write the output - * (that is, the provided Dataset) to external systems. The output Dataset is guaranteed - * to be exactly the same for the same batchId (assuming all operations are deterministic - * in the query). - * - * @since 2.4.0 - */ - @Evolving - def foreachBatch(function: VoidFunction2[Dataset[T], java.lang.Long]): DataStreamWriter[T] = { - foreachBatch((batchDs: Dataset[T], batchId: Long) => function.call(batchDs, batchId)) - } - private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols => cols.map(normalize(_, "Partition")) } @@ -564,8 +375,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * need to care about case sensitivity afterwards. */ private def normalize(columnName: String, columnType: String): String = { - val validColumnNames = df.logicalPlan.output.map(_.name) - validColumnNames.find(df.sparkSession.sessionState.analyzer.resolver(_, columnName)) + val validColumnNames = ds.logicalPlan.output.map(_.name) + validColumnNames.find(ds.sparkSession.sessionState.analyzer.resolver(_, columnName)) .getOrElse(throw QueryCompilationErrors.columnNotFoundInExistingColumnsError( columnType, columnName, validColumnNames)) } @@ -584,12 +395,28 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } /////////////////////////////////////////////////////////////////////////////////////// - // Builder pattern config options + // Covariant Overrides /////////////////////////////////////////////////////////////////////////////////////// - private var source: String = df.sparkSession.sessionState.conf.defaultDataSourceName + /** @inheritdoc */ + override def option(key: String, value: Boolean): this.type = super.option(key, value) + + /** @inheritdoc */ + override def option(key: String, value: Long): this.type = super.option(key, value) + + /** @inheritdoc */ + override def option(key: String, value: Double): this.type = super.option(key, value) + + /** @inheritdoc */ + @Evolving + override def foreachBatch(function: VoidFunction2[Dataset[T], java.lang.Long]): this.type = + super.foreachBatch(function) + + /////////////////////////////////////////////////////////////////////////////////////// + // Builder pattern config options + /////////////////////////////////////////////////////////////////////////////////////// - private var tableName: String = null + private var source: String = ds.sparkSession.sessionState.conf.defaultDataSourceName private var outputMode: OutputMode = OutputMode.Append @@ -597,12 +424,12 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { private var extraOptions = CaseInsensitiveMap[String](Map.empty) - private var foreachWriter: ForeachWriter[Any] = null + private var foreachWriter: ForeachWriter[Any] = _ private var foreachWriterEncoder: ExpressionEncoder[Any] = ds.exprEnc.asInstanceOf[ExpressionEncoder[Any]] - private var foreachBatchWriter: (Dataset[T], Long) => Unit = null + private var foreachBatchWriter: (Dataset[T], Long) => Unit = _ private var partitioningColumns: Option[Seq[String]] = None @@ -610,14 +437,14 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } object DataStreamWriter { - val SOURCE_NAME_MEMORY = "memory" - val SOURCE_NAME_FOREACH = "foreach" - val SOURCE_NAME_FOREACH_BATCH = "foreachBatch" - val SOURCE_NAME_CONSOLE = "console" - val SOURCE_NAME_TABLE = "table" - val SOURCE_NAME_NOOP = "noop" + val SOURCE_NAME_MEMORY: String = "memory" + val SOURCE_NAME_FOREACH: String = "foreach" + val SOURCE_NAME_FOREACH_BATCH: String = "foreachBatch" + val SOURCE_NAME_CONSOLE: String = "console" + val SOURCE_NAME_TABLE: String = "table" + val SOURCE_NAME_NOOP: String = "noop" // these writer sources are also used for one-time query, hence allow temp checkpoint location - val SOURCES_ALLOW_ONE_TIME_QUERY = Seq(SOURCE_NAME_MEMORY, SOURCE_NAME_FOREACH, + val SOURCES_ALLOW_ONE_TIME_QUERY: Seq[String] = Seq(SOURCE_NAME_MEMORY, SOURCE_NAME_FOREACH, SOURCE_NAME_FOREACH_BATCH, SOURCE_NAME_CONSOLE, SOURCE_NAME_NOOP) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala index 653e1df4af679..7cf92db59067c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala @@ -16,10 +16,10 @@ */ package org.apache.spark.sql.streaming -import org.apache.spark.sql.{api, Dataset, SparkSession} +import org.apache.spark.sql.{api, SparkSession} /** @inheritdoc */ -trait StreamingQuery extends api.StreamingQuery[Dataset] { +trait StreamingQuery extends api.StreamingQuery { /** @inheritdoc */ override def sparkSession: SparkSession } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 3ab6d02f6b515..42f6d04466b08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -27,7 +27,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Evolving import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{CLASS_NAME, QUERY_ID, RUN_ID} -import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.{api, Dataset, SparkSession} import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.streaming.{WriteToStream, WriteToStreamStatement} import org.apache.spark.sql.connector.catalog.{Identifier, SupportsWrite, Table, TableCatalog} @@ -47,7 +47,9 @@ import org.apache.spark.util.{Clock, SystemClock, Utils} @Evolving class StreamingQueryManager private[sql] ( sparkSession: SparkSession, - sqlConf: SQLConf) extends Logging { + sqlConf: SQLConf) + extends api.StreamingQueryManager + with Logging { private[sql] val stateStoreCoordinator = StateStoreCoordinatorRef.forDriver(sparkSession.sparkContext.env) @@ -70,7 +72,7 @@ class StreamingQueryManager private[sql] ( * failed. The exception is the exception of the last failed query. */ @GuardedBy("awaitTerminationLock") - private var lastTerminatedQueryException: Option[StreamingQueryException] = null + private var lastTerminatedQueryException: Option[StreamingQueryException] = _ try { sparkSession.sparkContext.conf.get(STREAMING_QUERY_LISTENERS).foreach { classNames => @@ -90,51 +92,20 @@ class StreamingQueryManager private[sql] ( throw QueryExecutionErrors.registeringStreamingQueryListenerError(e) } - /** - * Returns a list of active queries associated with this SQLContext - * - * @since 2.0.0 - */ + /** @inheritdoc */ def active: Array[StreamingQuery] = activeQueriesSharedLock.synchronized { activeQueries.values.toArray } - /** - * Returns the query if there is an active query with the given id, or null. - * - * @since 2.1.0 - */ + /** @inheritdoc */ def get(id: UUID): StreamingQuery = activeQueriesSharedLock.synchronized { activeQueries.get(id).orNull } - /** - * Returns the query if there is an active query with the given id, or null. - * - * @since 2.1.0 - */ + /** @inheritdoc */ def get(id: String): StreamingQuery = get(UUID.fromString(id)) - /** - * Wait until any of the queries on the associated SQLContext has terminated since the - * creation of the context, or since `resetTerminated()` was called. If any query was terminated - * with an exception, then the exception will be thrown. - * - * If a query has terminated, then subsequent calls to `awaitAnyTermination()` will either - * return immediately (if the query was terminated by `query.stop()`), - * or throw the exception immediately (if the query was terminated with exception). Use - * `resetTerminated()` to clear past terminations and wait for new terminations. - * - * In the case where multiple queries have terminated since `resetTermination()` was called, - * if any query has terminated with exception, then `awaitAnyTermination()` will - * throw any of the exception. For correctly documenting exceptions across multiple queries, - * users need to stop all of them after any of them terminates with exception, and then check the - * `query.exception()` for each query. - * - * @throws StreamingQueryException if any query has terminated with an exception - * - * @since 2.0.0 - */ + /** @inheritdoc */ @throws[StreamingQueryException] def awaitAnyTermination(): Unit = { awaitTerminationLock.synchronized { @@ -147,27 +118,7 @@ class StreamingQueryManager private[sql] ( } } - /** - * Wait until any of the queries on the associated SQLContext has terminated since the - * creation of the context, or since `resetTerminated()` was called. Returns whether any query - * has terminated or not (multiple may have terminated). If any query has terminated with an - * exception, then the exception will be thrown. - * - * If a query has terminated, then subsequent calls to `awaitAnyTermination()` will either - * return `true` immediately (if the query was terminated by `query.stop()`), - * or throw the exception immediately (if the query was terminated with exception). Use - * `resetTerminated()` to clear past terminations and wait for new terminations. - * - * In the case where multiple queries have terminated since `resetTermination()` was called, - * if any query has terminated with exception, then `awaitAnyTermination()` will - * throw any of the exception. For correctly documenting exceptions across multiple queries, - * users need to stop all of them after any of them terminates with exception, and then check the - * `query.exception()` for each query. - * - * @throws StreamingQueryException if any query has terminated with an exception - * - * @since 2.0.0 - */ + /** @inheritdoc */ @throws[StreamingQueryException] def awaitAnyTermination(timeoutMs: Long): Boolean = { @@ -187,42 +138,24 @@ class StreamingQueryManager private[sql] ( } } - /** - * Forget about past terminated queries so that `awaitAnyTermination()` can be used again to - * wait for new terminations. - * - * @since 2.0.0 - */ + /** @inheritdoc */ def resetTerminated(): Unit = { awaitTerminationLock.synchronized { lastTerminatedQueryException = null } } - /** - * Register a [[StreamingQueryListener]] to receive up-calls for life cycle events of - * [[StreamingQuery]]. - * - * @since 2.0.0 - */ + /** @inheritdoc */ def addListener(listener: StreamingQueryListener): Unit = { listenerBus.addListener(listener) } - /** - * Deregister a [[StreamingQueryListener]]. - * - * @since 2.0.0 - */ + /** @inheritdoc */ def removeListener(listener: StreamingQueryListener): Unit = { listenerBus.removeListener(listener) } - /** - * List all [[StreamingQueryListener]]s attached to this [[StreamingQueryManager]]. - * - * @since 3.0.0 - */ + /** @inheritdoc */ def listListeners(): Array[StreamingQueryListener] = { listenerBus.listeners.asScala.toArray } @@ -241,7 +174,7 @@ class StreamingQueryManager private[sql] ( private def createQuery( userSpecifiedName: Option[String], userSpecifiedCheckpointLocation: Option[String], - df: DataFrame, + df: Dataset[_], extraOptions: Map[String, String], sink: Table, outputMode: OutputMode, @@ -322,7 +255,7 @@ class StreamingQueryManager private[sql] ( private[sql] def startQuery( userSpecifiedName: Option[String], userSpecifiedCheckpointLocation: Option[String], - df: DataFrame, + df: Dataset[_], extraOptions: Map[String, String], sink: Table, outputMode: OutputMode, diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java index 9fbd1919a2668..9988d04220f0f 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java @@ -85,7 +85,7 @@ public void isInCollectionCheckExceptionMessage() { Dataset df = spark.createDataFrame(rows, schema); AnalysisException e = Assertions.assertThrows(AnalysisException.class, () -> df.filter(df.col("a").isInCollection(Arrays.asList(new Column("b"))))); - Assertions.assertTrue(e.getErrorClass().equals("DATATYPE_MISMATCH.DATA_DIFF_TYPES")); + Assertions.assertTrue(e.getCondition().equals("DATATYPE_MISMATCH.DATA_DIFF_TYPES")); Map messageParameters = new HashMap<>(); messageParameters.put("functionName", "`in`"); messageParameters.put("dataType", "[\"INT\", \"ARRAY\"]"); diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 5ad1380e1fb82..79fd25aa3eb14 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -99,9 +99,9 @@ | org.apache.spark.sql.catalyst.expressions.Csc | csc | SELECT csc(1) | struct | | org.apache.spark.sql.catalyst.expressions.CsvToStructs | from_csv | SELECT from_csv('1, 0.8', 'a INT, b DOUBLE') | struct> | | org.apache.spark.sql.catalyst.expressions.CumeDist | cume_dist | SELECT a, b, cume_dist() OVER (PARTITION BY a ORDER BY b) FROM VALUES ('A1', 2), ('A1', 1), ('A2', 3), ('A1', 1) tab(a, b) | struct | -| org.apache.spark.sql.catalyst.expressions.CurDateExpressionBuilder | curdate | SELECT curdate() | struct | +| org.apache.spark.sql.catalyst.expressions.CurDateExpressionBuilder | curdate | SELECT curdate() | struct | | org.apache.spark.sql.catalyst.expressions.CurrentCatalog | current_catalog | SELECT current_catalog() | struct | -| org.apache.spark.sql.catalyst.expressions.CurrentDatabase | current_database | SELECT current_database() | struct | +| org.apache.spark.sql.catalyst.expressions.CurrentDatabase | current_database | SELECT current_database() | struct | | org.apache.spark.sql.catalyst.expressions.CurrentDatabase | current_schema | SELECT current_schema() | struct | | org.apache.spark.sql.catalyst.expressions.CurrentDate | current_date | SELECT current_date() | struct | | org.apache.spark.sql.catalyst.expressions.CurrentTimeZone | current_timezone | SELECT current_timezone() | struct | @@ -110,7 +110,7 @@ | org.apache.spark.sql.catalyst.expressions.CurrentUser | session_user | SELECT session_user() | struct | | org.apache.spark.sql.catalyst.expressions.CurrentUser | user | SELECT user() | struct | | org.apache.spark.sql.catalyst.expressions.DateAdd | date_add | SELECT date_add('2016-07-30', 1) | struct | -| org.apache.spark.sql.catalyst.expressions.DateAdd | dateadd | SELECT dateadd('2016-07-30', 1) | struct | +| org.apache.spark.sql.catalyst.expressions.DateAdd | dateadd | SELECT dateadd('2016-07-30', 1) | struct | | org.apache.spark.sql.catalyst.expressions.DateDiff | date_diff | SELECT date_diff('2009-07-31', '2009-07-30') | struct | | org.apache.spark.sql.catalyst.expressions.DateDiff | datediff | SELECT datediff('2009-07-31', '2009-07-30') | struct | | org.apache.spark.sql.catalyst.expressions.DateFormatClass | date_format | SELECT date_format('2016-04-08', 'y') | struct | @@ -264,7 +264,7 @@ | org.apache.spark.sql.catalyst.expressions.RPadExpressionBuilder | rpad | SELECT rpad('hi', 5, '??') | struct | | org.apache.spark.sql.catalyst.expressions.RaiseErrorExpressionBuilder | raise_error | SELECT raise_error('custom error message') | struct | | org.apache.spark.sql.catalyst.expressions.Rand | rand | SELECT rand() | struct | -| org.apache.spark.sql.catalyst.expressions.Rand | random | SELECT random() | struct | +| org.apache.spark.sql.catalyst.expressions.Rand | random | SELECT random() | struct | | org.apache.spark.sql.catalyst.expressions.RandStr | randstr | SELECT randstr(3, 0) AS result | struct | | org.apache.spark.sql.catalyst.expressions.Randn | randn | SELECT randn() | struct | | org.apache.spark.sql.catalyst.expressions.Rank | rank | SELECT a, b, rank(b) OVER (PARTITION BY a ORDER BY b) FROM VALUES ('A1', 2), ('A1', 1), ('A2', 3), ('A1', 1) tab(a, b) | struct | @@ -340,7 +340,7 @@ | org.apache.spark.sql.catalyst.expressions.TimeWindow | window | SELECT a, window.start, window.end, count(*) as cnt FROM VALUES ('A1', '2021-01-01 00:00:00'), ('A1', '2021-01-01 00:04:30'), ('A1', '2021-01-01 00:06:00'), ('A2', '2021-01-01 00:01:00') AS tab(a, b) GROUP by a, window(b, '5 minutes') ORDER BY a, start | struct | | org.apache.spark.sql.catalyst.expressions.ToBinary | to_binary | SELECT to_binary('abc', 'utf-8') | struct | | org.apache.spark.sql.catalyst.expressions.ToCharacterBuilder | to_char | SELECT to_char(454, '999') | struct | -| org.apache.spark.sql.catalyst.expressions.ToCharacterBuilder | to_varchar | SELECT to_varchar(454, '999') | struct | +| org.apache.spark.sql.catalyst.expressions.ToCharacterBuilder | to_varchar | SELECT to_varchar(454, '999') | struct | | org.apache.spark.sql.catalyst.expressions.ToDegrees | degrees | SELECT degrees(3.141592653589793) | struct | | org.apache.spark.sql.catalyst.expressions.ToNumber | to_number | SELECT to_number('454', '999') | struct | | org.apache.spark.sql.catalyst.expressions.ToRadians | radians | SELECT radians(180) | struct | @@ -402,7 +402,7 @@ | org.apache.spark.sql.catalyst.expressions.aggregate.BoolOr | any | SELECT any(col) FROM VALUES (true), (false), (false) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.BoolOr | bool_or | SELECT bool_or(col) FROM VALUES (true), (false), (false) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.BoolOr | some | SELECT some(col) FROM VALUES (true), (false), (false) AS tab(col) | struct | -| org.apache.spark.sql.catalyst.expressions.aggregate.CollectList | array_agg | SELECT array_agg(col) FROM VALUES (1), (2), (1) AS tab(col) | struct> | +| org.apache.spark.sql.catalyst.expressions.aggregate.CollectList | array_agg | SELECT array_agg(col) FROM VALUES (1), (2), (1) AS tab(col) | struct> | | org.apache.spark.sql.catalyst.expressions.aggregate.CollectList | collect_list | SELECT collect_list(col) FROM VALUES (1), (2), (1) AS tab(col) | struct> | | org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet | collect_set | SELECT collect_set(col) FROM VALUES (1), (2), (1) AS tab(col) | struct> | | org.apache.spark.sql.catalyst.expressions.aggregate.Corr | corr | SELECT corr(c1, c2) FROM VALUES (3, 2), (3, 3), (6, 4) as tab(c1, c2) | struct | diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/date.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/date.sql.out index fd927b99c6456..0e4d2d4e99e26 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/date.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/date.sql.out @@ -736,7 +736,7 @@ Project [to_date(26/October/2015, Some(dd/MMMMM/yyyy), Some(America/Los_Angeles) -- !query select from_json('{"d":"26/October/2015"}', 'd Date', map('dateFormat', 'dd/MMMMM/yyyy')) -- !query analysis -Project [from_json(StructField(d,DateType,true), (dateFormat,dd/MMMMM/yyyy), {"d":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"d":"26/October/2015"})#x] +Project [from_json(StructField(d,DateType,true), (dateFormat,dd/MMMMM/yyyy), {"d":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"d":"26/October/2015"})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/interval.sql.out index 472c9b1df064a..b0d128c4cab69 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/interval.sql.out @@ -2108,7 +2108,7 @@ SELECT to_csv(named_struct('a', interval 32 year, 'b', interval 10 month)), from_csv(to_csv(named_struct('a', interval 32 year, 'b', interval 10 month)), 'a interval year, b interval month') -- !query analysis -Project [from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles)) AS from_json({"a":"1 days"})#x, from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None) AS from_csv(1, 1)#x, to_json(from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1 days"}))#x, to_csv(from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None), Some(America/Los_Angeles)) AS to_csv(from_csv(1, 1))#x, to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)) AS to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH))#x, from_csv(StructField(a,YearMonthIntervalType(0,0),true), StructField(b,YearMonthIntervalType(1,1),true), to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles), None) AS from_csv(to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH)))#x] +Project [from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles), false) AS from_json({"a":"1 days"})#x, from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None) AS from_csv(1, 1)#x, to_json(from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles), false), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1 days"}))#x, to_csv(from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None), Some(America/Los_Angeles)) AS to_csv(from_csv(1, 1))#x, to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)) AS to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH))#x, from_csv(StructField(a,YearMonthIntervalType(0,0),true), StructField(b,YearMonthIntervalType(1,1),true), to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles), None) AS from_csv(to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH)))#x] +- OneRowRelation @@ -2119,7 +2119,7 @@ SELECT to_json(map('a', interval 100 day 130 minute)), from_json(to_json(map('a', interval 100 day 130 minute)), 'a interval day to minute') -- !query analysis -Project [from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE))#x, from_json(StructField(a,DayTimeIntervalType(0,2),true), to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS from_json(to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE)))#x] +Project [from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE))#x, from_json(StructField(a,DayTimeIntervalType(0,2),true), to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)), Some(America/Los_Angeles), false) AS from_json(to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE)))#x] +- OneRowRelation @@ -2130,7 +2130,7 @@ SELECT to_json(map('a', interval 32 year 10 month)), from_json(to_json(map('a', interval 32 year 10 month)), 'a interval year to month') -- !query analysis -Project [from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '32-10' YEAR TO MONTH))#x, from_json(StructField(a,YearMonthIntervalType(0,1),true), to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS from_json(to_json(map(a, INTERVAL '32-10' YEAR TO MONTH)))#x] +Project [from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '32-10' YEAR TO MONTH))#x, from_json(StructField(a,YearMonthIntervalType(0,1),true), to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles), false) AS from_json(to_json(map(a, INTERVAL '32-10' YEAR TO MONTH)))#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/parse-schema-string.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/parse-schema-string.sql.out index 45fc3bd03a782..ae8e47ed3665c 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/parse-schema-string.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/parse-schema-string.sql.out @@ -16,12 +16,12 @@ Project [from_csv(StructField(cube,IntegerType,true), 1, Some(America/Los_Angele -- !query select from_json('{"create":1}', 'create INT') -- !query analysis -Project [from_json(StructField(create,IntegerType,true), {"create":1}, Some(America/Los_Angeles)) AS from_json({"create":1})#x] +Project [from_json(StructField(create,IntegerType,true), {"create":1}, Some(America/Los_Angeles), false) AS from_json({"create":1})#x] +- OneRowRelation -- !query select from_json('{"cube":1}', 'cube INT') -- !query analysis -Project [from_json(StructField(cube,IntegerType,true), {"cube":1}, Some(America/Los_Angeles)) AS from_json({"cube":1})#x] +Project [from_json(StructField(cube,IntegerType,true), {"cube":1}, Some(America/Los_Angeles), false) AS from_json({"cube":1})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/timestamp.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/timestamp.sql.out index bf34490d657e3..560974d28c545 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/timestamp.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/timestamp.sql.out @@ -730,7 +730,7 @@ Project [unix_timestamp(22 05 2020 Friday, dd MM yyyy EEEEE, Some(America/Los_An -- !query select from_json('{"t":"26/October/2015"}', 't Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')) -- !query analysis -Project [from_json(StructField(t,TimestampType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"t":"26/October/2015"})#x] +Project [from_json(StructField(t,TimestampType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"t":"26/October/2015"})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/charvarchar.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/charvarchar.sql.out index 5c1417f7c0aae..d4bcb8f2ed042 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/charvarchar.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/charvarchar.sql.out @@ -263,6 +263,18 @@ desc formatted char_part DescribeTableCommand `spark_catalog`.`default`.`char_part`, true, [col_name#x, data_type#x, comment#x] +-- !query +alter table char_part change column c1 comment 'char comment' +-- !query analysis +AlterTableChangeColumnCommand `spark_catalog`.`default`.`char_part`, c1, StructField(c1,CharType(5),true) + + +-- !query +alter table char_part change column v1 comment 'varchar comment' +-- !query analysis +AlterTableChangeColumnCommand `spark_catalog`.`default`.`char_part`, v1, StructField(v1,VarcharType(6),true) + + -- !query alter table char_part add partition (v2='ke', c2='nt') location 'loc1' -- !query analysis @@ -710,19 +722,19 @@ Project [chr(cast(167 as bigint)) AS chr(167)#x, chr(cast(247 as bigint)) AS chr -- !query SELECT to_varchar(78.12, '$99.99') -- !query analysis -Project [to_char(78.12, $99.99) AS to_char(78.12, $99.99)#x] +Project [to_varchar(78.12, $99.99) AS to_varchar(78.12, $99.99)#x] +- OneRowRelation -- !query SELECT to_varchar(111.11, '99.9') -- !query analysis -Project [to_char(111.11, 99.9) AS to_char(111.11, 99.9)#x] +Project [to_varchar(111.11, 99.9) AS to_varchar(111.11, 99.9)#x] +- OneRowRelation -- !query SELECT to_varchar(12454.8, '99,999.9S') -- !query analysis -Project [to_char(12454.8, 99,999.9S) AS to_char(12454.8, 99,999.9S)#x] +Project [to_varchar(12454.8, 99,999.9S) AS to_varchar(12454.8, 99,999.9S)#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out index 83c9ebfef4b25..eed7fa73ab698 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out @@ -436,6 +436,30 @@ Project [str_to_map(collate(text#x, utf8_binary), collate(pairDelim#x, utf8_bina +- Relation spark_catalog.default.t4[text#x,pairDelim#x,keyValueDelim#x] parquet +-- !query +select str_to_map(text collate unicode_ai, pairDelim collate unicode_ai, keyValueDelim collate unicode_ai) from t4 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(text, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"str_to_map(collate(text, unicode_ai), collate(pairDelim, unicode_ai), collate(keyValueDelim, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 106, + "fragment" : "str_to_map(text collate unicode_ai, pairDelim collate unicode_ai, keyValueDelim collate unicode_ai)" + } ] +} + + -- !query drop table t4 -- !query analysis @@ -820,6 +844,30 @@ Project [split_part(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, ut +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select split_part(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"split_part(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 2)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 83, + "fragment" : "split_part(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2)" + } ] +} + + -- !query select split_part(utf8_binary, 'a', 3), split_part(utf8_lcase, 'a', 3) from t5 -- !query analysis @@ -883,6 +931,30 @@ Project [Contains(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select contains(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"contains(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 78, + "fragment" : "contains(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select contains(utf8_binary, 'a'), contains(utf8_lcase, 'a') from t5 -- !query analysis @@ -946,6 +1018,30 @@ Project [substring_index(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase# +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select substring_index(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"substring_index(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 2)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 88, + "fragment" : "substring_index(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2)" + } ] +} + + -- !query select substring_index(utf8_binary, 'a', 2), substring_index(utf8_lcase, 'a', 2) from t5 -- !query analysis @@ -1009,6 +1105,30 @@ Project [instr(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lc +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select instr(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"instr(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "instr(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select instr(utf8_binary, 'a'), instr(utf8_lcase, 'a') from t5 -- !query analysis @@ -1135,6 +1255,30 @@ Project [StartsWith(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, ut +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select startswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"startswith(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 80, + "fragment" : "startswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select startswith(utf8_binary, 'aaAaaAaA'), startswith(utf8_lcase, 'aaAaaAaA') from t5 -- !query analysis @@ -1190,6 +1334,30 @@ Project [translate(cast(utf8_binary#x as string collate UTF8_LCASE), collate(SQL +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select translate(utf8_binary, 'SQL' collate unicode_ai, '12345' collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"utf8_binary\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"translate(utf8_binary, collate(SQL, unicode_ai), collate(12345, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 83, + "fragment" : "translate(utf8_binary, 'SQL' collate unicode_ai, '12345' collate unicode_ai)" + } ] +} + + -- !query select translate(utf8_lcase, 'aaAaaAaA', '12345'), translate(utf8_binary, 'aaAaaAaA', '12345') from t5 -- !query analysis @@ -1253,6 +1421,30 @@ Project [replace(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select replace(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 'abc') from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"replace(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), abc)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 84, + "fragment" : "replace(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 'abc')" + } ] +} + + -- !query select replace(utf8_binary, 'aaAaaAaA', 'abc'), replace(utf8_lcase, 'aaAaaAaA', 'abc') from t5 -- !query analysis @@ -1316,6 +1508,30 @@ Project [EndsWith(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select endswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"endswith(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 78, + "fragment" : "endswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select endswith(utf8_binary, 'aaAaaAaA'), endswith(utf8_lcase, 'aaAaaAaA') from t5 -- !query analysis @@ -2039,6 +2255,30 @@ Project [locate(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_l +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select locate(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 3) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"locate(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 3)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 79, + "fragment" : "locate(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 3)" + } ] +} + + -- !query select locate(utf8_binary, 'a'), locate(utf8_lcase, 'a') from t5 -- !query analysis @@ -2102,6 +2342,30 @@ Project [trim(collate(utf8_lcase#x, utf8_lcase), Some(collate(utf8_binary#x, utf +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select TRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(BOTH collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 74, + "fragment" : "TRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select TRIM('ABc', utf8_binary), TRIM('ABc', utf8_lcase) from t5 -- !query analysis @@ -2165,6 +2429,30 @@ Project [btrim(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lc +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select BTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(BOTH collate(utf8_lcase, unicode_ai) FROM collate(utf8_binary, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "BTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select BTRIM('ABc', utf8_binary), BTRIM('ABc', utf8_lcase) from t5 -- !query analysis @@ -2228,6 +2516,30 @@ Project [ltrim(collate(utf8_lcase#x, utf8_lcase), Some(collate(utf8_binary#x, ut +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select LTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(LEADING collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "LTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select LTRIM('ABc', utf8_binary), LTRIM('ABc', utf8_lcase) from t5 -- !query analysis @@ -2291,6 +2603,30 @@ Project [rtrim(collate(utf8_lcase#x, utf8_lcase), Some(collate(utf8_binary#x, ut +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select RTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(TRAILING collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "RTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select RTRIM('ABc', utf8_binary), RTRIM('ABc', utf8_lcase) from t5 -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/current_database_catalog.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/current_database_catalog.sql.out index 1a71594f84932..2759f5e67507b 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/current_database_catalog.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/current_database_catalog.sql.out @@ -2,5 +2,5 @@ -- !query select current_database(), current_schema(), current_catalog() -- !query analysis -Project [current_schema() AS current_schema()#x, current_schema() AS current_schema()#x, current_catalog() AS current_catalog()#x] +Project [current_database() AS current_database()#x, current_schema() AS current_schema()#x, current_catalog() AS current_catalog()#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/date.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/date.sql.out index 48137e06467e8..88c7d7b4e7d72 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/date.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/date.sql.out @@ -811,7 +811,7 @@ Project [to_date(26/October/2015, Some(dd/MMMMM/yyyy), Some(America/Los_Angeles) -- !query select from_json('{"d":"26/October/2015"}', 'd Date', map('dateFormat', 'dd/MMMMM/yyyy')) -- !query analysis -Project [from_json(StructField(d,DateType,true), (dateFormat,dd/MMMMM/yyyy), {"d":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"d":"26/October/2015"})#x] +Project [from_json(StructField(d,DateType,true), (dateFormat,dd/MMMMM/yyyy), {"d":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"d":"26/October/2015"})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/datetime-legacy.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/datetime-legacy.sql.out index 1e49f4df8267a..4221db822d024 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/datetime-legacy.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/datetime-legacy.sql.out @@ -811,7 +811,7 @@ Project [to_date(26/October/2015, Some(dd/MMMMM/yyyy), Some(America/Los_Angeles) -- !query select from_json('{"d":"26/October/2015"}', 'd Date', map('dateFormat', 'dd/MMMMM/yyyy')) -- !query analysis -Project [from_json(StructField(d,DateType,true), (dateFormat,dd/MMMMM/yyyy), {"d":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"d":"26/October/2015"})#x] +Project [from_json(StructField(d,DateType,true), (dateFormat,dd/MMMMM/yyyy), {"d":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"d":"26/October/2015"})#x] +- OneRowRelation @@ -1833,7 +1833,7 @@ Project [unix_timestamp(22 05 2020 Friday, dd MM yyyy EEEEE, Some(America/Los_An -- !query select from_json('{"t":"26/October/2015"}', 't Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')) -- !query analysis -Project [from_json(StructField(t,TimestampType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"t":"26/October/2015"})#x] +Project [from_json(StructField(t,TimestampType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"t":"26/October/2015"})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out index 8849aa4452252..6996eb913a21e 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out @@ -1133,7 +1133,7 @@ SELECT FROM VALUES (1), (2), (1) AS tab(col) -- !query analysis -Aggregate [collect_list(col#x, 0, 0) AS collect_list(col)#x, collect_list(col#x, 0, 0) AS collect_list(col)#x] +Aggregate [collect_list(col#x, 0, 0) AS collect_list(col)#x, array_agg(col#x, 0, 0) AS array_agg(col)#x] +- SubqueryAlias tab +- LocalRelation [col#x] @@ -1147,7 +1147,7 @@ FROM VALUES (1,4),(2,3),(1,4),(2,4) AS v(a,b) GROUP BY a -- !query analysis -Aggregate [a#x], [a#x, collect_list(b#x, 0, 0) AS collect_list(b)#x, collect_list(b#x, 0, 0) AS collect_list(b)#x] +Aggregate [a#x], [a#x, collect_list(b#x, 0, 0) AS collect_list(b)#x, array_agg(b#x, 0, 0) AS array_agg(b)#x] +- SubqueryAlias v +- LocalRelation [a#x, b#x] diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out index f0bf8b883dd8b..20e6ca1e6a2ec 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out @@ -893,7 +893,7 @@ org.apache.spark.sql.catalyst.parser.ParseException "errorClass" : "INVALID_SQL_SYNTAX.MULTI_PART_NAME", "sqlState" : "42000", "messageParameters" : { - "funcName" : "`default`.`myDoubleAvg`", + "name" : "`default`.`myDoubleAvg`", "statement" : "DROP TEMPORARY FUNCTION" }, "queryContext" : [ { diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out index 3db38d482b26d..efa149509751d 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out @@ -2108,7 +2108,7 @@ SELECT to_csv(named_struct('a', interval 32 year, 'b', interval 10 month)), from_csv(to_csv(named_struct('a', interval 32 year, 'b', interval 10 month)), 'a interval year, b interval month') -- !query analysis -Project [from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles)) AS from_json({"a":"1 days"})#x, from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None) AS from_csv(1, 1)#x, to_json(from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1 days"}))#x, to_csv(from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None), Some(America/Los_Angeles)) AS to_csv(from_csv(1, 1))#x, to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)) AS to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH))#x, from_csv(StructField(a,YearMonthIntervalType(0,0),true), StructField(b,YearMonthIntervalType(1,1),true), to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles), None) AS from_csv(to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH)))#x] +Project [from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles), false) AS from_json({"a":"1 days"})#x, from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None) AS from_csv(1, 1)#x, to_json(from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles), false), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1 days"}))#x, to_csv(from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None), Some(America/Los_Angeles)) AS to_csv(from_csv(1, 1))#x, to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)) AS to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH))#x, from_csv(StructField(a,YearMonthIntervalType(0,0),true), StructField(b,YearMonthIntervalType(1,1),true), to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles), None) AS from_csv(to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH)))#x] +- OneRowRelation @@ -2119,7 +2119,7 @@ SELECT to_json(map('a', interval 100 day 130 minute)), from_json(to_json(map('a', interval 100 day 130 minute)), 'a interval day to minute') -- !query analysis -Project [from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE))#x, from_json(StructField(a,DayTimeIntervalType(0,2),true), to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS from_json(to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE)))#x] +Project [from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE))#x, from_json(StructField(a,DayTimeIntervalType(0,2),true), to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)), Some(America/Los_Angeles), false) AS from_json(to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE)))#x] +- OneRowRelation @@ -2130,7 +2130,7 @@ SELECT to_json(map('a', interval 32 year 10 month)), from_json(to_json(map('a', interval 32 year 10 month)), 'a interval year to month') -- !query analysis -Project [from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '32-10' YEAR TO MONTH))#x, from_json(StructField(a,YearMonthIntervalType(0,1),true), to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS from_json(to_json(map(a, INTERVAL '32-10' YEAR TO MONTH)))#x] +Project [from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '32-10' YEAR TO MONTH))#x, from_json(StructField(a,YearMonthIntervalType(0,1),true), to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles), false) AS from_json(to_json(map(a, INTERVAL '32-10' YEAR TO MONTH)))#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/json-functions.sql.out index 0d7c6b2056231..fef9d0c5b6250 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/json-functions.sql.out @@ -118,14 +118,14 @@ org.apache.spark.sql.AnalysisException -- !query select from_json('{"a":1}', 'a INT') -- !query analysis -Project [from_json(StructField(a,IntegerType,true), {"a":1}, Some(America/Los_Angeles)) AS from_json({"a":1})#x] +Project [from_json(StructField(a,IntegerType,true), {"a":1}, Some(America/Los_Angeles), false) AS from_json({"a":1})#x] +- OneRowRelation -- !query select from_json('{"time":"26/08/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy')) -- !query analysis -Project [from_json(StructField(time,TimestampType,true), (timestampFormat,dd/MM/yyyy), {"time":"26/08/2015"}, Some(America/Los_Angeles)) AS from_json({"time":"26/08/2015"})#x] +Project [from_json(StructField(time,TimestampType,true), (timestampFormat,dd/MM/yyyy), {"time":"26/08/2015"}, Some(America/Los_Angeles), false) AS from_json({"time":"26/08/2015"})#x] +- OneRowRelation @@ -279,14 +279,14 @@ DropTempViewCommand jsonTable -- !query select from_json('{"a":1, "b":2}', 'map') -- !query analysis -Project [from_json(MapType(StringType,IntegerType,true), {"a":1, "b":2}, Some(America/Los_Angeles)) AS entries#x] +Project [from_json(MapType(StringType,IntegerType,true), {"a":1, "b":2}, Some(America/Los_Angeles), false) AS entries#x] +- OneRowRelation -- !query select from_json('{"a":1, "b":"2"}', 'struct') -- !query analysis -Project [from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), {"a":1, "b":"2"}, Some(America/Los_Angeles)) AS from_json({"a":1, "b":"2"})#x] +Project [from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), {"a":1, "b":"2"}, Some(America/Los_Angeles), false) AS from_json({"a":1, "b":"2"})#x] +- OneRowRelation @@ -300,70 +300,70 @@ Project [schema_of_json({"c1":0, "c2":[1]}) AS schema_of_json({"c1":0, "c2":[1]} -- !query select from_json('{"c1":[1, 2, 3]}', schema_of_json('{"c1":[0]}')) -- !query analysis -Project [from_json(StructField(c1,ArrayType(LongType,true),true), {"c1":[1, 2, 3]}, Some(America/Los_Angeles)) AS from_json({"c1":[1, 2, 3]})#x] +Project [from_json(StructField(c1,ArrayType(LongType,true),true), {"c1":[1, 2, 3]}, Some(America/Los_Angeles), false) AS from_json({"c1":[1, 2, 3]})#x] +- OneRowRelation -- !query select from_json('[1, 2, 3]', 'array') -- !query analysis -Project [from_json(ArrayType(IntegerType,true), [1, 2, 3], Some(America/Los_Angeles)) AS from_json([1, 2, 3])#x] +Project [from_json(ArrayType(IntegerType,true), [1, 2, 3], Some(America/Los_Angeles), false) AS from_json([1, 2, 3])#x] +- OneRowRelation -- !query select from_json('[1, "2", 3]', 'array') -- !query analysis -Project [from_json(ArrayType(IntegerType,true), [1, "2", 3], Some(America/Los_Angeles)) AS from_json([1, "2", 3])#x] +Project [from_json(ArrayType(IntegerType,true), [1, "2", 3], Some(America/Los_Angeles), false) AS from_json([1, "2", 3])#x] +- OneRowRelation -- !query select from_json('[1, 2, null]', 'array') -- !query analysis -Project [from_json(ArrayType(IntegerType,true), [1, 2, null], Some(America/Los_Angeles)) AS from_json([1, 2, null])#x] +Project [from_json(ArrayType(IntegerType,true), [1, 2, null], Some(America/Los_Angeles), false) AS from_json([1, 2, null])#x] +- OneRowRelation -- !query select from_json('[{"a": 1}, {"a":2}]', 'array>') -- !query analysis -Project [from_json(ArrayType(StructType(StructField(a,IntegerType,true)),true), [{"a": 1}, {"a":2}], Some(America/Los_Angeles)) AS from_json([{"a": 1}, {"a":2}])#x] +Project [from_json(ArrayType(StructType(StructField(a,IntegerType,true)),true), [{"a": 1}, {"a":2}], Some(America/Los_Angeles), false) AS from_json([{"a": 1}, {"a":2}])#x] +- OneRowRelation -- !query select from_json('{"a": 1}', 'array>') -- !query analysis -Project [from_json(ArrayType(StructType(StructField(a,IntegerType,true)),true), {"a": 1}, Some(America/Los_Angeles)) AS from_json({"a": 1})#x] +Project [from_json(ArrayType(StructType(StructField(a,IntegerType,true)),true), {"a": 1}, Some(America/Los_Angeles), false) AS from_json({"a": 1})#x] +- OneRowRelation -- !query select from_json('[null, {"a":2}]', 'array>') -- !query analysis -Project [from_json(ArrayType(StructType(StructField(a,IntegerType,true)),true), [null, {"a":2}], Some(America/Los_Angeles)) AS from_json([null, {"a":2}])#x] +Project [from_json(ArrayType(StructType(StructField(a,IntegerType,true)),true), [null, {"a":2}], Some(America/Los_Angeles), false) AS from_json([null, {"a":2}])#x] +- OneRowRelation -- !query select from_json('[{"a": 1}, {"b":2}]', 'array>') -- !query analysis -Project [from_json(ArrayType(MapType(StringType,IntegerType,true),true), [{"a": 1}, {"b":2}], Some(America/Los_Angeles)) AS from_json([{"a": 1}, {"b":2}])#x] +Project [from_json(ArrayType(MapType(StringType,IntegerType,true),true), [{"a": 1}, {"b":2}], Some(America/Los_Angeles), false) AS from_json([{"a": 1}, {"b":2}])#x] +- OneRowRelation -- !query select from_json('[{"a": 1}, 2]', 'array>') -- !query analysis -Project [from_json(ArrayType(MapType(StringType,IntegerType,true),true), [{"a": 1}, 2], Some(America/Los_Angeles)) AS from_json([{"a": 1}, 2])#x] +Project [from_json(ArrayType(MapType(StringType,IntegerType,true),true), [{"a": 1}, 2], Some(America/Los_Angeles), false) AS from_json([{"a": 1}, 2])#x] +- OneRowRelation -- !query select from_json('{"d": "2012-12-15", "t": "2012-12-15 15:15:15"}', 'd date, t timestamp') -- !query analysis -Project [from_json(StructField(d,DateType,true), StructField(t,TimestampType,true), {"d": "2012-12-15", "t": "2012-12-15 15:15:15"}, Some(America/Los_Angeles)) AS from_json({"d": "2012-12-15", "t": "2012-12-15 15:15:15"})#x] +Project [from_json(StructField(d,DateType,true), StructField(t,TimestampType,true), {"d": "2012-12-15", "t": "2012-12-15 15:15:15"}, Some(America/Los_Angeles), false) AS from_json({"d": "2012-12-15", "t": "2012-12-15 15:15:15"})#x] +- OneRowRelation @@ -373,7 +373,7 @@ select from_json( 'd date, t timestamp', map('dateFormat', 'MM/dd yyyy', 'timestampFormat', 'MM/dd yyyy HH:mm:ss')) -- !query analysis -Project [from_json(StructField(d,DateType,true), StructField(t,TimestampType,true), (dateFormat,MM/dd yyyy), (timestampFormat,MM/dd yyyy HH:mm:ss), {"d": "12/15 2012", "t": "12/15 2012 15:15:15"}, Some(America/Los_Angeles)) AS from_json({"d": "12/15 2012", "t": "12/15 2012 15:15:15"})#x] +Project [from_json(StructField(d,DateType,true), StructField(t,TimestampType,true), (dateFormat,MM/dd yyyy), (timestampFormat,MM/dd yyyy HH:mm:ss), {"d": "12/15 2012", "t": "12/15 2012 15:15:15"}, Some(America/Los_Angeles), false) AS from_json({"d": "12/15 2012", "t": "12/15 2012 15:15:15"})#x] +- OneRowRelation @@ -383,7 +383,7 @@ select from_json( 'd date', map('dateFormat', 'MM-dd')) -- !query analysis -Project [from_json(StructField(d,DateType,true), (dateFormat,MM-dd), {"d": "02-29"}, Some(America/Los_Angeles)) AS from_json({"d": "02-29"})#x] +Project [from_json(StructField(d,DateType,true), (dateFormat,MM-dd), {"d": "02-29"}, Some(America/Los_Angeles), false) AS from_json({"d": "02-29"})#x] +- OneRowRelation @@ -393,7 +393,7 @@ select from_json( 't timestamp', map('timestampFormat', 'MM-dd')) -- !query analysis -Project [from_json(StructField(t,TimestampType,true), (timestampFormat,MM-dd), {"t": "02-29"}, Some(America/Los_Angeles)) AS from_json({"t": "02-29"})#x] +Project [from_json(StructField(t,TimestampType,true), (timestampFormat,MM-dd), {"t": "02-29"}, Some(America/Los_Angeles), false) AS from_json({"t": "02-29"})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/null-handling.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/null-handling.sql.out index 26e9394932a17..37d84f6c5fc00 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/null-handling.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/null-handling.sql.out @@ -69,6 +69,24 @@ Project [a#x, (b#x + c#x) AS (b + c)#x] +- Relation spark_catalog.default.t1[a#x,b#x,c#x] parquet +-- !query +select b + 0 from t1 where a = 5 +-- !query analysis +Project [(b#x + 0) AS (b + 0)#x] ++- Filter (a#x = 5) + +- SubqueryAlias spark_catalog.default.t1 + +- Relation spark_catalog.default.t1[a#x,b#x,c#x] parquet + + +-- !query +select -100 + b + 100 from t1 where a = 5 +-- !query analysis +Project [((-100 + b#x) + 100) AS ((-100 + b) + 100)#x] ++- Filter (a#x = 5) + +- SubqueryAlias spark_catalog.default.t1 + +- Relation spark_catalog.default.t1[a#x,b#x,c#x] parquet + + -- !query select a+10, b*0 from t1 -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/parse-schema-string.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/parse-schema-string.sql.out index 45fc3bd03a782..ae8e47ed3665c 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/parse-schema-string.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/parse-schema-string.sql.out @@ -16,12 +16,12 @@ Project [from_csv(StructField(cube,IntegerType,true), 1, Some(America/Los_Angele -- !query select from_json('{"create":1}', 'create INT') -- !query analysis -Project [from_json(StructField(create,IntegerType,true), {"create":1}, Some(America/Los_Angeles)) AS from_json({"create":1})#x] +Project [from_json(StructField(create,IntegerType,true), {"create":1}, Some(America/Los_Angeles), false) AS from_json({"create":1})#x] +- OneRowRelation -- !query select from_json('{"cube":1}', 'cube INT') -- !query analysis -Project [from_json(StructField(cube,IntegerType,true), {"cube":1}, Some(America/Los_Angeles)) AS from_json({"cube":1})#x] +Project [from_json(StructField(cube,IntegerType,true), {"cube":1}, Some(America/Los_Angeles), false) AS from_json({"cube":1})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out index ab0635fef048b..7fa4ec0514ff0 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out @@ -62,6 +62,180 @@ InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_d +- LocalRelation [col1#x, col2#x] +-- !query +create temporary view courseSales as select * from values + ("dotNET", 2012, 10000), + ("Java", 2012, 20000), + ("dotNET", 2012, 5000), + ("dotNET", 2013, 48000), + ("Java", 2013, 30000) + as courseSales(course, year, earnings) +-- !query analysis +CreateViewCommand `courseSales`, select * from values + ("dotNET", 2012, 10000), + ("Java", 2012, 20000), + ("dotNET", 2012, 5000), + ("dotNET", 2013, 48000), + ("Java", 2013, 30000) + as courseSales(course, year, earnings), false, false, LocalTempView, UNSUPPORTED, true + +- Project [course#x, year#x, earnings#x] + +- SubqueryAlias courseSales + +- LocalRelation [course#x, year#x, earnings#x] + + +-- !query +create temporary view courseEarnings as select * from values + ("dotNET", 15000, 48000, 22500), + ("Java", 20000, 30000, NULL) + as courseEarnings(course, `2012`, `2013`, `2014`) +-- !query analysis +CreateViewCommand `courseEarnings`, select * from values + ("dotNET", 15000, 48000, 22500), + ("Java", 20000, 30000, NULL) + as courseEarnings(course, `2012`, `2013`, `2014`), false, false, LocalTempView, UNSUPPORTED, true + +- Project [course#x, 2012#x, 2013#x, 2014#x] + +- SubqueryAlias courseEarnings + +- LocalRelation [course#x, 2012#x, 2013#x, 2014#x] + + +-- !query +create temporary view courseEarningsAndSales as select * from values + ("dotNET", 15000, NULL, 48000, 1, 22500, 1), + ("Java", 20000, 1, 30000, 2, NULL, NULL) + as courseEarningsAndSales( + course, earnings2012, sales2012, earnings2013, sales2013, earnings2014, sales2014) +-- !query analysis +CreateViewCommand `courseEarningsAndSales`, select * from values + ("dotNET", 15000, NULL, 48000, 1, 22500, 1), + ("Java", 20000, 1, 30000, 2, NULL, NULL) + as courseEarningsAndSales( + course, earnings2012, sales2012, earnings2013, sales2013, earnings2014, sales2014), false, false, LocalTempView, UNSUPPORTED, true + +- Project [course#x, earnings2012#x, sales2012#x, earnings2013#x, sales2013#x, earnings2014#x, sales2014#x] + +- SubqueryAlias courseEarningsAndSales + +- LocalRelation [course#x, earnings2012#x, sales2012#x, earnings2013#x, sales2013#x, earnings2014#x, sales2014#x] + + +-- !query +create temporary view yearsWithComplexTypes as select * from values + (2012, array(1, 1), map('1', 1), struct(1, 'a')), + (2013, array(2, 2), map('2', 2), struct(2, 'b')) + as yearsWithComplexTypes(y, a, m, s) +-- !query analysis +CreateViewCommand `yearsWithComplexTypes`, select * from values + (2012, array(1, 1), map('1', 1), struct(1, 'a')), + (2013, array(2, 2), map('2', 2), struct(2, 'b')) + as yearsWithComplexTypes(y, a, m, s), false, false, LocalTempView, UNSUPPORTED, true + +- Project [y#x, a#x, m#x, s#x] + +- SubqueryAlias yearsWithComplexTypes + +- LocalRelation [y#x, a#x, m#x, s#x] + + +-- !query +create temporary view join_test_t1 as select * from values (1) as grouping(a) +-- !query analysis +CreateViewCommand `join_test_t1`, select * from values (1) as grouping(a), false, false, LocalTempView, UNSUPPORTED, true + +- Project [a#x] + +- SubqueryAlias grouping + +- LocalRelation [a#x] + + +-- !query +create temporary view join_test_t2 as select * from values (1) as grouping(a) +-- !query analysis +CreateViewCommand `join_test_t2`, select * from values (1) as grouping(a), false, false, LocalTempView, UNSUPPORTED, true + +- Project [a#x] + +- SubqueryAlias grouping + +- LocalRelation [a#x] + + +-- !query +create temporary view join_test_t3 as select * from values (1) as grouping(a) +-- !query analysis +CreateViewCommand `join_test_t3`, select * from values (1) as grouping(a), false, false, LocalTempView, UNSUPPORTED, true + +- Project [a#x] + +- SubqueryAlias grouping + +- LocalRelation [a#x] + + +-- !query +create temporary view join_test_empty_table as select a from join_test_t2 where false +-- !query analysis +CreateViewCommand `join_test_empty_table`, select a from join_test_t2 where false, false, false, LocalTempView, UNSUPPORTED, true + +- Project [a#x] + +- Filter false + +- SubqueryAlias join_test_t2 + +- View (`join_test_t2`, [a#x]) + +- Project [cast(a#x as int) AS a#x] + +- Project [a#x] + +- SubqueryAlias grouping + +- LocalRelation [a#x] + + +-- !query +create temporary view lateral_test_t1(c1, c2) + as values (0, 1), (1, 2) +-- !query analysis +CreateViewCommand `lateral_test_t1`, [(c1,None), (c2,None)], values (0, 1), (1, 2), false, false, LocalTempView, UNSUPPORTED, true + +- LocalRelation [col1#x, col2#x] + + +-- !query +create temporary view lateral_test_t2(c1, c2) + as values (0, 2), (0, 3) +-- !query analysis +CreateViewCommand `lateral_test_t2`, [(c1,None), (c2,None)], values (0, 2), (0, 3), false, false, LocalTempView, UNSUPPORTED, true + +- LocalRelation [col1#x, col2#x] + + +-- !query +create temporary view lateral_test_t3(c1, c2) + as values (0, array(0, 1)), (1, array(2)), (2, array()), (null, array(4)) +-- !query analysis +CreateViewCommand `lateral_test_t3`, [(c1,None), (c2,None)], values (0, array(0, 1)), (1, array(2)), (2, array()), (null, array(4)), false, false, LocalTempView, UNSUPPORTED, true + +- LocalRelation [col1#x, col2#x] + + +-- !query +create temporary view lateral_test_t4(c1, c2) + as values (0, 1), (0, 2), (1, 1), (1, 3) +-- !query analysis +CreateViewCommand `lateral_test_t4`, [(c1,None), (c2,None)], values (0, 1), (0, 2), (1, 1), (1, 3), false, false, LocalTempView, UNSUPPORTED, true + +- LocalRelation [col1#x, col2#x] + + +-- !query +create temporary view natural_join_test_t1 as select * from values + ("one", 1), ("two", 2), ("three", 3) as natural_join_test_t1(k, v1) +-- !query analysis +CreateViewCommand `natural_join_test_t1`, select * from values + ("one", 1), ("two", 2), ("three", 3) as natural_join_test_t1(k, v1), false, false, LocalTempView, UNSUPPORTED, true + +- Project [k#x, v1#x] + +- SubqueryAlias natural_join_test_t1 + +- LocalRelation [k#x, v1#x] + + +-- !query +create temporary view natural_join_test_t2 as select * from values + ("one", 1), ("two", 22), ("one", 5) as natural_join_test_t2(k, v2) +-- !query analysis +CreateViewCommand `natural_join_test_t2`, select * from values + ("one", 1), ("two", 22), ("one", 5) as natural_join_test_t2(k, v2), false, false, LocalTempView, UNSUPPORTED, true + +- Project [k#x, v2#x] + +- SubqueryAlias natural_join_test_t2 + +- LocalRelation [k#x, v2#x] + + +-- !query +create temporary view natural_join_test_t3 as select * from values + ("one", 4), ("two", 5), ("one", 6) as natural_join_test_t3(k, v3) +-- !query analysis +CreateViewCommand `natural_join_test_t3`, select * from values + ("one", 4), ("two", 5), ("one", 6) as natural_join_test_t3(k, v3), false, false, LocalTempView, UNSUPPORTED, true + +- Project [k#x, v3#x] + +- SubqueryAlias natural_join_test_t3 + +- LocalRelation [k#x, v3#x] + + -- !query table t |> select 1 as x @@ -255,6 +429,55 @@ Distinct +- Relation spark_catalog.default.t[x#x,y#x] csv +-- !query +table t +|> select * +-- !query analysis +Project [x#x, y#x] ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select * except (y) +-- !query analysis +Project [x#x] ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select /*+ repartition(3) */ * +-- !query analysis +Repartition 3, true ++- Project [x#x, y#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select /*+ repartition(3) */ distinct x +-- !query analysis +Repartition 3, true ++- Distinct + +- Project [x#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select /*+ repartition(3) */ all x +-- !query analysis +Repartition 3, true ++- Project [x#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + -- !query table t |> select sum(x) as result @@ -297,6 +520,1508 @@ org.apache.spark.sql.AnalysisException } +-- !query +table t +|> where true +-- !query analysis +Filter true ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where x + length(y) < 4 +-- !query analysis +Filter ((x#x + length(y#x)) < 4) ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where x + length(y) < 4 +|> where x + length(y) < 3 +-- !query analysis +Filter ((x#x + length(y#x)) < 3) ++- SubqueryAlias __auto_generated_subquery_name + +- Filter ((x#x + length(y#x)) < 4) + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +(select x, sum(length(y)) as sum_len from t group by x) +|> where x = 1 +-- !query analysis +Filter (x#x = 1) ++- SubqueryAlias __auto_generated_subquery_name + +- Aggregate [x#x], [x#x, sum(length(y#x)) AS sum_len#xL] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where t.x = 1 +-- !query analysis +Filter (x#x = 1) ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where spark_catalog.default.t.x = 1 +-- !query analysis +Filter (x#x = 1) ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +(select col from st) +|> where col.i1 = 1 +-- !query analysis +Filter (col#x.i1 = 1) ++- SubqueryAlias __auto_generated_subquery_name + +- Project [col#x] + +- SubqueryAlias spark_catalog.default.st + +- Relation spark_catalog.default.st[x#x,col#x] parquet + + +-- !query +table st +|> where st.col.i1 = 2 +-- !query analysis +Filter (col#x.i1 = 2) ++- SubqueryAlias spark_catalog.default.st + +- Relation spark_catalog.default.st[x#x,col#x] parquet + + +-- !query +table t +|> where exists (select a from other where x = a limit 1) +-- !query analysis +Filter exists#x [x#x] +: +- GlobalLimit 1 +: +- LocalLimit 1 +: +- Project [a#x] +: +- Filter (outer(x#x) = a#x) +: +- SubqueryAlias spark_catalog.default.other +: +- Relation spark_catalog.default.other[a#x,b#x] json ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where (select any_value(a) from other where x = a limit 1) = 1 +-- !query analysis +Filter (scalar-subquery#x [x#x] = 1) +: +- GlobalLimit 1 +: +- LocalLimit 1 +: +- Aggregate [any_value(a#x, false) AS any_value(a)#x] +: +- Filter (outer(x#x) = a#x) +: +- SubqueryAlias spark_catalog.default.other +: +- Relation spark_catalog.default.other[a#x,b#x] json ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where sum(x) = 1 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_WHERE_CONDITION", + "sqlState" : "42903", + "messageParameters" : { + "condition" : "\"(sum(x) = 1)\"", + "expressionList" : "sum(spark_catalog.default.t.x)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 27, + "fragment" : "table t\n|> where sum(x) = 1" + } ] +} + + +-- !query +table t +|> where y = 'abc' or length(y) + sum(x) = 1 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_WHERE_CONDITION", + "sqlState" : "42903", + "messageParameters" : { + "condition" : "\"((y = abc) OR ((length(y) + sum(x)) = 1))\"", + "expressionList" : "sum(spark_catalog.default.t.x)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 52, + "fragment" : "table t\n|> where y = 'abc' or length(y) + sum(x) = 1" + } ] +} + + +-- !query +table t +|> where first_value(x) over (partition by y) = 1 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1034", + "messageParameters" : { + "clauseName" : "WHERE" + } +} + + +-- !query +select * from t where first_value(x) over (partition by y) = 1 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1034", + "messageParameters" : { + "clauseName" : "WHERE" + } +} + + +-- !query +table t +|> select x, length(y) as z +|> where x + length(y) < 4 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`y`", + "proposal" : "`x`, `z`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 57, + "stopIndex" : 57, + "fragment" : "y" + } ] +} + + +-- !query +(select x, sum(length(y)) as sum_len from t group by x) +|> where sum(length(y)) = 3 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`y`", + "proposal" : "`x`, `sum_len`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 77, + "stopIndex" : 77, + "fragment" : "y" + } ] +} + + +-- !query +table courseSales +|> select `year`, course, earnings +|> pivot ( + sum(earnings) + for course in ('dotNET', 'Java') + ) +-- !query analysis +Project [year#x, __pivot_sum(coursesales.earnings) AS `sum(coursesales.earnings)`#x[0] AS dotNET#xL, __pivot_sum(coursesales.earnings) AS `sum(coursesales.earnings)`#x[1] AS Java#xL] ++- Aggregate [year#x], [year#x, pivotfirst(course#x, sum(coursesales.earnings)#xL, dotNET, Java, 0, 0) AS __pivot_sum(coursesales.earnings) AS `sum(coursesales.earnings)`#x] + +- Aggregate [year#x, course#x], [year#x, course#x, sum(earnings#x) AS sum(coursesales.earnings)#xL] + +- Project [year#x, course#x, earnings#x] + +- SubqueryAlias coursesales + +- View (`courseSales`, [course#x, year#x, earnings#x]) + +- Project [cast(course#x as string) AS course#x, cast(year#x as int) AS year#x, cast(earnings#x as int) AS earnings#x] + +- Project [course#x, year#x, earnings#x] + +- SubqueryAlias courseSales + +- LocalRelation [course#x, year#x, earnings#x] + + +-- !query +table courseSales +|> select `year` as y, course as c, earnings as e +|> pivot ( + sum(e) as s, avg(e) as a + for y in (2012 as firstYear, 2013 as secondYear) + ) +-- !query analysis +Project [c#x, __pivot_sum(e) AS s AS `sum(e) AS s`#x[0] AS firstYear_s#xL, __pivot_avg(e) AS a AS `avg(e) AS a`#x[0] AS firstYear_a#x, __pivot_sum(e) AS s AS `sum(e) AS s`#x[1] AS secondYear_s#xL, __pivot_avg(e) AS a AS `avg(e) AS a`#x[1] AS secondYear_a#x] ++- Aggregate [c#x], [c#x, pivotfirst(y#x, sum(e) AS s#xL, 2012, 2013, 0, 0) AS __pivot_sum(e) AS s AS `sum(e) AS s`#x, pivotfirst(y#x, avg(e) AS a#x, 2012, 2013, 0, 0) AS __pivot_avg(e) AS a AS `avg(e) AS a`#x] + +- Aggregate [c#x, y#x], [c#x, y#x, sum(e#x) AS sum(e) AS s#xL, avg(e#x) AS avg(e) AS a#x] + +- Project [pipeselect(year#x) AS y#x, pipeselect(course#x) AS c#x, pipeselect(earnings#x) AS e#x] + +- SubqueryAlias coursesales + +- View (`courseSales`, [course#x, year#x, earnings#x]) + +- Project [cast(course#x as string) AS course#x, cast(year#x as int) AS year#x, cast(earnings#x as int) AS earnings#x] + +- Project [course#x, year#x, earnings#x] + +- SubqueryAlias courseSales + +- LocalRelation [course#x, year#x, earnings#x] + + +-- !query +select course, `year`, y, a +from courseSales +join yearsWithComplexTypes on `year` = y +|> pivot ( + max(a) + for (y, course) in ((2012, 'dotNET'), (2013, 'Java')) + ) +-- !query analysis +Aggregate [year#x], [year#x, max(if ((named_struct(y, y#x, course, course#x) <=> cast(named_struct(col1, 2012, col2, dotNET) as struct))) a#x else cast(null as array)) AS {2012, dotNET}#x, max(if ((named_struct(y, y#x, course, course#x) <=> cast(named_struct(col1, 2013, col2, Java) as struct))) a#x else cast(null as array)) AS {2013, Java}#x] ++- Project [course#x, year#x, y#x, a#x] + +- Join Inner, (year#x = y#x) + :- SubqueryAlias coursesales + : +- View (`courseSales`, [course#x, year#x, earnings#x]) + : +- Project [cast(course#x as string) AS course#x, cast(year#x as int) AS year#x, cast(earnings#x as int) AS earnings#x] + : +- Project [course#x, year#x, earnings#x] + : +- SubqueryAlias courseSales + : +- LocalRelation [course#x, year#x, earnings#x] + +- SubqueryAlias yearswithcomplextypes + +- View (`yearsWithComplexTypes`, [y#x, a#x, m#x, s#x]) + +- Project [cast(y#x as int) AS y#x, cast(a#x as array) AS a#x, cast(m#x as map) AS m#x, cast(s#x as struct) AS s#x] + +- Project [y#x, a#x, m#x, s#x] + +- SubqueryAlias yearsWithComplexTypes + +- LocalRelation [y#x, a#x, m#x, s#x] + + +-- !query +select earnings, `year`, s +from courseSales +join yearsWithComplexTypes on `year` = y +|> pivot ( + sum(earnings) + for s in ((1, 'a'), (2, 'b')) + ) +-- !query analysis +Project [year#x, __pivot_sum(coursesales.earnings) AS `sum(coursesales.earnings)`#x[0] AS {1, a}#xL, __pivot_sum(coursesales.earnings) AS `sum(coursesales.earnings)`#x[1] AS {2, b}#xL] ++- Aggregate [year#x], [year#x, pivotfirst(s#x, sum(coursesales.earnings)#xL, [1,a], [2,b], 0, 0) AS __pivot_sum(coursesales.earnings) AS `sum(coursesales.earnings)`#x] + +- Aggregate [year#x, s#x], [year#x, s#x, sum(earnings#x) AS sum(coursesales.earnings)#xL] + +- Project [earnings#x, year#x, s#x] + +- Join Inner, (year#x = y#x) + :- SubqueryAlias coursesales + : +- View (`courseSales`, [course#x, year#x, earnings#x]) + : +- Project [cast(course#x as string) AS course#x, cast(year#x as int) AS year#x, cast(earnings#x as int) AS earnings#x] + : +- Project [course#x, year#x, earnings#x] + : +- SubqueryAlias courseSales + : +- LocalRelation [course#x, year#x, earnings#x] + +- SubqueryAlias yearswithcomplextypes + +- View (`yearsWithComplexTypes`, [y#x, a#x, m#x, s#x]) + +- Project [cast(y#x as int) AS y#x, cast(a#x as array) AS a#x, cast(m#x as map) AS m#x, cast(s#x as struct) AS s#x] + +- Project [y#x, a#x, m#x, s#x] + +- SubqueryAlias yearsWithComplexTypes + +- LocalRelation [y#x, a#x, m#x, s#x] + + +-- !query +table courseEarnings +|> unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) +-- !query analysis +Filter isnotnull(coalesce(earningsYear#x)) ++- Expand [[course#x, 2012, 2012#x], [course#x, 2013, 2013#x], [course#x, 2014, 2014#x]], [course#x, year#x, earningsYear#x] + +- SubqueryAlias courseearnings + +- View (`courseEarnings`, [course#x, 2012#x, 2013#x, 2014#x]) + +- Project [cast(course#x as string) AS course#x, cast(2012#x as int) AS 2012#x, cast(2013#x as int) AS 2013#x, cast(2014#x as int) AS 2014#x] + +- Project [course#x, 2012#x, 2013#x, 2014#x] + +- SubqueryAlias courseEarnings + +- LocalRelation [course#x, 2012#x, 2013#x, 2014#x] + + +-- !query +table courseEarnings +|> unpivot include nulls ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) +-- !query analysis +Expand [[course#x, 2012, 2012#x], [course#x, 2013, 2013#x], [course#x, 2014, 2014#x]], [course#x, year#x, earningsYear#x] ++- SubqueryAlias courseearnings + +- View (`courseEarnings`, [course#x, 2012#x, 2013#x, 2014#x]) + +- Project [cast(course#x as string) AS course#x, cast(2012#x as int) AS 2012#x, cast(2013#x as int) AS 2013#x, cast(2014#x as int) AS 2014#x] + +- Project [course#x, 2012#x, 2013#x, 2014#x] + +- SubqueryAlias courseEarnings + +- LocalRelation [course#x, 2012#x, 2013#x, 2014#x] + + +-- !query +table courseEarningsAndSales +|> unpivot include nulls ( + (earnings, sales) for `year` in ( + (earnings2012, sales2012) as `2012`, + (earnings2013, sales2013) as `2013`, + (earnings2014, sales2014) as `2014`) + ) +-- !query analysis +Expand [[course#x, 2012, earnings2012#x, sales2012#x], [course#x, 2013, earnings2013#x, sales2013#x], [course#x, 2014, earnings2014#x, sales2014#x]], [course#x, year#x, earnings#x, sales#x] ++- SubqueryAlias courseearningsandsales + +- View (`courseEarningsAndSales`, [course#x, earnings2012#x, sales2012#x, earnings2013#x, sales2013#x, earnings2014#x, sales2014#x]) + +- Project [cast(course#x as string) AS course#x, cast(earnings2012#x as int) AS earnings2012#x, cast(sales2012#x as int) AS sales2012#x, cast(earnings2013#x as int) AS earnings2013#x, cast(sales2013#x as int) AS sales2013#x, cast(earnings2014#x as int) AS earnings2014#x, cast(sales2014#x as int) AS sales2014#x] + +- Project [course#x, earnings2012#x, sales2012#x, earnings2013#x, sales2013#x, earnings2014#x, sales2014#x] + +- SubqueryAlias courseEarningsAndSales + +- LocalRelation [course#x, earnings2012#x, sales2012#x, earnings2013#x, sales2013#x, earnings2014#x, sales2014#x] + + +-- !query +table courseSales +|> select course, earnings +|> pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`year`", + "proposal" : "`course`, `earnings`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 49, + "stopIndex" : 111, + "fragment" : "pivot (\n sum(earnings)\n for `year` in (2012, 2013)\n )" + } ] +} + + +-- !query +table courseSales +|> pivot ( + sum(earnings) + for `year` in (course, 2013) + ) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "NON_LITERAL_PIVOT_VALUES", + "sqlState" : "42K08", + "messageParameters" : { + "expression" : "\"course\"" + } +} + + +-- !query +table courseSales +|> select course, earnings +|> pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) + unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "NOT_ALLOWED_IN_FROM.UNPIVOT_WITH_PIVOT", + "sqlState" : "42601", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 186, + "fragment" : "table courseSales\n|> select course, earnings\n|> pivot (\n sum(earnings)\n for `year` in (2012, 2013)\n )\n unpivot (\n earningsYear for `year` in (`2012`, `2013`, `2014`)\n )" + } ] +} + + +-- !query +table courseSales +|> select course, earnings +|> unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) + pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "NOT_ALLOWED_IN_FROM.UNPIVOT_WITH_PIVOT", + "sqlState" : "42601", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 186, + "fragment" : "table courseSales\n|> select course, earnings\n|> unpivot (\n earningsYear for `year` in (`2012`, `2013`, `2014`)\n )\n pivot (\n sum(earnings)\n for `year` in (2012, 2013)\n )" + } ] +} + + +-- !query +table courseSales +|> select course, earnings +|> pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) + pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'pivot'", + "hint" : "" + } +} + + +-- !query +table courseSales +|> select course, earnings +|> unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) + unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) + pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'unpivot'", + "hint" : "" + } +} + + +-- !query +table t +|> tablesample (100 percent) repeatable (0) +-- !query analysis +Sample 0.0, 1.0, false, 0 ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> tablesample (2 rows) repeatable (0) +-- !query analysis +GlobalLimit 2 ++- LocalLimit 2 + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> tablesample (bucket 1 out of 1) repeatable (0) +-- !query analysis +Sample 0.0, 1.0, false, 0 ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> tablesample (100 percent) repeatable (0) +|> tablesample (5 rows) repeatable (0) +|> tablesample (bucket 1 out of 1) repeatable (0) +-- !query analysis +Sample 0.0, 1.0, false, 0 ++- GlobalLimit 5 + +- LocalLimit 5 + +- Sample 0.0, 1.0, false, 0 + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> tablesample () +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_0014", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 25, + "fragment" : "tablesample ()" + } ] +} + + +-- !query +table t +|> tablesample (-100 percent) repeatable (0) +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_0064", + "messageParameters" : { + "msg" : "Sampling fraction (-1.0) must be on interval [0, 1]" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 52, + "fragment" : "tablesample (-100 percent) repeatable (0)" + } ] +} + + +-- !query +table t +|> tablesample (-5 rows) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_LIMIT_LIKE_EXPRESSION.IS_NEGATIVE", + "sqlState" : "42K0E", + "messageParameters" : { + "expr" : "\"-5\"", + "name" : "limit", + "v" : "-5" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 25, + "stopIndex" : 26, + "fragment" : "-5" + } ] +} + + +-- !query +table t +|> tablesample (x rows) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_LIMIT_LIKE_EXPRESSION.IS_UNFOLDABLE", + "sqlState" : "42K0E", + "messageParameters" : { + "expr" : "\"x\"", + "name" : "limit" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 25, + "stopIndex" : 25, + "fragment" : "x" + } ] +} + + +-- !query +table t +|> tablesample (bucket 2 out of 1) +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_0064", + "messageParameters" : { + "msg" : "Sampling fraction (2.0) must be on interval [0, 1]" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 42, + "fragment" : "tablesample (bucket 2 out of 1)" + } ] +} + + +-- !query +table t +|> tablesample (200b) repeatable (0) +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_0015", + "messageParameters" : { + "msg" : "byteLengthLiteral" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 44, + "fragment" : "tablesample (200b) repeatable (0)" + } ] +} + + +-- !query +table t +|> tablesample (200) repeatable (0) +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_0016", + "messageParameters" : { + "bytesStr" : "200" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 43, + "fragment" : "tablesample (200) repeatable (0)" + } ] +} + + +-- !query +table join_test_t1 +|> inner join join_test_empty_table +-- !query analysis +Join Inner +:- SubqueryAlias join_test_t1 +: +- View (`join_test_t1`, [a#x]) +: +- Project [cast(a#x as int) AS a#x] +: +- Project [a#x] +: +- SubqueryAlias grouping +: +- LocalRelation [a#x] ++- SubqueryAlias join_test_empty_table + +- View (`join_test_empty_table`, [a#x]) + +- Project [cast(a#x as int) AS a#x] + +- Project [a#x] + +- Filter false + +- SubqueryAlias join_test_t2 + +- View (`join_test_t2`, [a#x]) + +- Project [cast(a#x as int) AS a#x] + +- Project [a#x] + +- SubqueryAlias grouping + +- LocalRelation [a#x] + + +-- !query +table join_test_t1 +|> cross join join_test_empty_table +-- !query analysis +Join Cross +:- SubqueryAlias join_test_t1 +: +- View (`join_test_t1`, [a#x]) +: +- Project [cast(a#x as int) AS a#x] +: +- Project [a#x] +: +- SubqueryAlias grouping +: +- LocalRelation [a#x] ++- SubqueryAlias join_test_empty_table + +- View (`join_test_empty_table`, [a#x]) + +- Project [cast(a#x as int) AS a#x] + +- Project [a#x] + +- Filter false + +- SubqueryAlias join_test_t2 + +- View (`join_test_t2`, [a#x]) + +- Project [cast(a#x as int) AS a#x] + +- Project [a#x] + +- SubqueryAlias grouping + +- LocalRelation [a#x] + + +-- !query +table join_test_t1 +|> left outer join join_test_empty_table +-- !query analysis +Join LeftOuter +:- SubqueryAlias join_test_t1 +: +- View (`join_test_t1`, [a#x]) +: +- Project [cast(a#x as int) AS a#x] +: +- Project [a#x] +: +- SubqueryAlias grouping +: +- LocalRelation [a#x] ++- SubqueryAlias join_test_empty_table + +- View (`join_test_empty_table`, [a#x]) + +- Project [cast(a#x as int) AS a#x] + +- Project [a#x] + +- Filter false + +- SubqueryAlias join_test_t2 + +- View (`join_test_t2`, [a#x]) + +- Project [cast(a#x as int) AS a#x] + +- Project [a#x] + +- SubqueryAlias grouping + +- LocalRelation [a#x] + + +-- !query +table join_test_t1 +|> right outer join join_test_empty_table +-- !query analysis +Join RightOuter +:- SubqueryAlias join_test_t1 +: +- View (`join_test_t1`, [a#x]) +: +- Project [cast(a#x as int) AS a#x] +: +- Project [a#x] +: +- SubqueryAlias grouping +: +- LocalRelation [a#x] ++- SubqueryAlias join_test_empty_table + +- View (`join_test_empty_table`, [a#x]) + +- Project [cast(a#x as int) AS a#x] + +- Project [a#x] + +- Filter false + +- SubqueryAlias join_test_t2 + +- View (`join_test_t2`, [a#x]) + +- Project [cast(a#x as int) AS a#x] + +- Project [a#x] + +- SubqueryAlias grouping + +- LocalRelation [a#x] + + +-- !query +table join_test_t1 +|> full outer join join_test_empty_table using (a) +-- !query analysis +Project [coalesce(a#x, a#x) AS a#x] ++- Join FullOuter, (a#x = a#x) + :- SubqueryAlias join_test_t1 + : +- View (`join_test_t1`, [a#x]) + : +- Project [cast(a#x as int) AS a#x] + : +- Project [a#x] + : +- SubqueryAlias grouping + : +- LocalRelation [a#x] + +- SubqueryAlias join_test_empty_table + +- View (`join_test_empty_table`, [a#x]) + +- Project [cast(a#x as int) AS a#x] + +- Project [a#x] + +- Filter false + +- SubqueryAlias join_test_t2 + +- View (`join_test_t2`, [a#x]) + +- Project [cast(a#x as int) AS a#x] + +- Project [a#x] + +- SubqueryAlias grouping + +- LocalRelation [a#x] + + +-- !query +table join_test_t1 +|> full outer join join_test_empty_table on (join_test_t1.a = join_test_empty_table.a) +-- !query analysis +Join FullOuter, (a#x = a#x) +:- SubqueryAlias join_test_t1 +: +- View (`join_test_t1`, [a#x]) +: +- Project [cast(a#x as int) AS a#x] +: +- Project [a#x] +: +- SubqueryAlias grouping +: +- LocalRelation [a#x] ++- SubqueryAlias join_test_empty_table + +- View (`join_test_empty_table`, [a#x]) + +- Project [cast(a#x as int) AS a#x] + +- Project [a#x] + +- Filter false + +- SubqueryAlias join_test_t2 + +- View (`join_test_t2`, [a#x]) + +- Project [cast(a#x as int) AS a#x] + +- Project [a#x] + +- SubqueryAlias grouping + +- LocalRelation [a#x] + + +-- !query +table join_test_t1 +|> left semi join join_test_empty_table +-- !query analysis +Join LeftSemi +:- SubqueryAlias join_test_t1 +: +- View (`join_test_t1`, [a#x]) +: +- Project [cast(a#x as int) AS a#x] +: +- Project [a#x] +: +- SubqueryAlias grouping +: +- LocalRelation [a#x] ++- SubqueryAlias join_test_empty_table + +- View (`join_test_empty_table`, [a#x]) + +- Project [cast(a#x as int) AS a#x] + +- Project [a#x] + +- Filter false + +- SubqueryAlias join_test_t2 + +- View (`join_test_t2`, [a#x]) + +- Project [cast(a#x as int) AS a#x] + +- Project [a#x] + +- SubqueryAlias grouping + +- LocalRelation [a#x] + + +-- !query +table join_test_t1 +|> left anti join join_test_empty_table +-- !query analysis +Join LeftAnti +:- SubqueryAlias join_test_t1 +: +- View (`join_test_t1`, [a#x]) +: +- Project [cast(a#x as int) AS a#x] +: +- Project [a#x] +: +- SubqueryAlias grouping +: +- LocalRelation [a#x] ++- SubqueryAlias join_test_empty_table + +- View (`join_test_empty_table`, [a#x]) + +- Project [cast(a#x as int) AS a#x] + +- Project [a#x] + +- Filter false + +- SubqueryAlias join_test_t2 + +- View (`join_test_t2`, [a#x]) + +- Project [cast(a#x as int) AS a#x] + +- Project [a#x] + +- SubqueryAlias grouping + +- LocalRelation [a#x] + + +-- !query +select * from join_test_t1 where true +|> inner join join_test_empty_table +-- !query analysis +Join Inner +:- Project [a#x] +: +- Filter true +: +- SubqueryAlias join_test_t1 +: +- View (`join_test_t1`, [a#x]) +: +- Project [cast(a#x as int) AS a#x] +: +- Project [a#x] +: +- SubqueryAlias grouping +: +- LocalRelation [a#x] ++- SubqueryAlias join_test_empty_table + +- View (`join_test_empty_table`, [a#x]) + +- Project [cast(a#x as int) AS a#x] + +- Project [a#x] + +- Filter false + +- SubqueryAlias join_test_t2 + +- View (`join_test_t2`, [a#x]) + +- Project [cast(a#x as int) AS a#x] + +- Project [a#x] + +- SubqueryAlias grouping + +- LocalRelation [a#x] + + +-- !query +select 1 as x, 2 as y +|> inner join (select 1 as x, 4 as y) using (x) +-- !query analysis +Project [x#x, y#x, y#x] ++- Join Inner, (x#x = x#x) + :- Project [1 AS x#x, 2 AS y#x] + : +- OneRowRelation + +- SubqueryAlias __auto_generated_subquery_name + +- Project [1 AS x#x, 4 AS y#x] + +- OneRowRelation + + +-- !query +table join_test_t1 +|> inner join (join_test_t2 jt2 inner join join_test_t3 jt3 using (a)) using (a) +|> select a, join_test_t1.a, jt2.a, jt3.a +-- !query analysis +Project [a#x, a#x, a#x, a#x] ++- Project [a#x, a#x, a#x] + +- Join Inner, (a#x = a#x) + :- SubqueryAlias join_test_t1 + : +- View (`join_test_t1`, [a#x]) + : +- Project [cast(a#x as int) AS a#x] + : +- Project [a#x] + : +- SubqueryAlias grouping + : +- LocalRelation [a#x] + +- Project [a#x, a#x] + +- Join Inner, (a#x = a#x) + :- SubqueryAlias jt2 + : +- SubqueryAlias join_test_t2 + : +- View (`join_test_t2`, [a#x]) + : +- Project [cast(a#x as int) AS a#x] + : +- Project [a#x] + : +- SubqueryAlias grouping + : +- LocalRelation [a#x] + +- SubqueryAlias jt3 + +- SubqueryAlias join_test_t3 + +- View (`join_test_t3`, [a#x]) + +- Project [cast(a#x as int) AS a#x] + +- Project [a#x] + +- SubqueryAlias grouping + +- LocalRelation [a#x] + + +-- !query +table join_test_t1 +|> inner join join_test_t2 tablesample (100 percent) repeatable (0) jt2 using (a) +-- !query analysis +Project [a#x] ++- Join Inner, (a#x = a#x) + :- SubqueryAlias join_test_t1 + : +- View (`join_test_t1`, [a#x]) + : +- Project [cast(a#x as int) AS a#x] + : +- Project [a#x] + : +- SubqueryAlias grouping + : +- LocalRelation [a#x] + +- Sample 0.0, 1.0, false, 0 + +- SubqueryAlias jt2 + +- SubqueryAlias join_test_t2 + +- View (`join_test_t2`, [a#x]) + +- Project [cast(a#x as int) AS a#x] + +- Project [a#x] + +- SubqueryAlias grouping + +- LocalRelation [a#x] + + +-- !query +table join_test_t1 +|> inner join (select 1 as a) tablesample (100 percent) repeatable (0) jt2 using (a) +-- !query analysis +Project [a#x] ++- Join Inner, (a#x = a#x) + :- SubqueryAlias join_test_t1 + : +- View (`join_test_t1`, [a#x]) + : +- Project [cast(a#x as int) AS a#x] + : +- Project [a#x] + : +- SubqueryAlias grouping + : +- LocalRelation [a#x] + +- SubqueryAlias jt2 + +- Sample 0.0, 1.0, false, 0 + +- Project [1 AS a#x] + +- OneRowRelation + + +-- !query +table join_test_t1 +|> join join_test_t1 using (a) +-- !query analysis +Project [a#x] ++- Join Inner, (a#x = a#x) + :- SubqueryAlias join_test_t1 + : +- View (`join_test_t1`, [a#x]) + : +- Project [cast(a#x as int) AS a#x] + : +- Project [a#x] + : +- SubqueryAlias grouping + : +- LocalRelation [a#x] + +- SubqueryAlias join_test_t1 + +- View (`join_test_t1`, [a#x]) + +- Project [cast(a#x as int) AS a#x] + +- Project [a#x] + +- SubqueryAlias grouping + +- LocalRelation [a#x] + + +-- !query +table lateral_test_t1 +|> join lateral (select c1) +-- !query analysis +LateralJoin lateral-subquery#x [c1#x], Inner +: +- SubqueryAlias __auto_generated_subquery_name +: +- Project [outer(c1#x) AS c1#x] +: +- OneRowRelation ++- SubqueryAlias lateral_test_t1 + +- View (`lateral_test_t1`, [c1#x, c2#x]) + +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +table lateral_test_t1 +|> join lateral (select c1 from lateral_test_t2) +-- !query analysis +LateralJoin lateral-subquery#x [], Inner +: +- SubqueryAlias __auto_generated_subquery_name +: +- Project [c1#x] +: +- SubqueryAlias lateral_test_t2 +: +- View (`lateral_test_t2`, [c1#x, c2#x]) +: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias lateral_test_t1 + +- View (`lateral_test_t1`, [c1#x, c2#x]) + +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +table lateral_test_t1 +|> join lateral (select lateral_test_t1.c1 from lateral_test_t2) +-- !query analysis +LateralJoin lateral-subquery#x [c1#x], Inner +: +- SubqueryAlias __auto_generated_subquery_name +: +- Project [outer(c1#x) AS c1#x] +: +- SubqueryAlias lateral_test_t2 +: +- View (`lateral_test_t2`, [c1#x, c2#x]) +: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias lateral_test_t1 + +- View (`lateral_test_t1`, [c1#x, c2#x]) + +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +table lateral_test_t1 +|> join lateral (select lateral_test_t1.c1 + t2.c1 from lateral_test_t2 t2) +-- !query analysis +LateralJoin lateral-subquery#x [c1#x], Inner +: +- SubqueryAlias __auto_generated_subquery_name +: +- Project [(outer(c1#x) + c1#x) AS (outer(lateral_test_t1.c1) + c1)#x] +: +- SubqueryAlias t2 +: +- SubqueryAlias lateral_test_t2 +: +- View (`lateral_test_t2`, [c1#x, c2#x]) +: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias lateral_test_t1 + +- View (`lateral_test_t1`, [c1#x, c2#x]) + +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +table lateral_test_t1 +|> join lateral (select *) +-- !query analysis +LateralJoin lateral-subquery#x [], Inner +: +- SubqueryAlias __auto_generated_subquery_name +: +- Project +: +- OneRowRelation ++- SubqueryAlias lateral_test_t1 + +- View (`lateral_test_t1`, [c1#x, c2#x]) + +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +table lateral_test_t1 +|> join lateral (select * from lateral_test_t2) +-- !query analysis +LateralJoin lateral-subquery#x [], Inner +: +- SubqueryAlias __auto_generated_subquery_name +: +- Project [c1#x, c2#x] +: +- SubqueryAlias lateral_test_t2 +: +- View (`lateral_test_t2`, [c1#x, c2#x]) +: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias lateral_test_t1 + +- View (`lateral_test_t1`, [c1#x, c2#x]) + +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +table lateral_test_t1 +|> join lateral (select lateral_test_t1.* from lateral_test_t2) +-- !query analysis +LateralJoin lateral-subquery#x [c1#x && c2#x], Inner +: +- SubqueryAlias __auto_generated_subquery_name +: +- Project [outer(c1#x) AS c1#x, outer(c2#x) AS c2#x] +: +- SubqueryAlias lateral_test_t2 +: +- View (`lateral_test_t2`, [c1#x, c2#x]) +: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias lateral_test_t1 + +- View (`lateral_test_t1`, [c1#x, c2#x]) + +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +table lateral_test_t1 +|> join lateral (select lateral_test_t1.*, t2.* from lateral_test_t2 t2) +-- !query analysis +LateralJoin lateral-subquery#x [c1#x && c2#x], Inner +: +- SubqueryAlias __auto_generated_subquery_name +: +- Project [outer(c1#x) AS c1#x, outer(c2#x) AS c2#x, c1#x, c2#x] +: +- SubqueryAlias t2 +: +- SubqueryAlias lateral_test_t2 +: +- View (`lateral_test_t2`, [c1#x, c2#x]) +: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias lateral_test_t1 + +- View (`lateral_test_t1`, [c1#x, c2#x]) + +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +table lateral_test_t1 +|> join lateral_test_t2 +|> join lateral (select lateral_test_t1.c2 + lateral_test_t2.c2) +-- !query analysis +LateralJoin lateral-subquery#x [c2#x && c2#x], Inner +: +- SubqueryAlias __auto_generated_subquery_name +: +- Project [(outer(c2#x) + outer(c2#x)) AS (outer(lateral_test_t1.c2) + outer(lateral_test_t2.c2))#x] +: +- OneRowRelation ++- Join Inner + :- SubqueryAlias lateral_test_t1 + : +- View (`lateral_test_t1`, [c1#x, c2#x]) + : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias lateral_test_t2 + +- View (`lateral_test_t2`, [c1#x, c2#x]) + +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +table natural_join_test_t1 +|> natural join natural_join_test_t2 +|> where k = "one" +-- !query analysis +Filter (k#x = one) ++- SubqueryAlias __auto_generated_subquery_name + +- Project [k#x, v1#x, v2#x] + +- Join Inner, (k#x = k#x) + :- SubqueryAlias natural_join_test_t1 + : +- View (`natural_join_test_t1`, [k#x, v1#x]) + : +- Project [cast(k#x as string) AS k#x, cast(v1#x as int) AS v1#x] + : +- Project [k#x, v1#x] + : +- SubqueryAlias natural_join_test_t1 + : +- LocalRelation [k#x, v1#x] + +- SubqueryAlias natural_join_test_t2 + +- View (`natural_join_test_t2`, [k#x, v2#x]) + +- Project [cast(k#x as string) AS k#x, cast(v2#x as int) AS v2#x] + +- Project [k#x, v2#x] + +- SubqueryAlias natural_join_test_t2 + +- LocalRelation [k#x, v2#x] + + +-- !query +table natural_join_test_t1 +|> natural join natural_join_test_t2 nt2 +|> select natural_join_test_t1.* +-- !query analysis +Project [k#x, v1#x] ++- Project [k#x, v1#x, v2#x] + +- Join Inner, (k#x = k#x) + :- SubqueryAlias natural_join_test_t1 + : +- View (`natural_join_test_t1`, [k#x, v1#x]) + : +- Project [cast(k#x as string) AS k#x, cast(v1#x as int) AS v1#x] + : +- Project [k#x, v1#x] + : +- SubqueryAlias natural_join_test_t1 + : +- LocalRelation [k#x, v1#x] + +- SubqueryAlias nt2 + +- SubqueryAlias natural_join_test_t2 + +- View (`natural_join_test_t2`, [k#x, v2#x]) + +- Project [cast(k#x as string) AS k#x, cast(v2#x as int) AS v2#x] + +- Project [k#x, v2#x] + +- SubqueryAlias natural_join_test_t2 + +- LocalRelation [k#x, v2#x] + + +-- !query +table natural_join_test_t1 +|> natural join natural_join_test_t2 nt2 +|> natural join natural_join_test_t3 nt3 +|> select natural_join_test_t1.*, nt2.*, nt3.* +-- !query analysis +Project [k#x, v1#x, k#x, v2#x, k#x, v3#x] ++- Project [k#x, v1#x, v2#x, v3#x, k#x, k#x] + +- Join Inner, (k#x = k#x) + :- Project [k#x, v1#x, v2#x, k#x] + : +- Join Inner, (k#x = k#x) + : :- SubqueryAlias natural_join_test_t1 + : : +- View (`natural_join_test_t1`, [k#x, v1#x]) + : : +- Project [cast(k#x as string) AS k#x, cast(v1#x as int) AS v1#x] + : : +- Project [k#x, v1#x] + : : +- SubqueryAlias natural_join_test_t1 + : : +- LocalRelation [k#x, v1#x] + : +- SubqueryAlias nt2 + : +- SubqueryAlias natural_join_test_t2 + : +- View (`natural_join_test_t2`, [k#x, v2#x]) + : +- Project [cast(k#x as string) AS k#x, cast(v2#x as int) AS v2#x] + : +- Project [k#x, v2#x] + : +- SubqueryAlias natural_join_test_t2 + : +- LocalRelation [k#x, v2#x] + +- SubqueryAlias nt3 + +- SubqueryAlias natural_join_test_t3 + +- View (`natural_join_test_t3`, [k#x, v3#x]) + +- Project [cast(k#x as string) AS k#x, cast(v3#x as int) AS v3#x] + +- Project [k#x, v3#x] + +- SubqueryAlias natural_join_test_t3 + +- LocalRelation [k#x, v3#x] + + +-- !query +table join_test_t1 +|> inner join join_test_empty_table + inner join join_test_empty_table +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'inner'", + "hint" : "" + } +} + + +-- !query +table join_test_t1 +|> select 1 + 2 as result +|> full outer join join_test_empty_table on (join_test_t1.a = join_test_empty_table.a) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`join_test_t1`.`a`", + "proposal" : "`result`, `join_test_empty_table`.`a`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 91, + "stopIndex" : 104, + "fragment" : "join_test_t1.a" + } ] +} + + +-- !query +table join_test_t1 jt +|> cross join (select * from jt) +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'jt'", + "hint" : "" + } +} + + +-- !query +table t +|> union all table t +-- !query analysis +Union false, false +:- SubqueryAlias spark_catalog.default.t +: +- Relation spark_catalog.default.t[x#x,y#x] csv ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> union table t +-- !query analysis +Distinct ++- Union false, false + :- SubqueryAlias spark_catalog.default.t + : +- Relation spark_catalog.default.t[x#x,y#x] csv + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +(select * from t) +|> union all table t +-- !query analysis +Union false, false +:- Project [x#x, y#x] +: +- SubqueryAlias spark_catalog.default.t +: +- Relation spark_catalog.default.t[x#x,y#x] csv ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +(select * from t) +|> union table t +-- !query analysis +Distinct ++- Union false, false + :- Project [x#x, y#x] + : +- SubqueryAlias spark_catalog.default.t + : +- Relation spark_catalog.default.t[x#x,y#x] csv + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +values (0, 'abc') tab(x, y) +|> union all table t +-- !query analysis +Union false, false +:- SubqueryAlias tab +: +- LocalRelation [x#x, y#x] ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +values (0, 1) tab(x, y) +|> union table t +-- !query analysis +Distinct ++- Union false, false + :- Project [x#x, cast(y#x as string) AS y#x] + : +- SubqueryAlias tab + : +- LocalRelation [x#x, y#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +(select * from t) +|> union all (select * from t) +-- !query analysis +Union false, false +:- Project [x#x, y#x] +: +- SubqueryAlias spark_catalog.default.t +: +- Relation spark_catalog.default.t[x#x,y#x] csv ++- Project [x#x, y#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> except all table t +-- !query analysis +Except All true +:- SubqueryAlias spark_catalog.default.t +: +- Relation spark_catalog.default.t[x#x,y#x] csv ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> except table t +-- !query analysis +Except false +:- SubqueryAlias spark_catalog.default.t +: +- Relation spark_catalog.default.t[x#x,y#x] csv ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> intersect all table t +-- !query analysis +Intersect All true +:- SubqueryAlias spark_catalog.default.t +: +- Relation spark_catalog.default.t[x#x,y#x] csv ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> intersect table t +-- !query analysis +Intersect false +:- SubqueryAlias spark_catalog.default.t +: +- Relation spark_catalog.default.t[x#x,y#x] csv ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> minus all table t +-- !query analysis +Except All true +:- SubqueryAlias spark_catalog.default.t +: +- Relation spark_catalog.default.t[x#x,y#x] csv ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> minus table t +-- !query analysis +Except false +:- SubqueryAlias spark_catalog.default.t +: +- Relation spark_catalog.default.t[x#x,y#x] csv ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select x +|> union all table t +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "NUM_COLUMNS_MISMATCH", + "sqlState" : "42826", + "messageParameters" : { + "firstNumColumns" : "1", + "invalidNumColumns" : "2", + "invalidOrdinalNum" : "second", + "operator" : "UNION" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 40, + "fragment" : "table t\n|> select x\n|> union all table t" + } ] +} + + +-- !query +table t +|> union all table st +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INCOMPATIBLE_COLUMN_TYPE", + "sqlState" : "42825", + "messageParameters" : { + "columnOrdinalNumber" : "second", + "dataType1" : "\"STRUCT\"", + "dataType2" : "\"STRING\"", + "hint" : "", + "operator" : "UNION", + "tableOrdinalNumber" : "second" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 29, + "fragment" : "table t\n|> union all table st" + } ] +} + + -- !query drop table t -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out index a4e40f08b4463..8c10d78405751 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out @@ -776,7 +776,7 @@ Project [NULL AS Expected#x, variablereference(system.session.var1=CAST(NULL AS -- !query DECLARE OR REPLACE VARIABLE var1 STRING DEFAULT CURRENT_DATABASE() -- !query analysis -CreateVariable defaultvalueexpression(cast(current_schema() as string), CURRENT_DATABASE()), true +CreateVariable defaultvalueexpression(cast(current_database() as string), CURRENT_DATABASE()), true +- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 @@ -2147,7 +2147,7 @@ CreateVariable defaultvalueexpression(cast(a INT as string), 'a INT'), true -- !query SELECT from_json('{"a": 1}', var1) -- !query analysis -Project [from_json(StructField(a,IntegerType,true), {"a": 1}, Some(America/Los_Angeles)) AS from_json({"a": 1})#x] +Project [from_json(StructField(a,IntegerType,true), {"a": 1}, Some(America/Los_Angeles), false) AS from_json({"a": 1})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/subexp-elimination.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/subexp-elimination.sql.out index 94073f2751b3e..754b05bfa6fed 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/subexp-elimination.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/subexp-elimination.sql.out @@ -15,7 +15,7 @@ AS testData(a, b), false, true, LocalTempView, UNSUPPORTED, true -- !query SELECT from_json(a, 'struct').a, from_json(a, 'struct').b, from_json(b, 'array>')[0].a, from_json(b, 'array>')[0].b FROM testData -- !query analysis -Project [from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).a AS from_json(a).a#x, from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).b AS from_json(a).b#x, from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].a AS from_json(b)[0].a#x, from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].b AS from_json(b)[0].b#x] +Project [from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).a AS from_json(a).a#x, from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).b AS from_json(a).b#x, from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].a AS from_json(b)[0].a#x, from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].b AS from_json(b)[0].b#x] +- SubqueryAlias testdata +- View (`testData`, [a#x, b#x]) +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] @@ -27,7 +27,7 @@ Project [from_json(StructField(a,IntegerType,true), StructField(b,StringType,tru -- !query SELECT if(from_json(a, 'struct').a > 1, from_json(b, 'array>')[0].a, from_json(b, 'array>')[0].a + 1) FROM testData -- !query analysis -Project [if ((from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).a > 1)) from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].a else (from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].a + 1) AS (IF((from_json(a).a > 1), from_json(b)[0].a, (from_json(b)[0].a + 1)))#x] +Project [if ((from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).a > 1)) from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].a else (from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].a + 1) AS (IF((from_json(a).a > 1), from_json(b)[0].a, (from_json(b)[0].a + 1)))#x] +- SubqueryAlias testdata +- View (`testData`, [a#x, b#x]) +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] @@ -39,7 +39,7 @@ Project [if ((from_json(StructField(a,IntegerType,true), StructField(b,StringTyp -- !query SELECT if(isnull(from_json(a, 'struct').a), from_json(b, 'array>')[0].b + 1, from_json(b, 'array>')[0].b) FROM testData -- !query analysis -Project [if (isnull(from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).a)) (from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].b + 1) else from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].b AS (IF((from_json(a).a IS NULL), (from_json(b)[0].b + 1), from_json(b)[0].b))#x] +Project [if (isnull(from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).a)) (from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].b + 1) else from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].b AS (IF((from_json(a).a IS NULL), (from_json(b)[0].b + 1), from_json(b)[0].b))#x] +- SubqueryAlias testdata +- View (`testData`, [a#x, b#x]) +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] @@ -51,7 +51,7 @@ Project [if (isnull(from_json(StructField(a,IntegerType,true), StructField(b,Str -- !query SELECT case when from_json(a, 'struct').a > 5 then from_json(a, 'struct').b when from_json(a, 'struct').a > 4 then from_json(a, 'struct').b + 1 else from_json(a, 'struct').b + 2 end FROM testData -- !query analysis -Project [CASE WHEN (from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).a > 5) THEN from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).b WHEN (from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).a > 4) THEN cast((cast(from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).b as double) + cast(1 as double)) as string) ELSE cast((cast(from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).b as double) + cast(2 as double)) as string) END AS CASE WHEN (from_json(a).a > 5) THEN from_json(a).b WHEN (from_json(a).a > 4) THEN (from_json(a).b + 1) ELSE (from_json(a).b + 2) END#x] +Project [CASE WHEN (from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).a > 5) THEN from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).b WHEN (from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).a > 4) THEN cast((cast(from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).b as double) + cast(1 as double)) as string) ELSE cast((cast(from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).b as double) + cast(2 as double)) as string) END AS CASE WHEN (from_json(a).a > 5) THEN from_json(a).b WHEN (from_json(a).a > 4) THEN (from_json(a).b + 1) ELSE (from_json(a).b + 2) END#x] +- SubqueryAlias testdata +- View (`testData`, [a#x, b#x]) +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] @@ -63,7 +63,7 @@ Project [CASE WHEN (from_json(StructField(a,IntegerType,true), StructField(b,Str -- !query SELECT case when from_json(a, 'struct').a > 5 then from_json(b, 'array>')[0].b when from_json(a, 'struct').a > 4 then from_json(b, 'array>')[0].b + 1 else from_json(b, 'array>')[0].b + 2 end FROM testData -- !query analysis -Project [CASE WHEN (from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).a > 5) THEN from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].b WHEN (from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).a > 4) THEN (from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].b + 1) ELSE (from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].b + 2) END AS CASE WHEN (from_json(a).a > 5) THEN from_json(b)[0].b WHEN (from_json(a).a > 4) THEN (from_json(b)[0].b + 1) ELSE (from_json(b)[0].b + 2) END#x] +Project [CASE WHEN (from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).a > 5) THEN from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].b WHEN (from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).a > 4) THEN (from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].b + 1) ELSE (from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].b + 2) END AS CASE WHEN (from_json(a).a > 5) THEN from_json(b)[0].b WHEN (from_json(a).a > 4) THEN (from_json(b)[0].b + 1) ELSE (from_json(b)[0].b + 2) END#x] +- SubqueryAlias testdata +- View (`testData`, [a#x, b#x]) +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out index bea91e09b0053..01de7beda551d 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out @@ -142,6 +142,12 @@ Project [x1#x, x2#x, scalar-subquery#x [x1#x && x2#x] AS scalarsubquery(x1, x2)# +- LocalRelation [col1#x, col2#x] +-- !query +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = false +-- !query analysis +SetCommand (spark.sql.optimizer.scalarSubqueryUseSingleJoin,Some(false)) + + -- !query select * from x where (select count(*) from y where y1 > x1 group by y1) = 1 -- !query analysis @@ -202,24 +208,83 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = true +-- !query analysis +SetCommand (spark.sql.optimizer.scalarSubqueryUseSingleJoin,Some(true)) + + +-- !query +select * from x where (select count(*) from y where y1 > x1 group by y1) = 1 +-- !query analysis +Project [x1#x, x2#x] ++- Filter (scalar-subquery#x [x1#x] = cast(1 as bigint)) + : +- Aggregate [y1#x], [count(1) AS count(1)#xL] + : +- Filter (y1#x > outer(x1#x)) + : +- SubqueryAlias y + : +- View (`y`, [y1#x, y2#x]) + : +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias x + +- View (`x`, [x1#x, x2#x]) + +- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select *, (select count(*) from y where y1 + y2 = x1 group by y1) from x +-- !query analysis +Project [x1#x, x2#x, scalar-subquery#x [x1#x] AS scalarsubquery(x1)#xL] +: +- Aggregate [y1#x], [count(1) AS count(1)#xL] +: +- Filter ((y1#x + y2#x) = outer(x1#x)) +: +- SubqueryAlias y +: +- View (`y`, [y1#x, y2#x]) +: +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias x + +- View (`x`, [x1#x, x2#x]) + +- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select *, (select count(*) from y where x1 = y1 and y2 + 10 = x1 + 1 group by y2) from x +-- !query analysis +Project [x1#x, x2#x, scalar-subquery#x [x1#x && x1#x] AS scalarsubquery(x1, x1)#xL] +: +- Aggregate [y2#x], [count(1) AS count(1)#xL] +: +- Filter ((outer(x1#x) = y1#x) AND ((y2#x + 10) = (outer(x1#x) + 1))) +: +- SubqueryAlias y +: +- View (`y`, [y1#x, y2#x]) +: +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias x + +- View (`x`, [x1#x, x2#x]) + +- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x] + +- LocalRelation [col1#x, col2#x] + + -- !query select *, (select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1) from x -- !query analysis -org.apache.spark.sql.catalyst.ExtendedAnalysisException -{ - "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY", - "sqlState" : "0A000", - "messageParameters" : { - "value" : "y1" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 11, - "stopIndex" : 106, - "fragment" : "(select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1)" - } ] -} +Project [x1#x, x2#x, scalar-subquery#x [x1#x] AS scalarsubquery(x1)#xL] +: +- Aggregate [y1#x], [count(1) AS count(1)#xL] +: +- SubqueryAlias sub +: +- Union false, false +: :- Project [y1#x, y2#x] +: : +- Filter (y1#x = outer(x1#x)) +: : +- SubqueryAlias y +: : +- View (`y`, [y1#x, y2#x]) +: : +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] +: : +- LocalRelation [col1#x, col2#x] +: +- Project [y1#x, y2#x] +: +- SubqueryAlias y +: +- View (`y`, [y1#x, y2#x]) +: +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias x + +- View (`x`, [x1#x, x2#x]) + +- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x] + +- LocalRelation [col1#x, col2#x] -- !query @@ -227,17 +292,17 @@ select *, (select count(*) from y left join (select * from z where z1 = x1) sub -- !query analysis org.apache.spark.sql.catalyst.ExtendedAnalysisException { - "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY", + "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.ACCESSING_OUTER_QUERY_COLUMN_IS_NOT_ALLOWED", "sqlState" : "0A000", "messageParameters" : { - "value" : "z1" + "treeNode" : "Filter (z1#x = outer(x1#x))\n+- SubqueryAlias z\n +- View (`z`, [z1#x, z2#x])\n +- Project [cast(col1#x as int) AS z1#x, cast(col2#x as int) AS z2#x]\n +- LocalRelation [col1#x, col2#x]\n" }, "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 11, - "stopIndex" : 103, - "fragment" : "(select count(*) from y left join (select * from z where z1 = x1) sub on y2 = z2 group by z1)" + "startIndex" : 46, + "stopIndex" : 74, + "fragment" : "select * from z where z1 = x1" } ] } @@ -248,6 +313,12 @@ set spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate = SetCommand (spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate,Some(true)) +-- !query +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = false +-- !query analysis +SetCommand (spark.sql.optimizer.scalarSubqueryUseSingleJoin,Some(false)) + + -- !query select * from x where (select count(*) from y where y1 > x1 group by y1) = 1 -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out index e3ce85fe5d209..4ff0222d6e965 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out @@ -1748,3 +1748,21 @@ Project [t1a#x, t1b#x, t1c#x] +- View (`t1`, [t1a#x, t1b#x, t1c#x]) +- Project [cast(col1#x as int) AS t1a#x, cast(col2#x as int) AS t1b#x, cast(col3#x as int) AS t1c#x] +- LocalRelation [col1#x, col2#x, col3#x] + + +-- !query +SELECT * FROM t0 WHERE t0a = (SELECT distinct(t1c) FROM t1 WHERE t1a = t0a) +-- !query analysis +Project [t0a#x, t0b#x] ++- Filter (t0a#x = scalar-subquery#x [t0a#x]) + : +- Distinct + : +- Project [t1c#x] + : +- Filter (t1a#x = outer(t0a#x)) + : +- SubqueryAlias t1 + : +- View (`t1`, [t1a#x, t1b#x, t1c#x]) + : +- Project [cast(col1#x as int) AS t1a#x, cast(col2#x as int) AS t1b#x, cast(col3#x as int) AS t1c#x] + : +- LocalRelation [col1#x, col2#x, col3#x] + +- SubqueryAlias t0 + +- View (`t0`, [t0a#x, t0b#x]) + +- Project [cast(col1#x as int) AS t0a#x, cast(col2#x as int) AS t0b#x] + +- LocalRelation [col1#x, col2#x] diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/timestamp.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/timestamp.sql.out index 6ca35b8b141dc..dcfd783b648f8 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/timestamp.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/timestamp.sql.out @@ -802,7 +802,7 @@ Project [unix_timestamp(22 05 2020 Friday, dd MM yyyy EEEEE, Some(America/Los_An -- !query select from_json('{"t":"26/October/2015"}', 't Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')) -- !query analysis -Project [from_json(StructField(t,TimestampType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"t":"26/October/2015"})#x] +Project [from_json(StructField(t,TimestampType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"t":"26/October/2015"})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/timestampNTZ/timestamp-ansi.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/timestampNTZ/timestamp-ansi.sql.out index e50c860270563..ec227afc87fe1 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/timestampNTZ/timestamp-ansi.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/timestampNTZ/timestamp-ansi.sql.out @@ -745,7 +745,7 @@ Project [unix_timestamp(22 05 2020 Friday, dd MM yyyy EEEEE, Some(America/Los_An -- !query select from_json('{"t":"26/October/2015"}', 't Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')) -- !query analysis -Project [from_json(StructField(t,TimestampNTZType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"t":"26/October/2015"})#x] +Project [from_json(StructField(t,TimestampNTZType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"t":"26/October/2015"})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/timestampNTZ/timestamp.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/timestampNTZ/timestamp.sql.out index 098abfb3852cf..7475f837250d5 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/timestampNTZ/timestamp.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/timestampNTZ/timestamp.sql.out @@ -805,7 +805,7 @@ Project [unix_timestamp(22 05 2020 Friday, dd MM yyyy EEEEE, Some(America/Los_An -- !query select from_json('{"t":"26/October/2015"}', 't Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')) -- !query analysis -Project [from_json(StructField(t,TimestampNTZType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"t":"26/October/2015"})#x] +Project [from_json(StructField(t,TimestampNTZType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"t":"26/October/2015"})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/timezone.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/timezone.sql.out index 9059f37f3607b..5b55a0c218934 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/timezone.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/timezone.sql.out @@ -64,7 +64,11 @@ SET TIME ZONE INTERVAL 3 DAYS -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0044", + "errorClass" : "INVALID_INTERVAL_FORMAT.TIMEZONE_INTERVAL_OUT_OF_RANGE", + "sqlState" : "22006", + "messageParameters" : { + "input" : "3" + }, "queryContext" : [ { "objectType" : "", "objectName" : "", @@ -80,7 +84,11 @@ SET TIME ZONE INTERVAL 24 HOURS -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0044", + "errorClass" : "INVALID_INTERVAL_FORMAT.TIMEZONE_INTERVAL_OUT_OF_RANGE", + "sqlState" : "22006", + "messageParameters" : { + "input" : "24" + }, "queryContext" : [ { "objectType" : "", "objectName" : "", @@ -96,7 +104,11 @@ SET TIME ZONE INTERVAL '19:40:32' HOUR TO SECOND -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0044", + "errorClass" : "INVALID_INTERVAL_FORMAT.TIMEZONE_INTERVAL_OUT_OF_RANGE", + "sqlState" : "22006", + "messageParameters" : { + "input" : "19" + }, "queryContext" : [ { "objectType" : "", "objectName" : "", @@ -128,7 +140,11 @@ SET TIME ZONE INTERVAL 10 HOURS 1 MILLISECOND -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0044", + "errorClass" : "INVALID_INTERVAL_FORMAT.TIMEZONE_INTERVAL_OUT_OF_RANGE", + "sqlState" : "22006", + "messageParameters" : { + "input" : "36000" + }, "queryContext" : [ { "objectType" : "", "objectName" : "", diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/typeCoercion/native/stringCastAndExpressions.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/typeCoercion/native/stringCastAndExpressions.sql.out index 009e91f7ffacf..22e60d0606382 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/typeCoercion/native/stringCastAndExpressions.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/typeCoercion/native/stringCastAndExpressions.sql.out @@ -370,7 +370,7 @@ Project [c0#x] -- !query select from_json(a, 'a INT') from t -- !query analysis -Project [from_json(StructField(a,IntegerType,true), a#x, Some(America/Los_Angeles)) AS from_json(a)#x] +Project [from_json(StructField(a,IntegerType,true), a#x, Some(America/Los_Angeles), false) AS from_json(a)#x] +- SubqueryAlias t +- View (`t`, [a#x]) +- Project [cast(a#x as string) AS a#x] diff --git a/sql/core/src/test/resources/sql-tests/inputs/charvarchar.sql b/sql/core/src/test/resources/sql-tests/inputs/charvarchar.sql index 8117dec53f4ab..be038e1083cd8 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/charvarchar.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/charvarchar.sql @@ -49,6 +49,8 @@ desc formatted char_tbl1; create table char_part(c1 char(5), c2 char(2), v1 varchar(6), v2 varchar(2)) using parquet partitioned by (v2, c2); desc formatted char_part; +alter table char_part change column c1 comment 'char comment'; +alter table char_part change column v1 comment 'varchar comment'; alter table char_part add partition (v2='ke', c2='nt') location 'loc1'; desc formatted char_part; diff --git a/sql/core/src/test/resources/sql-tests/inputs/collations.sql b/sql/core/src/test/resources/sql-tests/inputs/collations.sql index 183577b83971b..f3a42fd3e1f12 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/collations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/collations.sql @@ -99,6 +99,7 @@ insert into t4 values('a:1,b:2,c:3', ',', ':'); select str_to_map(text, pairDelim, keyValueDelim) from t4; select str_to_map(text collate utf8_binary, pairDelim collate utf8_lcase, keyValueDelim collate utf8_binary) from t4; select str_to_map(text collate utf8_binary, pairDelim collate utf8_binary, keyValueDelim collate utf8_binary) from t4; +select str_to_map(text collate unicode_ai, pairDelim collate unicode_ai, keyValueDelim collate unicode_ai) from t4; drop table t4; @@ -159,6 +160,7 @@ select split_part(s, utf8_binary, 1) from t5; select split_part(utf8_binary collate utf8_binary, s collate utf8_lcase, 1) from t5; select split_part(utf8_binary, utf8_lcase collate utf8_binary, 2) from t5; select split_part(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 2) from t5; +select split_part(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2) from t5; select split_part(utf8_binary, 'a', 3), split_part(utf8_lcase, 'a', 3) from t5; select split_part(utf8_binary, 'a' collate utf8_lcase, 3), split_part(utf8_lcase, 'a' collate utf8_binary, 3) from t5; @@ -168,6 +170,7 @@ select contains(s, utf8_binary) from t5; select contains(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select contains(utf8_binary, utf8_lcase collate utf8_binary) from t5; select contains(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select contains(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select contains(utf8_binary, 'a'), contains(utf8_lcase, 'a') from t5; select contains(utf8_binary, 'AaAA' collate utf8_lcase), contains(utf8_lcase, 'AAa' collate utf8_binary) from t5; @@ -177,6 +180,7 @@ select substring_index(s, utf8_binary,1) from t5; select substring_index(utf8_binary collate utf8_binary, s collate utf8_lcase, 3) from t5; select substring_index(utf8_binary, utf8_lcase collate utf8_binary, 2) from t5; select substring_index(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 2) from t5; +select substring_index(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2) from t5; select substring_index(utf8_binary, 'a', 2), substring_index(utf8_lcase, 'a', 2) from t5; select substring_index(utf8_binary, 'AaAA' collate utf8_lcase, 2), substring_index(utf8_lcase, 'AAa' collate utf8_binary, 2) from t5; @@ -186,6 +190,7 @@ select instr(s, utf8_binary) from t5; select instr(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select instr(utf8_binary, utf8_lcase collate utf8_binary) from t5; select instr(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select instr(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select instr(utf8_binary, 'a'), instr(utf8_lcase, 'a') from t5; select instr(utf8_binary, 'AaAA' collate utf8_lcase), instr(utf8_lcase, 'AAa' collate utf8_binary) from t5; @@ -204,6 +209,7 @@ select startswith(s, utf8_binary) from t5; select startswith(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select startswith(utf8_binary, utf8_lcase collate utf8_binary) from t5; select startswith(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select startswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select startswith(utf8_binary, 'aaAaaAaA'), startswith(utf8_lcase, 'aaAaaAaA') from t5; select startswith(utf8_binary, 'aaAaaAaA' collate utf8_lcase), startswith(utf8_lcase, 'aaAaaAaA' collate utf8_binary) from t5; @@ -212,6 +218,7 @@ select translate(utf8_lcase, utf8_lcase, '12345') from t5; select translate(utf8_binary, utf8_lcase, '12345') from t5; select translate(utf8_binary, 'aBc' collate utf8_lcase, '12345' collate utf8_binary) from t5; select translate(utf8_binary, 'SQL' collate utf8_lcase, '12345' collate utf8_lcase) from t5; +select translate(utf8_binary, 'SQL' collate unicode_ai, '12345' collate unicode_ai) from t5; select translate(utf8_lcase, 'aaAaaAaA', '12345'), translate(utf8_binary, 'aaAaaAaA', '12345') from t5; select translate(utf8_lcase, 'aBc' collate utf8_binary, '12345'), translate(utf8_binary, 'aBc' collate utf8_lcase, '12345') from t5; @@ -221,6 +228,7 @@ select replace(s, utf8_binary, 'abc') from t5; select replace(utf8_binary collate utf8_binary, s collate utf8_lcase, 'abc') from t5; select replace(utf8_binary, utf8_lcase collate utf8_binary, 'abc') from t5; select replace(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 'abc') from t5; +select replace(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 'abc') from t5; select replace(utf8_binary, 'aaAaaAaA', 'abc'), replace(utf8_lcase, 'aaAaaAaA', 'abc') from t5; select replace(utf8_binary, 'aaAaaAaA' collate utf8_lcase, 'abc'), replace(utf8_lcase, 'aaAaaAaA' collate utf8_binary, 'abc') from t5; @@ -230,6 +238,7 @@ select endswith(s, utf8_binary) from t5; select endswith(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select endswith(utf8_binary, utf8_lcase collate utf8_binary) from t5; select endswith(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select endswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select endswith(utf8_binary, 'aaAaaAaA'), endswith(utf8_lcase, 'aaAaaAaA') from t5; select endswith(utf8_binary, 'aaAaaAaA' collate utf8_lcase), endswith(utf8_lcase, 'aaAaaAaA' collate utf8_binary) from t5; @@ -364,6 +373,7 @@ select locate(s, utf8_binary) from t5; select locate(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select locate(utf8_binary, utf8_lcase collate utf8_binary) from t5; select locate(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 3) from t5; +select locate(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 3) from t5; select locate(utf8_binary, 'a'), locate(utf8_lcase, 'a') from t5; select locate(utf8_binary, 'AaAA' collate utf8_lcase, 4), locate(utf8_lcase, 'AAa' collate utf8_binary, 4) from t5; @@ -373,6 +383,7 @@ select TRIM(s, utf8_binary) from t5; select TRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select TRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5; select TRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select TRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select TRIM('ABc', utf8_binary), TRIM('ABc', utf8_lcase) from t5; select TRIM('ABc' collate utf8_lcase, utf8_binary), TRIM('AAa' collate utf8_binary, utf8_lcase) from t5; -- StringTrimBoth @@ -381,6 +392,7 @@ select BTRIM(s, utf8_binary) from t5; select BTRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select BTRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5; select BTRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select BTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select BTRIM('ABc', utf8_binary), BTRIM('ABc', utf8_lcase) from t5; select BTRIM('ABc' collate utf8_lcase, utf8_binary), BTRIM('AAa' collate utf8_binary, utf8_lcase) from t5; -- StringTrimLeft @@ -389,6 +401,7 @@ select LTRIM(s, utf8_binary) from t5; select LTRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select LTRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5; select LTRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select LTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select LTRIM('ABc', utf8_binary), LTRIM('ABc', utf8_lcase) from t5; select LTRIM('ABc' collate utf8_lcase, utf8_binary), LTRIM('AAa' collate utf8_binary, utf8_lcase) from t5; -- StringTrimRight @@ -397,6 +410,7 @@ select RTRIM(s, utf8_binary) from t5; select RTRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select RTRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5; select RTRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select RTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select RTRIM('ABc', utf8_binary), RTRIM('ABc', utf8_lcase) from t5; select RTRIM('ABc' collate utf8_lcase, utf8_binary), RTRIM('AAa' collate utf8_binary, utf8_lcase) from t5; diff --git a/sql/core/src/test/resources/sql-tests/inputs/null-handling.sql b/sql/core/src/test/resources/sql-tests/inputs/null-handling.sql index 040be00503442..dcdf241df73d9 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/null-handling.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/null-handling.sql @@ -10,6 +10,8 @@ insert into t1 values(7,null,null); -- Adding anything to null gives null select a, b+c from t1; +select b + 0 from t1 where a = 5; +select -100 + b + 100 from t1 where a = 5; -- Multiplying null by zero gives null select a+10, b*0 from t1; diff --git a/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql b/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql index 7d0966e7f2095..61890f5cb146d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql @@ -12,7 +12,54 @@ drop table if exists st; create table st(x int, col struct) using parquet; insert into st values (1, (2, 3)); --- Selection operators: positive tests. +create temporary view courseSales as select * from values + ("dotNET", 2012, 10000), + ("Java", 2012, 20000), + ("dotNET", 2012, 5000), + ("dotNET", 2013, 48000), + ("Java", 2013, 30000) + as courseSales(course, year, earnings); + +create temporary view courseEarnings as select * from values + ("dotNET", 15000, 48000, 22500), + ("Java", 20000, 30000, NULL) + as courseEarnings(course, `2012`, `2013`, `2014`); + +create temporary view courseEarningsAndSales as select * from values + ("dotNET", 15000, NULL, 48000, 1, 22500, 1), + ("Java", 20000, 1, 30000, 2, NULL, NULL) + as courseEarningsAndSales( + course, earnings2012, sales2012, earnings2013, sales2013, earnings2014, sales2014); + +create temporary view yearsWithComplexTypes as select * from values + (2012, array(1, 1), map('1', 1), struct(1, 'a')), + (2013, array(2, 2), map('2', 2), struct(2, 'b')) + as yearsWithComplexTypes(y, a, m, s); + +create temporary view join_test_t1 as select * from values (1) as grouping(a); +create temporary view join_test_t2 as select * from values (1) as grouping(a); +create temporary view join_test_t3 as select * from values (1) as grouping(a); +create temporary view join_test_empty_table as select a from join_test_t2 where false; + +create temporary view lateral_test_t1(c1, c2) + as values (0, 1), (1, 2); +create temporary view lateral_test_t2(c1, c2) + as values (0, 2), (0, 3); +create temporary view lateral_test_t3(c1, c2) + as values (0, array(0, 1)), (1, array(2)), (2, array()), (null, array(4)); +create temporary view lateral_test_t4(c1, c2) + as values (0, 1), (0, 2), (1, 1), (1, 3); + +create temporary view natural_join_test_t1 as select * from values + ("one", 1), ("two", 2), ("three", 3) as natural_join_test_t1(k, v1); + +create temporary view natural_join_test_t2 as select * from values + ("one", 1), ("two", 22), ("one", 5) as natural_join_test_t2(k, v2); + +create temporary view natural_join_test_t3 as select * from values + ("one", 4), ("two", 5), ("one", 6) as natural_join_test_t3(k, v3); + +-- SELECT operators: positive tests. --------------------------------------- -- Selecting a constant. @@ -85,7 +132,24 @@ table t table t |> select distinct x, y; --- Selection operators: negative tests. +-- SELECT * is supported. +table t +|> select *; + +table t +|> select * except (y); + +-- Hints are supported. +table t +|> select /*+ repartition(3) */ *; + +table t +|> select /*+ repartition(3) */ distinct x; + +table t +|> select /*+ repartition(3) */ all x; + +-- SELECT operators: negative tests. --------------------------------------- -- Aggregate functions are not allowed in the pipe operator SELECT list. @@ -95,6 +159,418 @@ table t table t |> select y, length(y) + sum(x) as result; +-- WHERE operators: positive tests. +----------------------------------- + +-- Filtering with a constant predicate. +table t +|> where true; + +-- Filtering with a predicate based on attributes from the input relation. +table t +|> where x + length(y) < 4; + +-- Two consecutive filters are allowed. +table t +|> where x + length(y) < 4 +|> where x + length(y) < 3; + +-- It is possible to use the WHERE operator instead of the HAVING clause when processing the result +-- of aggregations. For example, this WHERE operator is equivalent to the normal SQL "HAVING x = 1". +(select x, sum(length(y)) as sum_len from t group by x) +|> where x = 1; + +-- Filtering by referring to the table or table subquery alias. +table t +|> where t.x = 1; + +table t +|> where spark_catalog.default.t.x = 1; + +-- Filtering using struct fields. +(select col from st) +|> where col.i1 = 1; + +table st +|> where st.col.i1 = 2; + +-- Expression subqueries in the WHERE clause. +table t +|> where exists (select a from other where x = a limit 1); + +-- Aggregations are allowed within expression subqueries in the pipe operator WHERE clause as long +-- no aggregate functions exist in the top-level expression predicate. +table t +|> where (select any_value(a) from other where x = a limit 1) = 1; + +-- WHERE operators: negative tests. +----------------------------------- + +-- Aggregate functions are not allowed in the top-level WHERE predicate. +-- (Note: to implement this behavior, perform the aggregation first separately and then add a +-- pipe-operator WHERE clause referring to the result of aggregate expression(s) therein). +table t +|> where sum(x) = 1; + +table t +|> where y = 'abc' or length(y) + sum(x) = 1; + +-- Window functions are not allowed in the WHERE clause (pipe operators or otherwise). +table t +|> where first_value(x) over (partition by y) = 1; + +select * from t where first_value(x) over (partition by y) = 1; + +-- Pipe operators may only refer to attributes produced as output from the directly-preceding +-- pipe operator, not from earlier ones. +table t +|> select x, length(y) as z +|> where x + length(y) < 4; + +-- If the WHERE clause wants to filter rows produced by an aggregation, it is not valid to try to +-- refer to the aggregate functions directly; it is necessary to use aliases instead. +(select x, sum(length(y)) as sum_len from t group by x) +|> where sum(length(y)) = 3; + +-- Pivot and unpivot operators: positive tests. +----------------------------------------------- + +table courseSales +|> select `year`, course, earnings +|> pivot ( + sum(earnings) + for course in ('dotNET', 'Java') + ); + +table courseSales +|> select `year` as y, course as c, earnings as e +|> pivot ( + sum(e) as s, avg(e) as a + for y in (2012 as firstYear, 2013 as secondYear) + ); + +-- Pivot on multiple pivot columns with aggregate columns of complex data types. +select course, `year`, y, a +from courseSales +join yearsWithComplexTypes on `year` = y +|> pivot ( + max(a) + for (y, course) in ((2012, 'dotNET'), (2013, 'Java')) + ); + +-- Pivot on pivot column of struct type. +select earnings, `year`, s +from courseSales +join yearsWithComplexTypes on `year` = y +|> pivot ( + sum(earnings) + for s in ((1, 'a'), (2, 'b')) + ); + +table courseEarnings +|> unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ); + +table courseEarnings +|> unpivot include nulls ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ); + +table courseEarningsAndSales +|> unpivot include nulls ( + (earnings, sales) for `year` in ( + (earnings2012, sales2012) as `2012`, + (earnings2013, sales2013) as `2013`, + (earnings2014, sales2014) as `2014`) + ); + +-- Pivot and unpivot operators: negative tests. +----------------------------------------------- + +-- The PIVOT operator refers to a column 'year' is not available in the input relation. +table courseSales +|> select course, earnings +|> pivot ( + sum(earnings) + for `year` in (2012, 2013) + ); + +-- Non-literal PIVOT values are not supported. +table courseSales +|> pivot ( + sum(earnings) + for `year` in (course, 2013) + ); + +-- The PIVOT and UNPIVOT clauses are mutually exclusive. +table courseSales +|> select course, earnings +|> pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) + unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ); + +table courseSales +|> select course, earnings +|> unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) + pivot ( + sum(earnings) + for `year` in (2012, 2013) + ); + +-- Multiple PIVOT and/or UNPIVOT clauses are not supported in the same pipe operator. +table courseSales +|> select course, earnings +|> pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) + pivot ( + sum(earnings) + for `year` in (2012, 2013) + ); + +table courseSales +|> select course, earnings +|> unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) + unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) + pivot ( + sum(earnings) + for `year` in (2012, 2013) + ); + +-- Sampling operators: positive tests. +-------------------------------------- + +-- We will use the REPEATABLE clause and/or adjust the sampling options to either remove no rows or +-- all rows to help keep the tests deterministic. +table t +|> tablesample (100 percent) repeatable (0); + +table t +|> tablesample (2 rows) repeatable (0); + +table t +|> tablesample (bucket 1 out of 1) repeatable (0); + +table t +|> tablesample (100 percent) repeatable (0) +|> tablesample (5 rows) repeatable (0) +|> tablesample (bucket 1 out of 1) repeatable (0); + +-- Sampling operators: negative tests. +-------------------------------------- + +-- The sampling method is required. +table t +|> tablesample (); + +-- Negative sampling options are not supported. +table t +|> tablesample (-100 percent) repeatable (0); + +table t +|> tablesample (-5 rows); + +-- The sampling method may not refer to attribute names from the input relation. +table t +|> tablesample (x rows); + +-- The bucket number is invalid. +table t +|> tablesample (bucket 2 out of 1); + +-- Byte literals are not supported. +table t +|> tablesample (200b) repeatable (0); + +-- Invalid byte literal syntax. +table t +|> tablesample (200) repeatable (0); + +-- JOIN operators: positive tests. +---------------------------------- + +table join_test_t1 +|> inner join join_test_empty_table; + +table join_test_t1 +|> cross join join_test_empty_table; + +table join_test_t1 +|> left outer join join_test_empty_table; + +table join_test_t1 +|> right outer join join_test_empty_table; + +table join_test_t1 +|> full outer join join_test_empty_table using (a); + +table join_test_t1 +|> full outer join join_test_empty_table on (join_test_t1.a = join_test_empty_table.a); + +table join_test_t1 +|> left semi join join_test_empty_table; + +table join_test_t1 +|> left anti join join_test_empty_table; + +select * from join_test_t1 where true +|> inner join join_test_empty_table; + +select 1 as x, 2 as y +|> inner join (select 1 as x, 4 as y) using (x); + +table join_test_t1 +|> inner join (join_test_t2 jt2 inner join join_test_t3 jt3 using (a)) using (a) +|> select a, join_test_t1.a, jt2.a, jt3.a; + +table join_test_t1 +|> inner join join_test_t2 tablesample (100 percent) repeatable (0) jt2 using (a); + +table join_test_t1 +|> inner join (select 1 as a) tablesample (100 percent) repeatable (0) jt2 using (a); + +table join_test_t1 +|> join join_test_t1 using (a); + +-- Lateral joins. +table lateral_test_t1 +|> join lateral (select c1); + +table lateral_test_t1 +|> join lateral (select c1 from lateral_test_t2); + +table lateral_test_t1 +|> join lateral (select lateral_test_t1.c1 from lateral_test_t2); + +table lateral_test_t1 +|> join lateral (select lateral_test_t1.c1 + t2.c1 from lateral_test_t2 t2); + +table lateral_test_t1 +|> join lateral (select *); + +table lateral_test_t1 +|> join lateral (select * from lateral_test_t2); + +table lateral_test_t1 +|> join lateral (select lateral_test_t1.* from lateral_test_t2); + +table lateral_test_t1 +|> join lateral (select lateral_test_t1.*, t2.* from lateral_test_t2 t2); + +table lateral_test_t1 +|> join lateral_test_t2 +|> join lateral (select lateral_test_t1.c2 + lateral_test_t2.c2); + +-- Natural joins. +table natural_join_test_t1 +|> natural join natural_join_test_t2 +|> where k = "one"; + +table natural_join_test_t1 +|> natural join natural_join_test_t2 nt2 +|> select natural_join_test_t1.*; + +table natural_join_test_t1 +|> natural join natural_join_test_t2 nt2 +|> natural join natural_join_test_t3 nt3 +|> select natural_join_test_t1.*, nt2.*, nt3.*; + +-- JOIN operators: negative tests. +---------------------------------- + +-- Multiple joins within the same pipe operator are not supported without parentheses. +table join_test_t1 +|> inner join join_test_empty_table + inner join join_test_empty_table; + +-- The join pipe operator can only refer to column names from the previous relation. +table join_test_t1 +|> select 1 + 2 as result +|> full outer join join_test_empty_table on (join_test_t1.a = join_test_empty_table.a); + +-- The table from the pipe input is not visible as a table name in the right side. +table join_test_t1 jt +|> cross join (select * from jt); + +-- Set operations: positive tests. +----------------------------------- + +-- Union all. +table t +|> union all table t; + +-- Union distinct. +table t +|> union table t; + +-- Union all with a table subquery. +(select * from t) +|> union all table t; + +-- Union distinct with a table subquery. +(select * from t) +|> union table t; + +-- Union all with a VALUES list. +values (0, 'abc') tab(x, y) +|> union all table t; + +-- Union distinct with a VALUES list. +values (0, 1) tab(x, y) +|> union table t; + +-- Union all with a table subquery on both the source and target sides. +(select * from t) +|> union all (select * from t); + +-- Except all. +table t +|> except all table t; + +-- Except distinct. +table t +|> except table t; + +-- Intersect all. +table t +|> intersect all table t; + +-- Intersect distinct. +table t +|> intersect table t; + +-- Minus all. +table t +|> minus all table t; + +-- Minus distinct. +table t +|> minus table t; + +-- Set operations: negative tests. +----------------------------------- + +-- The UNION operator requires the same number of columns in the input relations. +table t +|> select x +|> union all table t; + +-- The UNION operator requires the column types to be compatible. +table t +|> union all table st; + -- Cleanup. ----------- drop table t; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-group-by.sql index db7cdc97614cb..a23083e9e0e4d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-group-by.sql @@ -22,16 +22,25 @@ select *, (select count(*) from y where x1 = y1 and cast(y2 as double) = x1 + 1 select *, (select count(*) from y where y2 + 1 = x1 + x2 group by y2 + 1) from x; --- Illegal queries +-- Illegal queries (single join disabled) +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = false; select * from x where (select count(*) from y where y1 > x1 group by y1) = 1; select *, (select count(*) from y where y1 + y2 = x1 group by y1) from x; select *, (select count(*) from y where x1 = y1 and y2 + 10 = x1 + 1 group by y2) from x; +-- Same queries, with LeftSingle join +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = true; +select * from x where (select count(*) from y where y1 > x1 group by y1) = 1; +select *, (select count(*) from y where y1 + y2 = x1 group by y1) from x; +select *, (select count(*) from y where x1 = y1 and y2 + 10 = x1 + 1 group by y2) from x; + + -- Certain other operators like OUTER JOIN or UNION between the correlating filter and the group-by also can cause the scalar subquery to return multiple values and hence make the query illegal. select *, (select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1) from x; select *, (select count(*) from y left join (select * from z where z1 = x1) sub on y2 = z2 group by z1) from x; -- The correlation below the join is unsupported in Spark anyway, but when we do support it this query should still be disallowed. -- Test legacy behavior conf set spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate = true; +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = false; select * from x where (select count(*) from y where y1 > x1 group by y1) = 1; reset spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql index 2823888e6e438..81e0c5f98d82b 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql @@ -529,3 +529,6 @@ FROM t1 WHERE (SELECT max(t2c) FROM t2 WHERE t1b = t2b ) between 1 and 2; + + +SELECT * FROM t0 WHERE t0a = (SELECT distinct(t1c) FROM t1 WHERE t1a = t0a); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out index 6497a46c68ccd..d9d266e8a674a 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out @@ -32,6 +32,7 @@ BUCKETS false BY false BYTE false CACHE false +CALL true CALLED false CASCADE false CASE true @@ -186,6 +187,7 @@ LOCK false LOCKS false LOGICAL false LONG false +LOOP false MACRO false MAP false MATCHED false @@ -378,6 +380,7 @@ ANY AS AUTHORIZATION BOTH +CALL CASE CAST CHECK diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out index cf1bce3c0e504..706673606625b 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out @@ -842,7 +842,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "WINDOWS-1252", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`encode`", "parameter" : "`charset`" } @@ -860,7 +860,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "WINDOWS-1252", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`encode`", "parameter" : "`charset`" } @@ -878,7 +878,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "Windows-xxx", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`encode`", "parameter" : "`charset`" } @@ -896,7 +896,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "Windows-xxx", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`encode`", "parameter" : "`charset`" } @@ -1140,7 +1140,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "Windows-xxx", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`decode`", "parameter" : "`charset`" } @@ -1158,7 +1158,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "Windows-xxx", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`decode`", "parameter" : "`charset`" } @@ -1208,7 +1208,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "WINDOWS-1252", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`decode`", "parameter" : "`charset`" } @@ -1226,7 +1226,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "WINDOWS-1252", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`decode`", "parameter" : "`charset`" } diff --git a/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out b/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out index 568c9f3b29e87..2960c4ca4f4d4 100644 --- a/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out @@ -556,6 +556,22 @@ Location [not included in comparison]/{warehouse_dir}/char_part Partition Provider Catalog +-- !query +alter table char_part change column c1 comment 'char comment' +-- !query schema +struct<> +-- !query output + + + +-- !query +alter table char_part change column v1 comment 'varchar comment' +-- !query schema +struct<> +-- !query output + + + -- !query alter table char_part add partition (v2='ke', c2='nt') location 'loc1' -- !query schema @@ -569,8 +585,8 @@ desc formatted char_part -- !query schema struct -- !query output -c1 char(5) -v1 varchar(6) +c1 char(5) char comment +v1 varchar(6) varchar comment v2 varchar(2) c2 char(2) # Partition Information @@ -612,8 +628,8 @@ desc formatted char_part -- !query schema struct -- !query output -c1 char(5) -v1 varchar(6) +c1 char(5) char comment +v1 varchar(6) varchar comment v2 varchar(2) c2 char(2) # Partition Information @@ -647,8 +663,8 @@ desc formatted char_part -- !query schema struct -- !query output -c1 char(5) -v1 varchar(6) +c1 char(5) char comment +v1 varchar(6) varchar comment v2 varchar(2) c2 char(2) # Partition Information @@ -682,8 +698,8 @@ desc formatted char_part -- !query schema struct -- !query output -c1 char(5) -v1 varchar(6) +c1 char(5) char comment +v1 varchar(6) varchar comment v2 varchar(2) c2 char(2) # Partition Information @@ -1219,7 +1235,7 @@ struct -- !query SELECT to_varchar(78.12, '$99.99') -- !query schema -struct +struct -- !query output $78.12 @@ -1227,7 +1243,7 @@ $78.12 -- !query SELECT to_varchar(111.11, '99.9') -- !query schema -struct +struct -- !query output ##.# @@ -1235,6 +1251,6 @@ struct -- !query SELECT to_varchar(12454.8, '99,999.9S') -- !query schema -struct +struct -- !query output 12,454.8+ diff --git a/sql/core/src/test/resources/sql-tests/results/collations.sql.out b/sql/core/src/test/resources/sql-tests/results/collations.sql.out index ea5564aafe96f..9d29a46e5a0ef 100644 --- a/sql/core/src/test/resources/sql-tests/results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/collations.sql.out @@ -480,6 +480,32 @@ struct +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(text, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"str_to_map(collate(text, unicode_ai), collate(pairDelim, unicode_ai), collate(keyValueDelim, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 106, + "fragment" : "str_to_map(text collate unicode_ai, pairDelim collate unicode_ai, keyValueDelim collate unicode_ai)" + } ] +} + + -- !query drop table t4 -- !query schema @@ -1021,6 +1047,32 @@ struct +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"split_part(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 2)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 83, + "fragment" : "split_part(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2)" + } ] +} + + -- !query select split_part(utf8_binary, 'a', 3), split_part(utf8_lcase, 'a', 3) from t5 -- !query schema @@ -1148,6 +1200,32 @@ true true +-- !query +select contains(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"contains(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 78, + "fragment" : "contains(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select contains(utf8_binary, 'a'), contains(utf8_lcase, 'a') from t5 -- !query schema @@ -1275,6 +1353,32 @@ kitten İo +-- !query +select substring_index(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"substring_index(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 2)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 88, + "fragment" : "substring_index(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2)" + } ] +} + + -- !query select substring_index(utf8_binary, 'a', 2), substring_index(utf8_lcase, 'a', 2) from t5 -- !query schema @@ -1402,6 +1506,32 @@ struct +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"instr(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "instr(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select instr(utf8_binary, 'a'), instr(utf8_lcase, 'a') from t5 -- !query schema @@ -1656,6 +1786,32 @@ true true +-- !query +select startswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"startswith(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 80, + "fragment" : "startswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select startswith(utf8_binary, 'aaAaaAaA'), startswith(utf8_lcase, 'aaAaaAaA') from t5 -- !query schema @@ -1763,6 +1919,32 @@ kitten İo +-- !query +select translate(utf8_binary, 'SQL' collate unicode_ai, '12345' collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"utf8_binary\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"translate(utf8_binary, collate(SQL, unicode_ai), collate(12345, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 83, + "fragment" : "translate(utf8_binary, 'SQL' collate unicode_ai, '12345' collate unicode_ai)" + } ] +} + + -- !query select translate(utf8_lcase, 'aaAaaAaA', '12345'), translate(utf8_binary, 'aaAaaAaA', '12345') from t5 -- !query schema @@ -1890,6 +2072,32 @@ bbabcbabcabcbabc kitten +-- !query +select replace(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 'abc') from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"replace(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), abc)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 84, + "fragment" : "replace(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 'abc')" + } ] +} + + -- !query select replace(utf8_binary, 'aaAaaAaA', 'abc'), replace(utf8_lcase, 'aaAaaAaA', 'abc') from t5 -- !query schema @@ -2005,8 +2213,8 @@ struct +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"endswith(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 78, + "fragment" : "endswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select endswith(utf8_binary, 'aaAaaAaA'), endswith(utf8_lcase, 'aaAaaAaA') from t5 -- !query schema @@ -3570,6 +3804,32 @@ struct +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"locate(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 3)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 79, + "fragment" : "locate(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 3)" + } ] +} + + -- !query select locate(utf8_binary, 'a'), locate(utf8_lcase, 'a') from t5 -- !query schema @@ -3685,6 +3945,32 @@ QL sitTing +-- !query +select TRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(BOTH collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 74, + "fragment" : "TRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select TRIM('ABc', utf8_binary), TRIM('ABc', utf8_lcase) from t5 -- !query schema @@ -3812,6 +4098,32 @@ park İ +-- !query +select BTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(BOTH collate(utf8_lcase, unicode_ai) FROM collate(utf8_binary, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "BTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select BTRIM('ABc', utf8_binary), BTRIM('ABc', utf8_lcase) from t5 -- !query schema @@ -3927,6 +4239,32 @@ QL sitTing +-- !query +select LTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(LEADING collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "LTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select LTRIM('ABc', utf8_binary), LTRIM('ABc', utf8_lcase) from t5 -- !query schema @@ -4042,6 +4380,32 @@ SQL sitTing +-- !query +select RTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(TRAILING collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "RTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select RTRIM('ABc', utf8_binary), RTRIM('ABc', utf8_lcase) from t5 -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/current_database_catalog.sql.out b/sql/core/src/test/resources/sql-tests/results/current_database_catalog.sql.out index 67db0adee7f07..7fbe2dfff4db1 100644 --- a/sql/core/src/test/resources/sql-tests/results/current_database_catalog.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/current_database_catalog.sql.out @@ -2,6 +2,6 @@ -- !query select current_database(), current_schema(), current_catalog() -- !query schema -struct +struct -- !query output default default spark_catalog diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index d8a9f4c2e11f5..5d220fc12b78e 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1066,7 +1066,7 @@ SELECT FROM VALUES (1), (2), (1) AS tab(col) -- !query schema -struct,collect_list(col):array> +struct,array_agg(col):array> -- !query output [1,2,1] [1,2,1] @@ -1080,7 +1080,7 @@ FROM VALUES (1,4),(2,3),(1,4),(2,4) AS v(a,b) GROUP BY a -- !query schema -struct,collect_list(b):array> +struct,array_agg(b):array> -- !query output 1 [4,4] [4,4] 2 [3,4] [3,4] diff --git a/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out b/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out index 952fb8fdc2bd2..596745b4ba5d8 100644 --- a/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out @@ -1024,7 +1024,7 @@ org.apache.spark.sql.catalyst.parser.ParseException "errorClass" : "INVALID_SQL_SYNTAX.MULTI_PART_NAME", "sqlState" : "42000", "messageParameters" : { - "funcName" : "`default`.`myDoubleAvg`", + "name" : "`default`.`myDoubleAvg`", "statement" : "DROP TEMPORARY FUNCTION" }, "queryContext" : [ { diff --git a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out index 0dfd62599afa6..cd93a811d64f5 100644 --- a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out @@ -32,6 +32,7 @@ BUCKETS false BY false BYTE false CACHE false +CALL false CALLED false CASCADE false CASE false @@ -186,6 +187,7 @@ LOCK false LOCKS false LOGICAL false LONG false +LOOP false MACRO false MAP false MATCHED false diff --git a/sql/core/src/test/resources/sql-tests/results/null-handling.sql.out b/sql/core/src/test/resources/sql-tests/results/null-handling.sql.out index ece6dbef1605d..fb96be8317a5b 100644 --- a/sql/core/src/test/resources/sql-tests/results/null-handling.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/null-handling.sql.out @@ -77,6 +77,22 @@ struct 7 NULL +-- !query +select b + 0 from t1 where a = 5 +-- !query schema +struct<(b + 0):int> +-- !query output +NULL + + +-- !query +select -100 + b + 100 from t1 where a = 5 +-- !query schema +struct<((-100 + b) + 100):int> +-- !query output +NULL + + -- !query select a+10, b*0 from t1 -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out b/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out index 7e0b7912105c2..8cbc5357d78b6 100644 --- a/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out @@ -71,6 +71,149 @@ struct<> +-- !query +create temporary view courseSales as select * from values + ("dotNET", 2012, 10000), + ("Java", 2012, 20000), + ("dotNET", 2012, 5000), + ("dotNET", 2013, 48000), + ("Java", 2013, 30000) + as courseSales(course, year, earnings) +-- !query schema +struct<> +-- !query output + + + +-- !query +create temporary view courseEarnings as select * from values + ("dotNET", 15000, 48000, 22500), + ("Java", 20000, 30000, NULL) + as courseEarnings(course, `2012`, `2013`, `2014`) +-- !query schema +struct<> +-- !query output + + + +-- !query +create temporary view courseEarningsAndSales as select * from values + ("dotNET", 15000, NULL, 48000, 1, 22500, 1), + ("Java", 20000, 1, 30000, 2, NULL, NULL) + as courseEarningsAndSales( + course, earnings2012, sales2012, earnings2013, sales2013, earnings2014, sales2014) +-- !query schema +struct<> +-- !query output + + + +-- !query +create temporary view yearsWithComplexTypes as select * from values + (2012, array(1, 1), map('1', 1), struct(1, 'a')), + (2013, array(2, 2), map('2', 2), struct(2, 'b')) + as yearsWithComplexTypes(y, a, m, s) +-- !query schema +struct<> +-- !query output + + + +-- !query +create temporary view join_test_t1 as select * from values (1) as grouping(a) +-- !query schema +struct<> +-- !query output + + + +-- !query +create temporary view join_test_t2 as select * from values (1) as grouping(a) +-- !query schema +struct<> +-- !query output + + + +-- !query +create temporary view join_test_t3 as select * from values (1) as grouping(a) +-- !query schema +struct<> +-- !query output + + + +-- !query +create temporary view join_test_empty_table as select a from join_test_t2 where false +-- !query schema +struct<> +-- !query output + + + +-- !query +create temporary view lateral_test_t1(c1, c2) + as values (0, 1), (1, 2) +-- !query schema +struct<> +-- !query output + + + +-- !query +create temporary view lateral_test_t2(c1, c2) + as values (0, 2), (0, 3) +-- !query schema +struct<> +-- !query output + + + +-- !query +create temporary view lateral_test_t3(c1, c2) + as values (0, array(0, 1)), (1, array(2)), (2, array()), (null, array(4)) +-- !query schema +struct<> +-- !query output + + + +-- !query +create temporary view lateral_test_t4(c1, c2) + as values (0, 1), (0, 2), (1, 1), (1, 3) +-- !query schema +struct<> +-- !query output + + + +-- !query +create temporary view natural_join_test_t1 as select * from values + ("one", 1), ("two", 2), ("three", 3) as natural_join_test_t1(k, v1) +-- !query schema +struct<> +-- !query output + + + +-- !query +create temporary view natural_join_test_t2 as select * from values + ("one", 1), ("two", 22), ("one", 5) as natural_join_test_t2(k, v2) +-- !query schema +struct<> +-- !query output + + + +-- !query +create temporary view natural_join_test_t3 as select * from values + ("one", 4), ("two", 5), ("one", 6) as natural_join_test_t3(k, v3) +-- !query schema +struct<> +-- !query output + + + -- !query table t |> select 1 as x @@ -238,6 +381,56 @@ struct 1 def +-- !query +table t +|> select * +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> select * except (y) +-- !query schema +struct +-- !query output +0 +1 + + +-- !query +table t +|> select /*+ repartition(3) */ * +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> select /*+ repartition(3) */ distinct x +-- !query schema +struct +-- !query output +0 +1 + + +-- !query +table t +|> select /*+ repartition(3) */ all x +-- !query schema +struct +-- !query output +0 +1 + + -- !query table t |> select sum(x) as result @@ -284,6 +477,1202 @@ org.apache.spark.sql.AnalysisException } +-- !query +table t +|> where true +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> where x + length(y) < 4 +-- !query schema +struct +-- !query output +0 abc + + +-- !query +table t +|> where x + length(y) < 4 +|> where x + length(y) < 3 +-- !query schema +struct +-- !query output + + + +-- !query +(select x, sum(length(y)) as sum_len from t group by x) +|> where x = 1 +-- !query schema +struct +-- !query output +1 3 + + +-- !query +table t +|> where t.x = 1 +-- !query schema +struct +-- !query output +1 def + + +-- !query +table t +|> where spark_catalog.default.t.x = 1 +-- !query schema +struct +-- !query output +1 def + + +-- !query +(select col from st) +|> where col.i1 = 1 +-- !query schema +struct> +-- !query output + + + +-- !query +table st +|> where st.col.i1 = 2 +-- !query schema +struct> +-- !query output +1 {"i1":2,"i2":3} + + +-- !query +table t +|> where exists (select a from other where x = a limit 1) +-- !query schema +struct +-- !query output +1 def + + +-- !query +table t +|> where (select any_value(a) from other where x = a limit 1) = 1 +-- !query schema +struct +-- !query output +1 def + + +-- !query +table t +|> where sum(x) = 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_WHERE_CONDITION", + "sqlState" : "42903", + "messageParameters" : { + "condition" : "\"(sum(x) = 1)\"", + "expressionList" : "sum(spark_catalog.default.t.x)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 27, + "fragment" : "table t\n|> where sum(x) = 1" + } ] +} + + +-- !query +table t +|> where y = 'abc' or length(y) + sum(x) = 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_WHERE_CONDITION", + "sqlState" : "42903", + "messageParameters" : { + "condition" : "\"((y = abc) OR ((length(y) + sum(x)) = 1))\"", + "expressionList" : "sum(spark_catalog.default.t.x)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 52, + "fragment" : "table t\n|> where y = 'abc' or length(y) + sum(x) = 1" + } ] +} + + +-- !query +table t +|> where first_value(x) over (partition by y) = 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1034", + "messageParameters" : { + "clauseName" : "WHERE" + } +} + + +-- !query +select * from t where first_value(x) over (partition by y) = 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1034", + "messageParameters" : { + "clauseName" : "WHERE" + } +} + + +-- !query +table t +|> select x, length(y) as z +|> where x + length(y) < 4 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`y`", + "proposal" : "`x`, `z`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 57, + "stopIndex" : 57, + "fragment" : "y" + } ] +} + + +-- !query +(select x, sum(length(y)) as sum_len from t group by x) +|> where sum(length(y)) = 3 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`y`", + "proposal" : "`x`, `sum_len`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 77, + "stopIndex" : 77, + "fragment" : "y" + } ] +} + + +-- !query +table courseSales +|> select `year`, course, earnings +|> pivot ( + sum(earnings) + for course in ('dotNET', 'Java') + ) +-- !query schema +struct +-- !query output +2012 15000 20000 +2013 48000 30000 + + +-- !query +table courseSales +|> select `year` as y, course as c, earnings as e +|> pivot ( + sum(e) as s, avg(e) as a + for y in (2012 as firstYear, 2013 as secondYear) + ) +-- !query schema +struct +-- !query output +Java 20000 20000.0 30000 30000.0 +dotNET 15000 7500.0 48000 48000.0 + + +-- !query +select course, `year`, y, a +from courseSales +join yearsWithComplexTypes on `year` = y +|> pivot ( + max(a) + for (y, course) in ((2012, 'dotNET'), (2013, 'Java')) + ) +-- !query schema +struct,{2013, Java}:array> +-- !query output +2012 [1,1] NULL +2013 NULL [2,2] + + +-- !query +select earnings, `year`, s +from courseSales +join yearsWithComplexTypes on `year` = y +|> pivot ( + sum(earnings) + for s in ((1, 'a'), (2, 'b')) + ) +-- !query schema +struct +-- !query output +2012 35000 NULL +2013 NULL 78000 + + +-- !query +table courseEarnings +|> unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) +-- !query schema +struct +-- !query output +Java 2012 20000 +Java 2013 30000 +dotNET 2012 15000 +dotNET 2013 48000 +dotNET 2014 22500 + + +-- !query +table courseEarnings +|> unpivot include nulls ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) +-- !query schema +struct +-- !query output +Java 2012 20000 +Java 2013 30000 +Java 2014 NULL +dotNET 2012 15000 +dotNET 2013 48000 +dotNET 2014 22500 + + +-- !query +table courseEarningsAndSales +|> unpivot include nulls ( + (earnings, sales) for `year` in ( + (earnings2012, sales2012) as `2012`, + (earnings2013, sales2013) as `2013`, + (earnings2014, sales2014) as `2014`) + ) +-- !query schema +struct +-- !query output +Java 2012 20000 1 +Java 2013 30000 2 +Java 2014 NULL NULL +dotNET 2012 15000 NULL +dotNET 2013 48000 1 +dotNET 2014 22500 1 + + +-- !query +table courseSales +|> select course, earnings +|> pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`year`", + "proposal" : "`course`, `earnings`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 49, + "stopIndex" : 111, + "fragment" : "pivot (\n sum(earnings)\n for `year` in (2012, 2013)\n )" + } ] +} + + +-- !query +table courseSales +|> pivot ( + sum(earnings) + for `year` in (course, 2013) + ) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "NON_LITERAL_PIVOT_VALUES", + "sqlState" : "42K08", + "messageParameters" : { + "expression" : "\"course\"" + } +} + + +-- !query +table courseSales +|> select course, earnings +|> pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) + unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "NOT_ALLOWED_IN_FROM.UNPIVOT_WITH_PIVOT", + "sqlState" : "42601", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 186, + "fragment" : "table courseSales\n|> select course, earnings\n|> pivot (\n sum(earnings)\n for `year` in (2012, 2013)\n )\n unpivot (\n earningsYear for `year` in (`2012`, `2013`, `2014`)\n )" + } ] +} + + +-- !query +table courseSales +|> select course, earnings +|> unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) + pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "NOT_ALLOWED_IN_FROM.UNPIVOT_WITH_PIVOT", + "sqlState" : "42601", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 186, + "fragment" : "table courseSales\n|> select course, earnings\n|> unpivot (\n earningsYear for `year` in (`2012`, `2013`, `2014`)\n )\n pivot (\n sum(earnings)\n for `year` in (2012, 2013)\n )" + } ] +} + + +-- !query +table courseSales +|> select course, earnings +|> pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) + pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'pivot'", + "hint" : "" + } +} + + +-- !query +table courseSales +|> select course, earnings +|> unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) + unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) + pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'unpivot'", + "hint" : "" + } +} + + +-- !query +table t +|> tablesample (100 percent) repeatable (0) +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> tablesample (2 rows) repeatable (0) +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> tablesample (bucket 1 out of 1) repeatable (0) +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> tablesample (100 percent) repeatable (0) +|> tablesample (5 rows) repeatable (0) +|> tablesample (bucket 1 out of 1) repeatable (0) +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> tablesample () +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_0014", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 25, + "fragment" : "tablesample ()" + } ] +} + + +-- !query +table t +|> tablesample (-100 percent) repeatable (0) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_0064", + "messageParameters" : { + "msg" : "Sampling fraction (-1.0) must be on interval [0, 1]" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 52, + "fragment" : "tablesample (-100 percent) repeatable (0)" + } ] +} + + +-- !query +table t +|> tablesample (-5 rows) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_LIMIT_LIKE_EXPRESSION.IS_NEGATIVE", + "sqlState" : "42K0E", + "messageParameters" : { + "expr" : "\"-5\"", + "name" : "limit", + "v" : "-5" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 25, + "stopIndex" : 26, + "fragment" : "-5" + } ] +} + + +-- !query +table t +|> tablesample (x rows) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_LIMIT_LIKE_EXPRESSION.IS_UNFOLDABLE", + "sqlState" : "42K0E", + "messageParameters" : { + "expr" : "\"x\"", + "name" : "limit" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 25, + "stopIndex" : 25, + "fragment" : "x" + } ] +} + + +-- !query +table t +|> tablesample (bucket 2 out of 1) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_0064", + "messageParameters" : { + "msg" : "Sampling fraction (2.0) must be on interval [0, 1]" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 42, + "fragment" : "tablesample (bucket 2 out of 1)" + } ] +} + + +-- !query +table t +|> tablesample (200b) repeatable (0) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_0015", + "messageParameters" : { + "msg" : "byteLengthLiteral" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 44, + "fragment" : "tablesample (200b) repeatable (0)" + } ] +} + + +-- !query +table t +|> tablesample (200) repeatable (0) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_0016", + "messageParameters" : { + "bytesStr" : "200" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 43, + "fragment" : "tablesample (200) repeatable (0)" + } ] +} + + +-- !query +table join_test_t1 +|> inner join join_test_empty_table +-- !query schema +struct +-- !query output + + + +-- !query +table join_test_t1 +|> cross join join_test_empty_table +-- !query schema +struct +-- !query output + + + +-- !query +table join_test_t1 +|> left outer join join_test_empty_table +-- !query schema +struct +-- !query output +1 NULL + + +-- !query +table join_test_t1 +|> right outer join join_test_empty_table +-- !query schema +struct +-- !query output + + + +-- !query +table join_test_t1 +|> full outer join join_test_empty_table using (a) +-- !query schema +struct +-- !query output +1 + + +-- !query +table join_test_t1 +|> full outer join join_test_empty_table on (join_test_t1.a = join_test_empty_table.a) +-- !query schema +struct +-- !query output +1 NULL + + +-- !query +table join_test_t1 +|> left semi join join_test_empty_table +-- !query schema +struct +-- !query output + + + +-- !query +table join_test_t1 +|> left anti join join_test_empty_table +-- !query schema +struct +-- !query output +1 + + +-- !query +select * from join_test_t1 where true +|> inner join join_test_empty_table +-- !query schema +struct +-- !query output + + + +-- !query +select 1 as x, 2 as y +|> inner join (select 1 as x, 4 as y) using (x) +-- !query schema +struct +-- !query output +1 2 4 + + +-- !query +table join_test_t1 +|> inner join (join_test_t2 jt2 inner join join_test_t3 jt3 using (a)) using (a) +|> select a, join_test_t1.a, jt2.a, jt3.a +-- !query schema +struct +-- !query output +1 1 1 1 + + +-- !query +table join_test_t1 +|> inner join join_test_t2 tablesample (100 percent) repeatable (0) jt2 using (a) +-- !query schema +struct +-- !query output +1 + + +-- !query +table join_test_t1 +|> inner join (select 1 as a) tablesample (100 percent) repeatable (0) jt2 using (a) +-- !query schema +struct +-- !query output +1 + + +-- !query +table join_test_t1 +|> join join_test_t1 using (a) +-- !query schema +struct +-- !query output +1 + + +-- !query +table lateral_test_t1 +|> join lateral (select c1) +-- !query schema +struct +-- !query output +0 1 0 +1 2 1 + + +-- !query +table lateral_test_t1 +|> join lateral (select c1 from lateral_test_t2) +-- !query schema +struct +-- !query output +0 1 0 +0 1 0 +1 2 0 +1 2 0 + + +-- !query +table lateral_test_t1 +|> join lateral (select lateral_test_t1.c1 from lateral_test_t2) +-- !query schema +struct +-- !query output +0 1 0 +0 1 0 +1 2 1 +1 2 1 + + +-- !query +table lateral_test_t1 +|> join lateral (select lateral_test_t1.c1 + t2.c1 from lateral_test_t2 t2) +-- !query schema +struct +-- !query output +0 1 0 +0 1 0 +1 2 1 +1 2 1 + + +-- !query +table lateral_test_t1 +|> join lateral (select *) +-- !query schema +struct +-- !query output +0 1 +1 2 + + +-- !query +table lateral_test_t1 +|> join lateral (select * from lateral_test_t2) +-- !query schema +struct +-- !query output +0 1 0 2 +0 1 0 3 +1 2 0 2 +1 2 0 3 + + +-- !query +table lateral_test_t1 +|> join lateral (select lateral_test_t1.* from lateral_test_t2) +-- !query schema +struct +-- !query output +0 1 0 1 +0 1 0 1 +1 2 1 2 +1 2 1 2 + + +-- !query +table lateral_test_t1 +|> join lateral (select lateral_test_t1.*, t2.* from lateral_test_t2 t2) +-- !query schema +struct +-- !query output +0 1 0 1 0 2 +0 1 0 1 0 3 +1 2 1 2 0 2 +1 2 1 2 0 3 + + +-- !query +table lateral_test_t1 +|> join lateral_test_t2 +|> join lateral (select lateral_test_t1.c2 + lateral_test_t2.c2) +-- !query schema +struct +-- !query output +0 1 0 2 3 +0 1 0 3 4 +1 2 0 2 4 +1 2 0 3 5 + + +-- !query +table natural_join_test_t1 +|> natural join natural_join_test_t2 +|> where k = "one" +-- !query schema +struct +-- !query output +one 1 1 +one 1 5 + + +-- !query +table natural_join_test_t1 +|> natural join natural_join_test_t2 nt2 +|> select natural_join_test_t1.* +-- !query schema +struct +-- !query output +one 1 +one 1 +two 2 + + +-- !query +table natural_join_test_t1 +|> natural join natural_join_test_t2 nt2 +|> natural join natural_join_test_t3 nt3 +|> select natural_join_test_t1.*, nt2.*, nt3.* +-- !query schema +struct +-- !query output +one 1 one 1 one 4 +one 1 one 1 one 6 +one 1 one 5 one 4 +one 1 one 5 one 6 +two 2 two 22 two 5 + + +-- !query +table join_test_t1 +|> inner join join_test_empty_table + inner join join_test_empty_table +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'inner'", + "hint" : "" + } +} + + +-- !query +table join_test_t1 +|> select 1 + 2 as result +|> full outer join join_test_empty_table on (join_test_t1.a = join_test_empty_table.a) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`join_test_t1`.`a`", + "proposal" : "`result`, `join_test_empty_table`.`a`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 91, + "stopIndex" : 104, + "fragment" : "join_test_t1.a" + } ] +} + + +-- !query +table join_test_t1 jt +|> cross join (select * from jt) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'jt'", + "hint" : "" + } +} + + +-- !query +table t +|> union all table t +-- !query schema +struct +-- !query output +0 abc +0 abc +1 def +1 def + + +-- !query +table t +|> union table t +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +(select * from t) +|> union all table t +-- !query schema +struct +-- !query output +0 abc +0 abc +1 def +1 def + + +-- !query +(select * from t) +|> union table t +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +values (0, 'abc') tab(x, y) +|> union all table t +-- !query schema +struct +-- !query output +0 abc +0 abc +1 def + + +-- !query +values (0, 1) tab(x, y) +|> union table t +-- !query schema +struct +-- !query output +0 1 +0 abc +1 def + + +-- !query +(select * from t) +|> union all (select * from t) +-- !query schema +struct +-- !query output +0 abc +0 abc +1 def +1 def + + +-- !query +table t +|> except all table t +-- !query schema +struct +-- !query output + + + +-- !query +table t +|> except table t +-- !query schema +struct +-- !query output + + + +-- !query +table t +|> intersect all table t +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> intersect table t +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> minus all table t +-- !query schema +struct +-- !query output + + + +-- !query +table t +|> minus table t +-- !query schema +struct +-- !query output + + + +-- !query +table t +|> select x +|> union all table t +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "NUM_COLUMNS_MISMATCH", + "sqlState" : "42826", + "messageParameters" : { + "firstNumColumns" : "1", + "invalidNumColumns" : "2", + "invalidOrdinalNum" : "second", + "operator" : "UNION" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 40, + "fragment" : "table t\n|> select x\n|> union all table t" + } ] +} + + +-- !query +table t +|> union all table st +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INCOMPATIBLE_COLUMN_TYPE", + "sqlState" : "42825", + "messageParameters" : { + "columnOrdinalNumber" : "second", + "dataType1" : "\"STRUCT\"", + "dataType2" : "\"STRING\"", + "hint" : "", + "operator" : "UNION", + "tableOrdinalNumber" : "second" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 29, + "fragment" : "table t\n|> union all table st" + } ] +} + + -- !query drop table t -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 14d7b31f8c63f..3f9f24f817f2c 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -778,7 +778,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "WINDOWS-1252", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`encode`", "parameter" : "`charset`" } @@ -796,7 +796,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "WINDOWS-1252", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`encode`", "parameter" : "`charset`" } @@ -814,7 +814,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "Windows-xxx", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`encode`", "parameter" : "`charset`" } @@ -832,7 +832,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "Windows-xxx", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`encode`", "parameter" : "`charset`" } @@ -1076,7 +1076,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "Windows-xxx", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`decode`", "parameter" : "`charset`" } @@ -1094,7 +1094,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "Windows-xxx", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`decode`", "parameter" : "`charset`" } @@ -1144,7 +1144,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "WINDOWS-1252", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`decode`", "parameter" : "`charset`" } @@ -1162,7 +1162,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "WINDOWS-1252", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`decode`", "parameter" : "`charset`" } diff --git a/sql/core/src/test/resources/sql-tests/results/subexp-elimination.sql.out b/sql/core/src/test/resources/sql-tests/results/subexp-elimination.sql.out index 0f7ff3f107567..28457c0579e95 100644 --- a/sql/core/src/test/resources/sql-tests/results/subexp-elimination.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subexp-elimination.sql.out @@ -72,7 +72,7 @@ NULL -- !query SELECT from_json(a, 'struct').a + random() > 2, from_json(a, 'struct').b, from_json(b, 'array>')[0].a, from_json(b, 'array>')[0].b + + random() > 2 FROM testData -- !query schema -struct<((from_json(a).a + rand()) > 2):boolean,from_json(a).b:string,from_json(b)[0].a:int,((from_json(b)[0].b + (+ rand())) > 2):boolean> +struct<((from_json(a).a + random()) > 2):boolean,from_json(a).b:string,from_json(b)[0].a:int,((from_json(b)[0].b + (+ random())) > 2):boolean> -- !query output NULL NULL 1 true false 2 1 true @@ -84,7 +84,7 @@ true 6 6 true -- !query SELECT if(from_json(a, 'struct').a + random() > 5, from_json(b, 'array>')[0].a, from_json(b, 'array>')[0].a + 1) FROM testData -- !query schema -struct<(IF(((from_json(a).a + rand()) > 5), from_json(b)[0].a, (from_json(b)[0].a + 1))):int> +struct<(IF(((from_json(a).a + random()) > 5), from_json(b)[0].a, (from_json(b)[0].a + 1))):int> -- !query output 2 2 @@ -96,7 +96,7 @@ NULL -- !query SELECT case when from_json(a, 'struct').a > 5 then from_json(a, 'struct').b + random() > 5 when from_json(a, 'struct').a > 4 then from_json(a, 'struct').b + 1 + random() > 2 else from_json(a, 'struct').b + 2 + random() > 5 end FROM testData -- !query schema -struct 5) THEN ((from_json(a).b + rand()) > 5) WHEN (from_json(a).a > 4) THEN (((from_json(a).b + 1) + rand()) > 2) ELSE (((from_json(a).b + 2) + rand()) > 5) END:boolean> +struct 5) THEN ((from_json(a).b + random()) > 5) WHEN (from_json(a).a > 4) THEN (((from_json(a).b + 1) + random()) > 2) ELSE (((from_json(a).b + 2) + random()) > 5) END:boolean> -- !query output NULL false diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out index 41cba1f43745f..56932edd4e545 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out @@ -112,6 +112,14 @@ struct 2 2 NULL +-- !query +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = false +-- !query schema +struct +-- !query output +spark.sql.optimizer.scalarSubqueryUseSingleJoin false + + -- !query select * from x where (select count(*) from y where y1 > x1 group by y1) = 1 -- !query schema @@ -178,25 +186,56 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = true +-- !query schema +struct +-- !query output +spark.sql.optimizer.scalarSubqueryUseSingleJoin true + + +-- !query +select * from x where (select count(*) from y where y1 > x1 group by y1) = 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkRuntimeException +{ + "errorClass" : "SCALAR_SUBQUERY_TOO_MANY_ROWS", + "sqlState" : "21000" +} + + +-- !query +select *, (select count(*) from y where y1 + y2 = x1 group by y1) from x +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkRuntimeException +{ + "errorClass" : "SCALAR_SUBQUERY_TOO_MANY_ROWS", + "sqlState" : "21000" +} + + +-- !query +select *, (select count(*) from y where x1 = y1 and y2 + 10 = x1 + 1 group by y2) from x +-- !query schema +struct +-- !query output +1 1 NULL +2 2 NULL + + -- !query select *, (select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1) from x -- !query schema struct<> -- !query output -org.apache.spark.sql.catalyst.ExtendedAnalysisException +org.apache.spark.SparkRuntimeException { - "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY", - "sqlState" : "0A000", - "messageParameters" : { - "value" : "y1" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 11, - "stopIndex" : 106, - "fragment" : "(select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1)" - } ] + "errorClass" : "SCALAR_SUBQUERY_TOO_MANY_ROWS", + "sqlState" : "21000" } @@ -207,17 +246,17 @@ struct<> -- !query output org.apache.spark.sql.catalyst.ExtendedAnalysisException { - "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY", + "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.ACCESSING_OUTER_QUERY_COLUMN_IS_NOT_ALLOWED", "sqlState" : "0A000", "messageParameters" : { - "value" : "z1" + "treeNode" : "Filter (z1#x = outer(x1#x))\n+- SubqueryAlias z\n +- View (`z`, [z1#x, z2#x])\n +- Project [cast(col1#x as int) AS z1#x, cast(col2#x as int) AS z2#x]\n +- LocalRelation [col1#x, col2#x]\n" }, "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 11, - "stopIndex" : 103, - "fragment" : "(select count(*) from y left join (select * from z where z1 = x1) sub on y2 = z2 group by z1)" + "startIndex" : 46, + "stopIndex" : 74, + "fragment" : "select * from z where z1 = x1" } ] } @@ -230,6 +269,14 @@ struct spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate true +-- !query +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = false +-- !query schema +struct +-- !query output +spark.sql.optimizer.scalarSubqueryUseSingleJoin false + + -- !query select * from x where (select count(*) from y where y1 > x1 group by y1) = 1 -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out index a02f0c70be6da..2460c2452ea56 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out @@ -906,3 +906,11 @@ WHERE (SELECT max(t2c) struct -- !query output + + +-- !query +SELECT * FROM t0 WHERE t0a = (SELECT distinct(t1c) FROM t1 WHERE t1a = t0a) +-- !query schema +struct +-- !query output + diff --git a/sql/core/src/test/resources/sql-tests/results/timezone.sql.out b/sql/core/src/test/resources/sql-tests/results/timezone.sql.out index d34599a49c5ff..5f0fdef50e3db 100644 --- a/sql/core/src/test/resources/sql-tests/results/timezone.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/timezone.sql.out @@ -80,7 +80,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0044", + "errorClass" : "INVALID_INTERVAL_FORMAT.TIMEZONE_INTERVAL_OUT_OF_RANGE", + "sqlState" : "22006", + "messageParameters" : { + "input" : "3" + }, "queryContext" : [ { "objectType" : "", "objectName" : "", @@ -98,7 +102,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0044", + "errorClass" : "INVALID_INTERVAL_FORMAT.TIMEZONE_INTERVAL_OUT_OF_RANGE", + "sqlState" : "22006", + "messageParameters" : { + "input" : "24" + }, "queryContext" : [ { "objectType" : "", "objectName" : "", @@ -116,7 +124,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0044", + "errorClass" : "INVALID_INTERVAL_FORMAT.TIMEZONE_INTERVAL_OUT_OF_RANGE", + "sqlState" : "22006", + "messageParameters" : { + "input" : "19" + }, "queryContext" : [ { "objectType" : "", "objectName" : "", @@ -152,7 +164,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0044", + "errorClass" : "INVALID_INTERVAL_FORMAT.TIMEZONE_INTERVAL_OUT_OF_RANGE", + "sqlState" : "22006", + "messageParameters" : { + "input" : "36000" + }, "queryContext" : [ { "objectType" : "", "objectName" : "", diff --git a/sql/core/src/test/resources/test-data/more-columns.csv b/sql/core/src/test/resources/test-data/more-columns.csv new file mode 100644 index 0000000000000..06db38f0a145a --- /dev/null +++ b/sql/core/src/test/resources/test-data/more-columns.csv @@ -0,0 +1 @@ +1,3.14,string,5,7 \ No newline at end of file diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a.sf100/explain.txt index 96bed479d2e06..4bf7de791b279 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a.sf100/explain.txt @@ -175,125 +175,125 @@ Input [6]: [i_product_name#12, i_brand#9, i_class#10, i_category#11, sum#21, cou Keys [4]: [i_product_name#12, i_brand#9, i_class#10, i_category#11] Functions [1]: [avg(qoh#18)] Aggregate Attributes [1]: [avg(qoh#18)#23] -Results [5]: [i_product_name#12, i_brand#9, i_class#10, i_category#11, avg(qoh#18)#23 AS qoh#24] +Results [5]: [i_product_name#12 AS i_product_name#24, i_brand#9 AS i_brand#25, i_class#10 AS i_class#26, i_category#11 AS i_category#27, avg(qoh#18)#23 AS qoh#28] (27) ReusedExchange [Reuses operator id: 23] -Output [6]: [i_product_name#25, i_brand#26, i_class#27, i_category#28, sum#29, count#30] +Output [6]: [i_product_name#29, i_brand#30, i_class#31, i_category#32, sum#33, count#34] (28) HashAggregate [codegen id : 16] -Input [6]: [i_product_name#25, i_brand#26, i_class#27, i_category#28, sum#29, count#30] -Keys [4]: [i_product_name#25, i_brand#26, i_class#27, i_category#28] -Functions [1]: [avg(inv_quantity_on_hand#31)] -Aggregate Attributes [1]: [avg(inv_quantity_on_hand#31)#17] -Results [4]: [i_product_name#25, i_brand#26, i_class#27, avg(inv_quantity_on_hand#31)#17 AS qoh#32] +Input [6]: [i_product_name#29, i_brand#30, i_class#31, i_category#32, sum#33, count#34] +Keys [4]: [i_product_name#29, i_brand#30, i_class#31, i_category#32] +Functions [1]: [avg(inv_quantity_on_hand#35)] +Aggregate Attributes [1]: [avg(inv_quantity_on_hand#35)#17] +Results [4]: [i_product_name#29, i_brand#30, i_class#31, avg(inv_quantity_on_hand#35)#17 AS qoh#36] (29) HashAggregate [codegen id : 16] -Input [4]: [i_product_name#25, i_brand#26, i_class#27, qoh#32] -Keys [3]: [i_product_name#25, i_brand#26, i_class#27] -Functions [1]: [partial_avg(qoh#32)] -Aggregate Attributes [2]: [sum#33, count#34] -Results [5]: [i_product_name#25, i_brand#26, i_class#27, sum#35, count#36] +Input [4]: [i_product_name#29, i_brand#30, i_class#31, qoh#36] +Keys [3]: [i_product_name#29, i_brand#30, i_class#31] +Functions [1]: [partial_avg(qoh#36)] +Aggregate Attributes [2]: [sum#37, count#38] +Results [5]: [i_product_name#29, i_brand#30, i_class#31, sum#39, count#40] (30) Exchange -Input [5]: [i_product_name#25, i_brand#26, i_class#27, sum#35, count#36] -Arguments: hashpartitioning(i_product_name#25, i_brand#26, i_class#27, 5), ENSURE_REQUIREMENTS, [plan_id=5] +Input [5]: [i_product_name#29, i_brand#30, i_class#31, sum#39, count#40] +Arguments: hashpartitioning(i_product_name#29, i_brand#30, i_class#31, 5), ENSURE_REQUIREMENTS, [plan_id=5] (31) HashAggregate [codegen id : 17] -Input [5]: [i_product_name#25, i_brand#26, i_class#27, sum#35, count#36] -Keys [3]: [i_product_name#25, i_brand#26, i_class#27] -Functions [1]: [avg(qoh#32)] -Aggregate Attributes [1]: [avg(qoh#32)#37] -Results [5]: [i_product_name#25, i_brand#26, i_class#27, null AS i_category#38, avg(qoh#32)#37 AS qoh#39] +Input [5]: [i_product_name#29, i_brand#30, i_class#31, sum#39, count#40] +Keys [3]: [i_product_name#29, i_brand#30, i_class#31] +Functions [1]: [avg(qoh#36)] +Aggregate Attributes [1]: [avg(qoh#36)#41] +Results [5]: [i_product_name#29, i_brand#30, i_class#31, null AS i_category#42, avg(qoh#36)#41 AS qoh#43] (32) ReusedExchange [Reuses operator id: 23] -Output [6]: [i_product_name#40, i_brand#41, i_class#42, i_category#43, sum#44, count#45] +Output [6]: [i_product_name#44, i_brand#45, i_class#46, i_category#47, sum#48, count#49] (33) HashAggregate [codegen id : 25] -Input [6]: [i_product_name#40, i_brand#41, i_class#42, i_category#43, sum#44, count#45] -Keys [4]: [i_product_name#40, i_brand#41, i_class#42, i_category#43] -Functions [1]: [avg(inv_quantity_on_hand#46)] -Aggregate Attributes [1]: [avg(inv_quantity_on_hand#46)#17] -Results [3]: [i_product_name#40, i_brand#41, avg(inv_quantity_on_hand#46)#17 AS qoh#47] +Input [6]: [i_product_name#44, i_brand#45, i_class#46, i_category#47, sum#48, count#49] +Keys [4]: [i_product_name#44, i_brand#45, i_class#46, i_category#47] +Functions [1]: [avg(inv_quantity_on_hand#50)] +Aggregate Attributes [1]: [avg(inv_quantity_on_hand#50)#17] +Results [3]: [i_product_name#44, i_brand#45, avg(inv_quantity_on_hand#50)#17 AS qoh#51] (34) HashAggregate [codegen id : 25] -Input [3]: [i_product_name#40, i_brand#41, qoh#47] -Keys [2]: [i_product_name#40, i_brand#41] -Functions [1]: [partial_avg(qoh#47)] -Aggregate Attributes [2]: [sum#48, count#49] -Results [4]: [i_product_name#40, i_brand#41, sum#50, count#51] +Input [3]: [i_product_name#44, i_brand#45, qoh#51] +Keys [2]: [i_product_name#44, i_brand#45] +Functions [1]: [partial_avg(qoh#51)] +Aggregate Attributes [2]: [sum#52, count#53] +Results [4]: [i_product_name#44, i_brand#45, sum#54, count#55] (35) Exchange -Input [4]: [i_product_name#40, i_brand#41, sum#50, count#51] -Arguments: hashpartitioning(i_product_name#40, i_brand#41, 5), ENSURE_REQUIREMENTS, [plan_id=6] +Input [4]: [i_product_name#44, i_brand#45, sum#54, count#55] +Arguments: hashpartitioning(i_product_name#44, i_brand#45, 5), ENSURE_REQUIREMENTS, [plan_id=6] (36) HashAggregate [codegen id : 26] -Input [4]: [i_product_name#40, i_brand#41, sum#50, count#51] -Keys [2]: [i_product_name#40, i_brand#41] -Functions [1]: [avg(qoh#47)] -Aggregate Attributes [1]: [avg(qoh#47)#52] -Results [5]: [i_product_name#40, i_brand#41, null AS i_class#53, null AS i_category#54, avg(qoh#47)#52 AS qoh#55] +Input [4]: [i_product_name#44, i_brand#45, sum#54, count#55] +Keys [2]: [i_product_name#44, i_brand#45] +Functions [1]: [avg(qoh#51)] +Aggregate Attributes [1]: [avg(qoh#51)#56] +Results [5]: [i_product_name#44, i_brand#45, null AS i_class#57, null AS i_category#58, avg(qoh#51)#56 AS qoh#59] (37) ReusedExchange [Reuses operator id: 23] -Output [6]: [i_product_name#56, i_brand#57, i_class#58, i_category#59, sum#60, count#61] +Output [6]: [i_product_name#60, i_brand#61, i_class#62, i_category#63, sum#64, count#65] (38) HashAggregate [codegen id : 34] -Input [6]: [i_product_name#56, i_brand#57, i_class#58, i_category#59, sum#60, count#61] -Keys [4]: [i_product_name#56, i_brand#57, i_class#58, i_category#59] -Functions [1]: [avg(inv_quantity_on_hand#62)] -Aggregate Attributes [1]: [avg(inv_quantity_on_hand#62)#17] -Results [2]: [i_product_name#56, avg(inv_quantity_on_hand#62)#17 AS qoh#63] +Input [6]: [i_product_name#60, i_brand#61, i_class#62, i_category#63, sum#64, count#65] +Keys [4]: [i_product_name#60, i_brand#61, i_class#62, i_category#63] +Functions [1]: [avg(inv_quantity_on_hand#66)] +Aggregate Attributes [1]: [avg(inv_quantity_on_hand#66)#17] +Results [2]: [i_product_name#60, avg(inv_quantity_on_hand#66)#17 AS qoh#67] (39) HashAggregate [codegen id : 34] -Input [2]: [i_product_name#56, qoh#63] -Keys [1]: [i_product_name#56] -Functions [1]: [partial_avg(qoh#63)] -Aggregate Attributes [2]: [sum#64, count#65] -Results [3]: [i_product_name#56, sum#66, count#67] +Input [2]: [i_product_name#60, qoh#67] +Keys [1]: [i_product_name#60] +Functions [1]: [partial_avg(qoh#67)] +Aggregate Attributes [2]: [sum#68, count#69] +Results [3]: [i_product_name#60, sum#70, count#71] (40) Exchange -Input [3]: [i_product_name#56, sum#66, count#67] -Arguments: hashpartitioning(i_product_name#56, 5), ENSURE_REQUIREMENTS, [plan_id=7] +Input [3]: [i_product_name#60, sum#70, count#71] +Arguments: hashpartitioning(i_product_name#60, 5), ENSURE_REQUIREMENTS, [plan_id=7] (41) HashAggregate [codegen id : 35] -Input [3]: [i_product_name#56, sum#66, count#67] -Keys [1]: [i_product_name#56] -Functions [1]: [avg(qoh#63)] -Aggregate Attributes [1]: [avg(qoh#63)#68] -Results [5]: [i_product_name#56, null AS i_brand#69, null AS i_class#70, null AS i_category#71, avg(qoh#63)#68 AS qoh#72] +Input [3]: [i_product_name#60, sum#70, count#71] +Keys [1]: [i_product_name#60] +Functions [1]: [avg(qoh#67)] +Aggregate Attributes [1]: [avg(qoh#67)#72] +Results [5]: [i_product_name#60, null AS i_brand#73, null AS i_class#74, null AS i_category#75, avg(qoh#67)#72 AS qoh#76] (42) ReusedExchange [Reuses operator id: 23] -Output [6]: [i_product_name#73, i_brand#74, i_class#75, i_category#76, sum#77, count#78] +Output [6]: [i_product_name#77, i_brand#78, i_class#79, i_category#80, sum#81, count#82] (43) HashAggregate [codegen id : 43] -Input [6]: [i_product_name#73, i_brand#74, i_class#75, i_category#76, sum#77, count#78] -Keys [4]: [i_product_name#73, i_brand#74, i_class#75, i_category#76] -Functions [1]: [avg(inv_quantity_on_hand#79)] -Aggregate Attributes [1]: [avg(inv_quantity_on_hand#79)#17] -Results [1]: [avg(inv_quantity_on_hand#79)#17 AS qoh#80] +Input [6]: [i_product_name#77, i_brand#78, i_class#79, i_category#80, sum#81, count#82] +Keys [4]: [i_product_name#77, i_brand#78, i_class#79, i_category#80] +Functions [1]: [avg(inv_quantity_on_hand#83)] +Aggregate Attributes [1]: [avg(inv_quantity_on_hand#83)#17] +Results [1]: [avg(inv_quantity_on_hand#83)#17 AS qoh#84] (44) HashAggregate [codegen id : 43] -Input [1]: [qoh#80] +Input [1]: [qoh#84] Keys: [] -Functions [1]: [partial_avg(qoh#80)] -Aggregate Attributes [2]: [sum#81, count#82] -Results [2]: [sum#83, count#84] +Functions [1]: [partial_avg(qoh#84)] +Aggregate Attributes [2]: [sum#85, count#86] +Results [2]: [sum#87, count#88] (45) Exchange -Input [2]: [sum#83, count#84] +Input [2]: [sum#87, count#88] Arguments: SinglePartition, ENSURE_REQUIREMENTS, [plan_id=8] (46) HashAggregate [codegen id : 44] -Input [2]: [sum#83, count#84] +Input [2]: [sum#87, count#88] Keys: [] -Functions [1]: [avg(qoh#80)] -Aggregate Attributes [1]: [avg(qoh#80)#85] -Results [5]: [null AS i_product_name#86, null AS i_brand#87, null AS i_class#88, null AS i_category#89, avg(qoh#80)#85 AS qoh#90] +Functions [1]: [avg(qoh#84)] +Aggregate Attributes [1]: [avg(qoh#84)#89] +Results [5]: [null AS i_product_name#90, null AS i_brand#91, null AS i_class#92, null AS i_category#93, avg(qoh#84)#89 AS qoh#94] (47) Union (48) TakeOrderedAndProject -Input [5]: [i_product_name#12, i_brand#9, i_class#10, i_category#11, qoh#24] -Arguments: 100, [qoh#24 ASC NULLS FIRST, i_product_name#12 ASC NULLS FIRST, i_brand#9 ASC NULLS FIRST, i_class#10 ASC NULLS FIRST, i_category#11 ASC NULLS FIRST], [i_product_name#12, i_brand#9, i_class#10, i_category#11, qoh#24] +Input [5]: [i_product_name#24, i_brand#25, i_class#26, i_category#27, qoh#28] +Arguments: 100, [qoh#28 ASC NULLS FIRST, i_product_name#24 ASC NULLS FIRST, i_brand#25 ASC NULLS FIRST, i_class#26 ASC NULLS FIRST, i_category#27 ASC NULLS FIRST], [i_product_name#24, i_brand#25, i_class#26, i_category#27, qoh#28] ===== Subqueries ===== @@ -306,22 +306,22 @@ BroadcastExchange (53) (49) Scan parquet spark_catalog.default.date_dim -Output [2]: [d_date_sk#7, d_month_seq#91] +Output [2]: [d_date_sk#7, d_month_seq#95] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_month_seq), GreaterThanOrEqual(d_month_seq,1212), LessThanOrEqual(d_month_seq,1223), IsNotNull(d_date_sk)] ReadSchema: struct (50) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#7, d_month_seq#91] +Input [2]: [d_date_sk#7, d_month_seq#95] (51) Filter [codegen id : 1] -Input [2]: [d_date_sk#7, d_month_seq#91] -Condition : (((isnotnull(d_month_seq#91) AND (d_month_seq#91 >= 1212)) AND (d_month_seq#91 <= 1223)) AND isnotnull(d_date_sk#7)) +Input [2]: [d_date_sk#7, d_month_seq#95] +Condition : (((isnotnull(d_month_seq#95) AND (d_month_seq#95 >= 1212)) AND (d_month_seq#95 <= 1223)) AND isnotnull(d_date_sk#7)) (52) Project [codegen id : 1] Output [1]: [d_date_sk#7] -Input [2]: [d_date_sk#7, d_month_seq#91] +Input [2]: [d_date_sk#7, d_month_seq#95] (53) BroadcastExchange Input [1]: [d_date_sk#7] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a.sf100/simplified.txt index 0c4267b3ca513..042f946b8fca4 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a.sf100/simplified.txt @@ -1,7 +1,7 @@ TakeOrderedAndProject [qoh,i_product_name,i_brand,i_class,i_category] Union WholeStageCodegen (8) - HashAggregate [i_product_name,i_brand,i_class,i_category,sum,count] [avg(qoh),qoh,sum,count] + HashAggregate [i_product_name,i_brand,i_class,i_category,sum,count] [avg(qoh),i_product_name,i_brand,i_class,i_category,qoh,sum,count] HashAggregate [i_product_name,i_brand,i_class,i_category,qoh] [sum,count,sum,count] HashAggregate [i_product_name,i_brand,i_class,i_category,sum,count] [avg(inv_quantity_on_hand),qoh,sum,count] InputAdapter diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a/explain.txt index 4b8993f370f4d..8aab8e91acfc8 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a/explain.txt @@ -160,125 +160,125 @@ Input [6]: [i_product_name#11, i_brand#8, i_class#9, i_category#10, sum#21, coun Keys [4]: [i_product_name#11, i_brand#8, i_class#9, i_category#10] Functions [1]: [avg(qoh#18)] Aggregate Attributes [1]: [avg(qoh#18)#23] -Results [5]: [i_product_name#11, i_brand#8, i_class#9, i_category#10, avg(qoh#18)#23 AS qoh#24] +Results [5]: [i_product_name#11 AS i_product_name#24, i_brand#8 AS i_brand#25, i_class#9 AS i_class#26, i_category#10 AS i_category#27, avg(qoh#18)#23 AS qoh#28] (24) ReusedExchange [Reuses operator id: 20] -Output [6]: [i_product_name#25, i_brand#26, i_class#27, i_category#28, sum#29, count#30] +Output [6]: [i_product_name#29, i_brand#30, i_class#31, i_category#32, sum#33, count#34] (25) HashAggregate [codegen id : 10] -Input [6]: [i_product_name#25, i_brand#26, i_class#27, i_category#28, sum#29, count#30] -Keys [4]: [i_product_name#25, i_brand#26, i_class#27, i_category#28] -Functions [1]: [avg(inv_quantity_on_hand#31)] -Aggregate Attributes [1]: [avg(inv_quantity_on_hand#31)#17] -Results [4]: [i_product_name#25, i_brand#26, i_class#27, avg(inv_quantity_on_hand#31)#17 AS qoh#32] +Input [6]: [i_product_name#29, i_brand#30, i_class#31, i_category#32, sum#33, count#34] +Keys [4]: [i_product_name#29, i_brand#30, i_class#31, i_category#32] +Functions [1]: [avg(inv_quantity_on_hand#35)] +Aggregate Attributes [1]: [avg(inv_quantity_on_hand#35)#17] +Results [4]: [i_product_name#29, i_brand#30, i_class#31, avg(inv_quantity_on_hand#35)#17 AS qoh#36] (26) HashAggregate [codegen id : 10] -Input [4]: [i_product_name#25, i_brand#26, i_class#27, qoh#32] -Keys [3]: [i_product_name#25, i_brand#26, i_class#27] -Functions [1]: [partial_avg(qoh#32)] -Aggregate Attributes [2]: [sum#33, count#34] -Results [5]: [i_product_name#25, i_brand#26, i_class#27, sum#35, count#36] +Input [4]: [i_product_name#29, i_brand#30, i_class#31, qoh#36] +Keys [3]: [i_product_name#29, i_brand#30, i_class#31] +Functions [1]: [partial_avg(qoh#36)] +Aggregate Attributes [2]: [sum#37, count#38] +Results [5]: [i_product_name#29, i_brand#30, i_class#31, sum#39, count#40] (27) Exchange -Input [5]: [i_product_name#25, i_brand#26, i_class#27, sum#35, count#36] -Arguments: hashpartitioning(i_product_name#25, i_brand#26, i_class#27, 5), ENSURE_REQUIREMENTS, [plan_id=4] +Input [5]: [i_product_name#29, i_brand#30, i_class#31, sum#39, count#40] +Arguments: hashpartitioning(i_product_name#29, i_brand#30, i_class#31, 5), ENSURE_REQUIREMENTS, [plan_id=4] (28) HashAggregate [codegen id : 11] -Input [5]: [i_product_name#25, i_brand#26, i_class#27, sum#35, count#36] -Keys [3]: [i_product_name#25, i_brand#26, i_class#27] -Functions [1]: [avg(qoh#32)] -Aggregate Attributes [1]: [avg(qoh#32)#37] -Results [5]: [i_product_name#25, i_brand#26, i_class#27, null AS i_category#38, avg(qoh#32)#37 AS qoh#39] +Input [5]: [i_product_name#29, i_brand#30, i_class#31, sum#39, count#40] +Keys [3]: [i_product_name#29, i_brand#30, i_class#31] +Functions [1]: [avg(qoh#36)] +Aggregate Attributes [1]: [avg(qoh#36)#41] +Results [5]: [i_product_name#29, i_brand#30, i_class#31, null AS i_category#42, avg(qoh#36)#41 AS qoh#43] (29) ReusedExchange [Reuses operator id: 20] -Output [6]: [i_product_name#40, i_brand#41, i_class#42, i_category#43, sum#44, count#45] +Output [6]: [i_product_name#44, i_brand#45, i_class#46, i_category#47, sum#48, count#49] (30) HashAggregate [codegen id : 16] -Input [6]: [i_product_name#40, i_brand#41, i_class#42, i_category#43, sum#44, count#45] -Keys [4]: [i_product_name#40, i_brand#41, i_class#42, i_category#43] -Functions [1]: [avg(inv_quantity_on_hand#46)] -Aggregate Attributes [1]: [avg(inv_quantity_on_hand#46)#17] -Results [3]: [i_product_name#40, i_brand#41, avg(inv_quantity_on_hand#46)#17 AS qoh#47] +Input [6]: [i_product_name#44, i_brand#45, i_class#46, i_category#47, sum#48, count#49] +Keys [4]: [i_product_name#44, i_brand#45, i_class#46, i_category#47] +Functions [1]: [avg(inv_quantity_on_hand#50)] +Aggregate Attributes [1]: [avg(inv_quantity_on_hand#50)#17] +Results [3]: [i_product_name#44, i_brand#45, avg(inv_quantity_on_hand#50)#17 AS qoh#51] (31) HashAggregate [codegen id : 16] -Input [3]: [i_product_name#40, i_brand#41, qoh#47] -Keys [2]: [i_product_name#40, i_brand#41] -Functions [1]: [partial_avg(qoh#47)] -Aggregate Attributes [2]: [sum#48, count#49] -Results [4]: [i_product_name#40, i_brand#41, sum#50, count#51] +Input [3]: [i_product_name#44, i_brand#45, qoh#51] +Keys [2]: [i_product_name#44, i_brand#45] +Functions [1]: [partial_avg(qoh#51)] +Aggregate Attributes [2]: [sum#52, count#53] +Results [4]: [i_product_name#44, i_brand#45, sum#54, count#55] (32) Exchange -Input [4]: [i_product_name#40, i_brand#41, sum#50, count#51] -Arguments: hashpartitioning(i_product_name#40, i_brand#41, 5), ENSURE_REQUIREMENTS, [plan_id=5] +Input [4]: [i_product_name#44, i_brand#45, sum#54, count#55] +Arguments: hashpartitioning(i_product_name#44, i_brand#45, 5), ENSURE_REQUIREMENTS, [plan_id=5] (33) HashAggregate [codegen id : 17] -Input [4]: [i_product_name#40, i_brand#41, sum#50, count#51] -Keys [2]: [i_product_name#40, i_brand#41] -Functions [1]: [avg(qoh#47)] -Aggregate Attributes [1]: [avg(qoh#47)#52] -Results [5]: [i_product_name#40, i_brand#41, null AS i_class#53, null AS i_category#54, avg(qoh#47)#52 AS qoh#55] +Input [4]: [i_product_name#44, i_brand#45, sum#54, count#55] +Keys [2]: [i_product_name#44, i_brand#45] +Functions [1]: [avg(qoh#51)] +Aggregate Attributes [1]: [avg(qoh#51)#56] +Results [5]: [i_product_name#44, i_brand#45, null AS i_class#57, null AS i_category#58, avg(qoh#51)#56 AS qoh#59] (34) ReusedExchange [Reuses operator id: 20] -Output [6]: [i_product_name#56, i_brand#57, i_class#58, i_category#59, sum#60, count#61] +Output [6]: [i_product_name#60, i_brand#61, i_class#62, i_category#63, sum#64, count#65] (35) HashAggregate [codegen id : 22] -Input [6]: [i_product_name#56, i_brand#57, i_class#58, i_category#59, sum#60, count#61] -Keys [4]: [i_product_name#56, i_brand#57, i_class#58, i_category#59] -Functions [1]: [avg(inv_quantity_on_hand#62)] -Aggregate Attributes [1]: [avg(inv_quantity_on_hand#62)#17] -Results [2]: [i_product_name#56, avg(inv_quantity_on_hand#62)#17 AS qoh#63] +Input [6]: [i_product_name#60, i_brand#61, i_class#62, i_category#63, sum#64, count#65] +Keys [4]: [i_product_name#60, i_brand#61, i_class#62, i_category#63] +Functions [1]: [avg(inv_quantity_on_hand#66)] +Aggregate Attributes [1]: [avg(inv_quantity_on_hand#66)#17] +Results [2]: [i_product_name#60, avg(inv_quantity_on_hand#66)#17 AS qoh#67] (36) HashAggregate [codegen id : 22] -Input [2]: [i_product_name#56, qoh#63] -Keys [1]: [i_product_name#56] -Functions [1]: [partial_avg(qoh#63)] -Aggregate Attributes [2]: [sum#64, count#65] -Results [3]: [i_product_name#56, sum#66, count#67] +Input [2]: [i_product_name#60, qoh#67] +Keys [1]: [i_product_name#60] +Functions [1]: [partial_avg(qoh#67)] +Aggregate Attributes [2]: [sum#68, count#69] +Results [3]: [i_product_name#60, sum#70, count#71] (37) Exchange -Input [3]: [i_product_name#56, sum#66, count#67] -Arguments: hashpartitioning(i_product_name#56, 5), ENSURE_REQUIREMENTS, [plan_id=6] +Input [3]: [i_product_name#60, sum#70, count#71] +Arguments: hashpartitioning(i_product_name#60, 5), ENSURE_REQUIREMENTS, [plan_id=6] (38) HashAggregate [codegen id : 23] -Input [3]: [i_product_name#56, sum#66, count#67] -Keys [1]: [i_product_name#56] -Functions [1]: [avg(qoh#63)] -Aggregate Attributes [1]: [avg(qoh#63)#68] -Results [5]: [i_product_name#56, null AS i_brand#69, null AS i_class#70, null AS i_category#71, avg(qoh#63)#68 AS qoh#72] +Input [3]: [i_product_name#60, sum#70, count#71] +Keys [1]: [i_product_name#60] +Functions [1]: [avg(qoh#67)] +Aggregate Attributes [1]: [avg(qoh#67)#72] +Results [5]: [i_product_name#60, null AS i_brand#73, null AS i_class#74, null AS i_category#75, avg(qoh#67)#72 AS qoh#76] (39) ReusedExchange [Reuses operator id: 20] -Output [6]: [i_product_name#73, i_brand#74, i_class#75, i_category#76, sum#77, count#78] +Output [6]: [i_product_name#77, i_brand#78, i_class#79, i_category#80, sum#81, count#82] (40) HashAggregate [codegen id : 28] -Input [6]: [i_product_name#73, i_brand#74, i_class#75, i_category#76, sum#77, count#78] -Keys [4]: [i_product_name#73, i_brand#74, i_class#75, i_category#76] -Functions [1]: [avg(inv_quantity_on_hand#79)] -Aggregate Attributes [1]: [avg(inv_quantity_on_hand#79)#17] -Results [1]: [avg(inv_quantity_on_hand#79)#17 AS qoh#80] +Input [6]: [i_product_name#77, i_brand#78, i_class#79, i_category#80, sum#81, count#82] +Keys [4]: [i_product_name#77, i_brand#78, i_class#79, i_category#80] +Functions [1]: [avg(inv_quantity_on_hand#83)] +Aggregate Attributes [1]: [avg(inv_quantity_on_hand#83)#17] +Results [1]: [avg(inv_quantity_on_hand#83)#17 AS qoh#84] (41) HashAggregate [codegen id : 28] -Input [1]: [qoh#80] +Input [1]: [qoh#84] Keys: [] -Functions [1]: [partial_avg(qoh#80)] -Aggregate Attributes [2]: [sum#81, count#82] -Results [2]: [sum#83, count#84] +Functions [1]: [partial_avg(qoh#84)] +Aggregate Attributes [2]: [sum#85, count#86] +Results [2]: [sum#87, count#88] (42) Exchange -Input [2]: [sum#83, count#84] +Input [2]: [sum#87, count#88] Arguments: SinglePartition, ENSURE_REQUIREMENTS, [plan_id=7] (43) HashAggregate [codegen id : 29] -Input [2]: [sum#83, count#84] +Input [2]: [sum#87, count#88] Keys: [] -Functions [1]: [avg(qoh#80)] -Aggregate Attributes [1]: [avg(qoh#80)#85] -Results [5]: [null AS i_product_name#86, null AS i_brand#87, null AS i_class#88, null AS i_category#89, avg(qoh#80)#85 AS qoh#90] +Functions [1]: [avg(qoh#84)] +Aggregate Attributes [1]: [avg(qoh#84)#89] +Results [5]: [null AS i_product_name#90, null AS i_brand#91, null AS i_class#92, null AS i_category#93, avg(qoh#84)#89 AS qoh#94] (44) Union (45) TakeOrderedAndProject -Input [5]: [i_product_name#11, i_brand#8, i_class#9, i_category#10, qoh#24] -Arguments: 100, [qoh#24 ASC NULLS FIRST, i_product_name#11 ASC NULLS FIRST, i_brand#8 ASC NULLS FIRST, i_class#9 ASC NULLS FIRST, i_category#10 ASC NULLS FIRST], [i_product_name#11, i_brand#8, i_class#9, i_category#10, qoh#24] +Input [5]: [i_product_name#24, i_brand#25, i_class#26, i_category#27, qoh#28] +Arguments: 100, [qoh#28 ASC NULLS FIRST, i_product_name#24 ASC NULLS FIRST, i_brand#25 ASC NULLS FIRST, i_class#26 ASC NULLS FIRST, i_category#27 ASC NULLS FIRST], [i_product_name#24, i_brand#25, i_class#26, i_category#27, qoh#28] ===== Subqueries ===== @@ -291,22 +291,22 @@ BroadcastExchange (50) (46) Scan parquet spark_catalog.default.date_dim -Output [2]: [d_date_sk#6, d_month_seq#91] +Output [2]: [d_date_sk#6, d_month_seq#95] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_month_seq), GreaterThanOrEqual(d_month_seq,1212), LessThanOrEqual(d_month_seq,1223), IsNotNull(d_date_sk)] ReadSchema: struct (47) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#6, d_month_seq#91] +Input [2]: [d_date_sk#6, d_month_seq#95] (48) Filter [codegen id : 1] -Input [2]: [d_date_sk#6, d_month_seq#91] -Condition : (((isnotnull(d_month_seq#91) AND (d_month_seq#91 >= 1212)) AND (d_month_seq#91 <= 1223)) AND isnotnull(d_date_sk#6)) +Input [2]: [d_date_sk#6, d_month_seq#95] +Condition : (((isnotnull(d_month_seq#95) AND (d_month_seq#95 >= 1212)) AND (d_month_seq#95 <= 1223)) AND isnotnull(d_date_sk#6)) (49) Project [codegen id : 1] Output [1]: [d_date_sk#6] -Input [2]: [d_date_sk#6, d_month_seq#91] +Input [2]: [d_date_sk#6, d_month_seq#95] (50) BroadcastExchange Input [1]: [d_date_sk#6] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a/simplified.txt index 22f73cc9b9db5..d747066f5945b 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a/simplified.txt @@ -1,7 +1,7 @@ TakeOrderedAndProject [qoh,i_product_name,i_brand,i_class,i_category] Union WholeStageCodegen (5) - HashAggregate [i_product_name,i_brand,i_class,i_category,sum,count] [avg(qoh),qoh,sum,count] + HashAggregate [i_product_name,i_brand,i_class,i_category,sum,count] [avg(qoh),i_product_name,i_brand,i_class,i_category,qoh,sum,count] HashAggregate [i_product_name,i_brand,i_class,i_category,qoh] [sum,count,sum,count] HashAggregate [i_product_name,i_brand,i_class,i_category,sum,count] [avg(inv_quantity_on_hand),qoh,sum,count] InputAdapter diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/explain.txt index 9c28ff9f351d8..a4c009f8219b4 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/explain.txt @@ -186,265 +186,265 @@ Input [10]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, Keys [8]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12] Functions [1]: [sum(coalesce((ss_sales_price#4 * cast(ss_quantity#3 as decimal(10,0))), 0.00))] Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#4 * cast(ss_quantity#3 as decimal(10,0))), 0.00))#22] -Results [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, cast(sum(coalesce((ss_sales_price#4 * cast(ss_quantity#3 as decimal(10,0))), 0.00))#22 as decimal(38,2)) AS sumsales#23] +Results [9]: [i_category#16 AS i_category#23, i_class#15 AS i_class#24, i_brand#14 AS i_brand#25, i_product_name#17 AS i_product_name#26, d_year#8 AS d_year#27, d_qoy#10 AS d_qoy#28, d_moy#9 AS d_moy#29, s_store_id#12 AS s_store_id#30, cast(sum(coalesce((ss_sales_price#4 * cast(ss_quantity#3 as decimal(10,0))), 0.00))#22 as decimal(38,2)) AS sumsales#31] (25) ReusedExchange [Reuses operator id: 23] -Output [10]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, s_store_id#31, sum#32, isEmpty#33] +Output [10]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, s_store_id#39, sum#40, isEmpty#41] (26) HashAggregate [codegen id : 16] -Input [10]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, s_store_id#31, sum#32, isEmpty#33] -Keys [8]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, s_store_id#31] -Functions [1]: [sum(coalesce((ss_sales_price#34 * cast(ss_quantity#35 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#34 * cast(ss_quantity#35 as decimal(10,0))), 0.00))#22] -Results [8]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sum(coalesce((ss_sales_price#34 * cast(ss_quantity#35 as decimal(10,0))), 0.00))#22 AS sumsales#36] +Input [10]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, s_store_id#39, sum#40, isEmpty#41] +Keys [8]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, s_store_id#39] +Functions [1]: [sum(coalesce((ss_sales_price#42 * cast(ss_quantity#43 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#42 * cast(ss_quantity#43 as decimal(10,0))), 0.00))#22] +Results [8]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sum(coalesce((ss_sales_price#42 * cast(ss_quantity#43 as decimal(10,0))), 0.00))#22 AS sumsales#44] (27) HashAggregate [codegen id : 16] -Input [8]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sumsales#36] -Keys [7]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30] -Functions [1]: [partial_sum(sumsales#36)] -Aggregate Attributes [2]: [sum#37, isEmpty#38] -Results [9]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sum#39, isEmpty#40] +Input [8]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sumsales#44] +Keys [7]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38] +Functions [1]: [partial_sum(sumsales#44)] +Aggregate Attributes [2]: [sum#45, isEmpty#46] +Results [9]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sum#47, isEmpty#48] (28) Exchange -Input [9]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sum#39, isEmpty#40] -Arguments: hashpartitioning(i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, 5), ENSURE_REQUIREMENTS, [plan_id=5] +Input [9]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sum#47, isEmpty#48] +Arguments: hashpartitioning(i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, 5), ENSURE_REQUIREMENTS, [plan_id=5] (29) HashAggregate [codegen id : 17] -Input [9]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sum#39, isEmpty#40] -Keys [7]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30] -Functions [1]: [sum(sumsales#36)] -Aggregate Attributes [1]: [sum(sumsales#36)#41] -Results [9]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, null AS s_store_id#42, sum(sumsales#36)#41 AS sumsales#43] +Input [9]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sum#47, isEmpty#48] +Keys [7]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38] +Functions [1]: [sum(sumsales#44)] +Aggregate Attributes [1]: [sum(sumsales#44)#49] +Results [9]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, null AS s_store_id#50, sum(sumsales#44)#49 AS sumsales#51] (30) ReusedExchange [Reuses operator id: 23] -Output [10]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, d_moy#50, s_store_id#51, sum#52, isEmpty#53] +Output [10]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, d_moy#58, s_store_id#59, sum#60, isEmpty#61] (31) HashAggregate [codegen id : 25] -Input [10]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, d_moy#50, s_store_id#51, sum#52, isEmpty#53] -Keys [8]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, d_moy#50, s_store_id#51] -Functions [1]: [sum(coalesce((ss_sales_price#54 * cast(ss_quantity#55 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#54 * cast(ss_quantity#55 as decimal(10,0))), 0.00))#22] -Results [7]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sum(coalesce((ss_sales_price#54 * cast(ss_quantity#55 as decimal(10,0))), 0.00))#22 AS sumsales#56] +Input [10]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, d_moy#58, s_store_id#59, sum#60, isEmpty#61] +Keys [8]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, d_moy#58, s_store_id#59] +Functions [1]: [sum(coalesce((ss_sales_price#62 * cast(ss_quantity#63 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#62 * cast(ss_quantity#63 as decimal(10,0))), 0.00))#22] +Results [7]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sum(coalesce((ss_sales_price#62 * cast(ss_quantity#63 as decimal(10,0))), 0.00))#22 AS sumsales#64] (32) HashAggregate [codegen id : 25] -Input [7]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sumsales#56] -Keys [6]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49] -Functions [1]: [partial_sum(sumsales#56)] -Aggregate Attributes [2]: [sum#57, isEmpty#58] -Results [8]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sum#59, isEmpty#60] +Input [7]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sumsales#64] +Keys [6]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57] +Functions [1]: [partial_sum(sumsales#64)] +Aggregate Attributes [2]: [sum#65, isEmpty#66] +Results [8]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sum#67, isEmpty#68] (33) Exchange -Input [8]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sum#59, isEmpty#60] -Arguments: hashpartitioning(i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, 5), ENSURE_REQUIREMENTS, [plan_id=6] +Input [8]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sum#67, isEmpty#68] +Arguments: hashpartitioning(i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, 5), ENSURE_REQUIREMENTS, [plan_id=6] (34) HashAggregate [codegen id : 26] -Input [8]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sum#59, isEmpty#60] -Keys [6]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49] -Functions [1]: [sum(sumsales#56)] -Aggregate Attributes [1]: [sum(sumsales#56)#61] -Results [9]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, null AS d_moy#62, null AS s_store_id#63, sum(sumsales#56)#61 AS sumsales#64] +Input [8]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sum#67, isEmpty#68] +Keys [6]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57] +Functions [1]: [sum(sumsales#64)] +Aggregate Attributes [1]: [sum(sumsales#64)#69] +Results [9]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, null AS d_moy#70, null AS s_store_id#71, sum(sumsales#64)#69 AS sumsales#72] (35) ReusedExchange [Reuses operator id: 23] -Output [10]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, d_qoy#70, d_moy#71, s_store_id#72, sum#73, isEmpty#74] +Output [10]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, d_qoy#78, d_moy#79, s_store_id#80, sum#81, isEmpty#82] (36) HashAggregate [codegen id : 34] -Input [10]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, d_qoy#70, d_moy#71, s_store_id#72, sum#73, isEmpty#74] -Keys [8]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, d_qoy#70, d_moy#71, s_store_id#72] -Functions [1]: [sum(coalesce((ss_sales_price#75 * cast(ss_quantity#76 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#75 * cast(ss_quantity#76 as decimal(10,0))), 0.00))#22] -Results [6]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sum(coalesce((ss_sales_price#75 * cast(ss_quantity#76 as decimal(10,0))), 0.00))#22 AS sumsales#77] +Input [10]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, d_qoy#78, d_moy#79, s_store_id#80, sum#81, isEmpty#82] +Keys [8]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, d_qoy#78, d_moy#79, s_store_id#80] +Functions [1]: [sum(coalesce((ss_sales_price#83 * cast(ss_quantity#84 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#83 * cast(ss_quantity#84 as decimal(10,0))), 0.00))#22] +Results [6]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sum(coalesce((ss_sales_price#83 * cast(ss_quantity#84 as decimal(10,0))), 0.00))#22 AS sumsales#85] (37) HashAggregate [codegen id : 34] -Input [6]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sumsales#77] -Keys [5]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69] -Functions [1]: [partial_sum(sumsales#77)] -Aggregate Attributes [2]: [sum#78, isEmpty#79] -Results [7]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sum#80, isEmpty#81] +Input [6]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sumsales#85] +Keys [5]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77] +Functions [1]: [partial_sum(sumsales#85)] +Aggregate Attributes [2]: [sum#86, isEmpty#87] +Results [7]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sum#88, isEmpty#89] (38) Exchange -Input [7]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sum#80, isEmpty#81] -Arguments: hashpartitioning(i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, 5), ENSURE_REQUIREMENTS, [plan_id=7] +Input [7]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sum#88, isEmpty#89] +Arguments: hashpartitioning(i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, 5), ENSURE_REQUIREMENTS, [plan_id=7] (39) HashAggregate [codegen id : 35] -Input [7]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sum#80, isEmpty#81] -Keys [5]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69] -Functions [1]: [sum(sumsales#77)] -Aggregate Attributes [1]: [sum(sumsales#77)#82] -Results [9]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, null AS d_qoy#83, null AS d_moy#84, null AS s_store_id#85, sum(sumsales#77)#82 AS sumsales#86] +Input [7]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sum#88, isEmpty#89] +Keys [5]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77] +Functions [1]: [sum(sumsales#85)] +Aggregate Attributes [1]: [sum(sumsales#85)#90] +Results [9]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, null AS d_qoy#91, null AS d_moy#92, null AS s_store_id#93, sum(sumsales#85)#90 AS sumsales#94] (40) ReusedExchange [Reuses operator id: 23] -Output [10]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, d_year#91, d_qoy#92, d_moy#93, s_store_id#94, sum#95, isEmpty#96] +Output [10]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, d_year#99, d_qoy#100, d_moy#101, s_store_id#102, sum#103, isEmpty#104] (41) HashAggregate [codegen id : 43] -Input [10]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, d_year#91, d_qoy#92, d_moy#93, s_store_id#94, sum#95, isEmpty#96] -Keys [8]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, d_year#91, d_qoy#92, d_moy#93, s_store_id#94] -Functions [1]: [sum(coalesce((ss_sales_price#97 * cast(ss_quantity#98 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#97 * cast(ss_quantity#98 as decimal(10,0))), 0.00))#22] -Results [5]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sum(coalesce((ss_sales_price#97 * cast(ss_quantity#98 as decimal(10,0))), 0.00))#22 AS sumsales#99] +Input [10]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, d_year#99, d_qoy#100, d_moy#101, s_store_id#102, sum#103, isEmpty#104] +Keys [8]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, d_year#99, d_qoy#100, d_moy#101, s_store_id#102] +Functions [1]: [sum(coalesce((ss_sales_price#105 * cast(ss_quantity#106 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#105 * cast(ss_quantity#106 as decimal(10,0))), 0.00))#22] +Results [5]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sum(coalesce((ss_sales_price#105 * cast(ss_quantity#106 as decimal(10,0))), 0.00))#22 AS sumsales#107] (42) HashAggregate [codegen id : 43] -Input [5]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sumsales#99] -Keys [4]: [i_category#87, i_class#88, i_brand#89, i_product_name#90] -Functions [1]: [partial_sum(sumsales#99)] -Aggregate Attributes [2]: [sum#100, isEmpty#101] -Results [6]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sum#102, isEmpty#103] +Input [5]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sumsales#107] +Keys [4]: [i_category#95, i_class#96, i_brand#97, i_product_name#98] +Functions [1]: [partial_sum(sumsales#107)] +Aggregate Attributes [2]: [sum#108, isEmpty#109] +Results [6]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sum#110, isEmpty#111] (43) Exchange -Input [6]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sum#102, isEmpty#103] -Arguments: hashpartitioning(i_category#87, i_class#88, i_brand#89, i_product_name#90, 5), ENSURE_REQUIREMENTS, [plan_id=8] +Input [6]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sum#110, isEmpty#111] +Arguments: hashpartitioning(i_category#95, i_class#96, i_brand#97, i_product_name#98, 5), ENSURE_REQUIREMENTS, [plan_id=8] (44) HashAggregate [codegen id : 44] -Input [6]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sum#102, isEmpty#103] -Keys [4]: [i_category#87, i_class#88, i_brand#89, i_product_name#90] -Functions [1]: [sum(sumsales#99)] -Aggregate Attributes [1]: [sum(sumsales#99)#104] -Results [9]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, null AS d_year#105, null AS d_qoy#106, null AS d_moy#107, null AS s_store_id#108, sum(sumsales#99)#104 AS sumsales#109] +Input [6]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sum#110, isEmpty#111] +Keys [4]: [i_category#95, i_class#96, i_brand#97, i_product_name#98] +Functions [1]: [sum(sumsales#107)] +Aggregate Attributes [1]: [sum(sumsales#107)#112] +Results [9]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, null AS d_year#113, null AS d_qoy#114, null AS d_moy#115, null AS s_store_id#116, sum(sumsales#107)#112 AS sumsales#117] (45) ReusedExchange [Reuses operator id: 23] -Output [10]: [i_category#110, i_class#111, i_brand#112, i_product_name#113, d_year#114, d_qoy#115, d_moy#116, s_store_id#117, sum#118, isEmpty#119] +Output [10]: [i_category#118, i_class#119, i_brand#120, i_product_name#121, d_year#122, d_qoy#123, d_moy#124, s_store_id#125, sum#126, isEmpty#127] (46) HashAggregate [codegen id : 52] -Input [10]: [i_category#110, i_class#111, i_brand#112, i_product_name#113, d_year#114, d_qoy#115, d_moy#116, s_store_id#117, sum#118, isEmpty#119] -Keys [8]: [i_category#110, i_class#111, i_brand#112, i_product_name#113, d_year#114, d_qoy#115, d_moy#116, s_store_id#117] -Functions [1]: [sum(coalesce((ss_sales_price#120 * cast(ss_quantity#121 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#120 * cast(ss_quantity#121 as decimal(10,0))), 0.00))#22] -Results [4]: [i_category#110, i_class#111, i_brand#112, sum(coalesce((ss_sales_price#120 * cast(ss_quantity#121 as decimal(10,0))), 0.00))#22 AS sumsales#122] +Input [10]: [i_category#118, i_class#119, i_brand#120, i_product_name#121, d_year#122, d_qoy#123, d_moy#124, s_store_id#125, sum#126, isEmpty#127] +Keys [8]: [i_category#118, i_class#119, i_brand#120, i_product_name#121, d_year#122, d_qoy#123, d_moy#124, s_store_id#125] +Functions [1]: [sum(coalesce((ss_sales_price#128 * cast(ss_quantity#129 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#128 * cast(ss_quantity#129 as decimal(10,0))), 0.00))#22] +Results [4]: [i_category#118, i_class#119, i_brand#120, sum(coalesce((ss_sales_price#128 * cast(ss_quantity#129 as decimal(10,0))), 0.00))#22 AS sumsales#130] (47) HashAggregate [codegen id : 52] -Input [4]: [i_category#110, i_class#111, i_brand#112, sumsales#122] -Keys [3]: [i_category#110, i_class#111, i_brand#112] -Functions [1]: [partial_sum(sumsales#122)] -Aggregate Attributes [2]: [sum#123, isEmpty#124] -Results [5]: [i_category#110, i_class#111, i_brand#112, sum#125, isEmpty#126] +Input [4]: [i_category#118, i_class#119, i_brand#120, sumsales#130] +Keys [3]: [i_category#118, i_class#119, i_brand#120] +Functions [1]: [partial_sum(sumsales#130)] +Aggregate Attributes [2]: [sum#131, isEmpty#132] +Results [5]: [i_category#118, i_class#119, i_brand#120, sum#133, isEmpty#134] (48) Exchange -Input [5]: [i_category#110, i_class#111, i_brand#112, sum#125, isEmpty#126] -Arguments: hashpartitioning(i_category#110, i_class#111, i_brand#112, 5), ENSURE_REQUIREMENTS, [plan_id=9] +Input [5]: [i_category#118, i_class#119, i_brand#120, sum#133, isEmpty#134] +Arguments: hashpartitioning(i_category#118, i_class#119, i_brand#120, 5), ENSURE_REQUIREMENTS, [plan_id=9] (49) HashAggregate [codegen id : 53] -Input [5]: [i_category#110, i_class#111, i_brand#112, sum#125, isEmpty#126] -Keys [3]: [i_category#110, i_class#111, i_brand#112] -Functions [1]: [sum(sumsales#122)] -Aggregate Attributes [1]: [sum(sumsales#122)#127] -Results [9]: [i_category#110, i_class#111, i_brand#112, null AS i_product_name#128, null AS d_year#129, null AS d_qoy#130, null AS d_moy#131, null AS s_store_id#132, sum(sumsales#122)#127 AS sumsales#133] +Input [5]: [i_category#118, i_class#119, i_brand#120, sum#133, isEmpty#134] +Keys [3]: [i_category#118, i_class#119, i_brand#120] +Functions [1]: [sum(sumsales#130)] +Aggregate Attributes [1]: [sum(sumsales#130)#135] +Results [9]: [i_category#118, i_class#119, i_brand#120, null AS i_product_name#136, null AS d_year#137, null AS d_qoy#138, null AS d_moy#139, null AS s_store_id#140, sum(sumsales#130)#135 AS sumsales#141] (50) ReusedExchange [Reuses operator id: 23] -Output [10]: [i_category#134, i_class#135, i_brand#136, i_product_name#137, d_year#138, d_qoy#139, d_moy#140, s_store_id#141, sum#142, isEmpty#143] +Output [10]: [i_category#142, i_class#143, i_brand#144, i_product_name#145, d_year#146, d_qoy#147, d_moy#148, s_store_id#149, sum#150, isEmpty#151] (51) HashAggregate [codegen id : 61] -Input [10]: [i_category#134, i_class#135, i_brand#136, i_product_name#137, d_year#138, d_qoy#139, d_moy#140, s_store_id#141, sum#142, isEmpty#143] -Keys [8]: [i_category#134, i_class#135, i_brand#136, i_product_name#137, d_year#138, d_qoy#139, d_moy#140, s_store_id#141] -Functions [1]: [sum(coalesce((ss_sales_price#144 * cast(ss_quantity#145 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#144 * cast(ss_quantity#145 as decimal(10,0))), 0.00))#22] -Results [3]: [i_category#134, i_class#135, sum(coalesce((ss_sales_price#144 * cast(ss_quantity#145 as decimal(10,0))), 0.00))#22 AS sumsales#146] +Input [10]: [i_category#142, i_class#143, i_brand#144, i_product_name#145, d_year#146, d_qoy#147, d_moy#148, s_store_id#149, sum#150, isEmpty#151] +Keys [8]: [i_category#142, i_class#143, i_brand#144, i_product_name#145, d_year#146, d_qoy#147, d_moy#148, s_store_id#149] +Functions [1]: [sum(coalesce((ss_sales_price#152 * cast(ss_quantity#153 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#152 * cast(ss_quantity#153 as decimal(10,0))), 0.00))#22] +Results [3]: [i_category#142, i_class#143, sum(coalesce((ss_sales_price#152 * cast(ss_quantity#153 as decimal(10,0))), 0.00))#22 AS sumsales#154] (52) HashAggregate [codegen id : 61] -Input [3]: [i_category#134, i_class#135, sumsales#146] -Keys [2]: [i_category#134, i_class#135] -Functions [1]: [partial_sum(sumsales#146)] -Aggregate Attributes [2]: [sum#147, isEmpty#148] -Results [4]: [i_category#134, i_class#135, sum#149, isEmpty#150] +Input [3]: [i_category#142, i_class#143, sumsales#154] +Keys [2]: [i_category#142, i_class#143] +Functions [1]: [partial_sum(sumsales#154)] +Aggregate Attributes [2]: [sum#155, isEmpty#156] +Results [4]: [i_category#142, i_class#143, sum#157, isEmpty#158] (53) Exchange -Input [4]: [i_category#134, i_class#135, sum#149, isEmpty#150] -Arguments: hashpartitioning(i_category#134, i_class#135, 5), ENSURE_REQUIREMENTS, [plan_id=10] +Input [4]: [i_category#142, i_class#143, sum#157, isEmpty#158] +Arguments: hashpartitioning(i_category#142, i_class#143, 5), ENSURE_REQUIREMENTS, [plan_id=10] (54) HashAggregate [codegen id : 62] -Input [4]: [i_category#134, i_class#135, sum#149, isEmpty#150] -Keys [2]: [i_category#134, i_class#135] -Functions [1]: [sum(sumsales#146)] -Aggregate Attributes [1]: [sum(sumsales#146)#151] -Results [9]: [i_category#134, i_class#135, null AS i_brand#152, null AS i_product_name#153, null AS d_year#154, null AS d_qoy#155, null AS d_moy#156, null AS s_store_id#157, sum(sumsales#146)#151 AS sumsales#158] +Input [4]: [i_category#142, i_class#143, sum#157, isEmpty#158] +Keys [2]: [i_category#142, i_class#143] +Functions [1]: [sum(sumsales#154)] +Aggregate Attributes [1]: [sum(sumsales#154)#159] +Results [9]: [i_category#142, i_class#143, null AS i_brand#160, null AS i_product_name#161, null AS d_year#162, null AS d_qoy#163, null AS d_moy#164, null AS s_store_id#165, sum(sumsales#154)#159 AS sumsales#166] (55) ReusedExchange [Reuses operator id: 23] -Output [10]: [i_category#159, i_class#160, i_brand#161, i_product_name#162, d_year#163, d_qoy#164, d_moy#165, s_store_id#166, sum#167, isEmpty#168] +Output [10]: [i_category#167, i_class#168, i_brand#169, i_product_name#170, d_year#171, d_qoy#172, d_moy#173, s_store_id#174, sum#175, isEmpty#176] (56) HashAggregate [codegen id : 70] -Input [10]: [i_category#159, i_class#160, i_brand#161, i_product_name#162, d_year#163, d_qoy#164, d_moy#165, s_store_id#166, sum#167, isEmpty#168] -Keys [8]: [i_category#159, i_class#160, i_brand#161, i_product_name#162, d_year#163, d_qoy#164, d_moy#165, s_store_id#166] -Functions [1]: [sum(coalesce((ss_sales_price#169 * cast(ss_quantity#170 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#169 * cast(ss_quantity#170 as decimal(10,0))), 0.00))#22] -Results [2]: [i_category#159, sum(coalesce((ss_sales_price#169 * cast(ss_quantity#170 as decimal(10,0))), 0.00))#22 AS sumsales#171] +Input [10]: [i_category#167, i_class#168, i_brand#169, i_product_name#170, d_year#171, d_qoy#172, d_moy#173, s_store_id#174, sum#175, isEmpty#176] +Keys [8]: [i_category#167, i_class#168, i_brand#169, i_product_name#170, d_year#171, d_qoy#172, d_moy#173, s_store_id#174] +Functions [1]: [sum(coalesce((ss_sales_price#177 * cast(ss_quantity#178 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#177 * cast(ss_quantity#178 as decimal(10,0))), 0.00))#22] +Results [2]: [i_category#167, sum(coalesce((ss_sales_price#177 * cast(ss_quantity#178 as decimal(10,0))), 0.00))#22 AS sumsales#179] (57) HashAggregate [codegen id : 70] -Input [2]: [i_category#159, sumsales#171] -Keys [1]: [i_category#159] -Functions [1]: [partial_sum(sumsales#171)] -Aggregate Attributes [2]: [sum#172, isEmpty#173] -Results [3]: [i_category#159, sum#174, isEmpty#175] +Input [2]: [i_category#167, sumsales#179] +Keys [1]: [i_category#167] +Functions [1]: [partial_sum(sumsales#179)] +Aggregate Attributes [2]: [sum#180, isEmpty#181] +Results [3]: [i_category#167, sum#182, isEmpty#183] (58) Exchange -Input [3]: [i_category#159, sum#174, isEmpty#175] -Arguments: hashpartitioning(i_category#159, 5), ENSURE_REQUIREMENTS, [plan_id=11] +Input [3]: [i_category#167, sum#182, isEmpty#183] +Arguments: hashpartitioning(i_category#167, 5), ENSURE_REQUIREMENTS, [plan_id=11] (59) HashAggregate [codegen id : 71] -Input [3]: [i_category#159, sum#174, isEmpty#175] -Keys [1]: [i_category#159] -Functions [1]: [sum(sumsales#171)] -Aggregate Attributes [1]: [sum(sumsales#171)#176] -Results [9]: [i_category#159, null AS i_class#177, null AS i_brand#178, null AS i_product_name#179, null AS d_year#180, null AS d_qoy#181, null AS d_moy#182, null AS s_store_id#183, sum(sumsales#171)#176 AS sumsales#184] +Input [3]: [i_category#167, sum#182, isEmpty#183] +Keys [1]: [i_category#167] +Functions [1]: [sum(sumsales#179)] +Aggregate Attributes [1]: [sum(sumsales#179)#184] +Results [9]: [i_category#167, null AS i_class#185, null AS i_brand#186, null AS i_product_name#187, null AS d_year#188, null AS d_qoy#189, null AS d_moy#190, null AS s_store_id#191, sum(sumsales#179)#184 AS sumsales#192] (60) ReusedExchange [Reuses operator id: 23] -Output [10]: [i_category#185, i_class#186, i_brand#187, i_product_name#188, d_year#189, d_qoy#190, d_moy#191, s_store_id#192, sum#193, isEmpty#194] +Output [10]: [i_category#193, i_class#194, i_brand#195, i_product_name#196, d_year#197, d_qoy#198, d_moy#199, s_store_id#200, sum#201, isEmpty#202] (61) HashAggregate [codegen id : 79] -Input [10]: [i_category#185, i_class#186, i_brand#187, i_product_name#188, d_year#189, d_qoy#190, d_moy#191, s_store_id#192, sum#193, isEmpty#194] -Keys [8]: [i_category#185, i_class#186, i_brand#187, i_product_name#188, d_year#189, d_qoy#190, d_moy#191, s_store_id#192] -Functions [1]: [sum(coalesce((ss_sales_price#195 * cast(ss_quantity#196 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#195 * cast(ss_quantity#196 as decimal(10,0))), 0.00))#22] -Results [1]: [sum(coalesce((ss_sales_price#195 * cast(ss_quantity#196 as decimal(10,0))), 0.00))#22 AS sumsales#197] +Input [10]: [i_category#193, i_class#194, i_brand#195, i_product_name#196, d_year#197, d_qoy#198, d_moy#199, s_store_id#200, sum#201, isEmpty#202] +Keys [8]: [i_category#193, i_class#194, i_brand#195, i_product_name#196, d_year#197, d_qoy#198, d_moy#199, s_store_id#200] +Functions [1]: [sum(coalesce((ss_sales_price#203 * cast(ss_quantity#204 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#203 * cast(ss_quantity#204 as decimal(10,0))), 0.00))#22] +Results [1]: [sum(coalesce((ss_sales_price#203 * cast(ss_quantity#204 as decimal(10,0))), 0.00))#22 AS sumsales#205] (62) HashAggregate [codegen id : 79] -Input [1]: [sumsales#197] +Input [1]: [sumsales#205] Keys: [] -Functions [1]: [partial_sum(sumsales#197)] -Aggregate Attributes [2]: [sum#198, isEmpty#199] -Results [2]: [sum#200, isEmpty#201] +Functions [1]: [partial_sum(sumsales#205)] +Aggregate Attributes [2]: [sum#206, isEmpty#207] +Results [2]: [sum#208, isEmpty#209] (63) Exchange -Input [2]: [sum#200, isEmpty#201] +Input [2]: [sum#208, isEmpty#209] Arguments: SinglePartition, ENSURE_REQUIREMENTS, [plan_id=12] (64) HashAggregate [codegen id : 80] -Input [2]: [sum#200, isEmpty#201] +Input [2]: [sum#208, isEmpty#209] Keys: [] -Functions [1]: [sum(sumsales#197)] -Aggregate Attributes [1]: [sum(sumsales#197)#202] -Results [9]: [null AS i_category#203, null AS i_class#204, null AS i_brand#205, null AS i_product_name#206, null AS d_year#207, null AS d_qoy#208, null AS d_moy#209, null AS s_store_id#210, sum(sumsales#197)#202 AS sumsales#211] +Functions [1]: [sum(sumsales#205)] +Aggregate Attributes [1]: [sum(sumsales#205)#210] +Results [9]: [null AS i_category#211, null AS i_class#212, null AS i_brand#213, null AS i_product_name#214, null AS d_year#215, null AS d_qoy#216, null AS d_moy#217, null AS s_store_id#218, sum(sumsales#205)#210 AS sumsales#219] (65) Union (66) Sort [codegen id : 81] -Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] -Arguments: [i_category#16 ASC NULLS FIRST, sumsales#23 DESC NULLS LAST], false, 0 +Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] +Arguments: [i_category#23 ASC NULLS FIRST, sumsales#31 DESC NULLS LAST], false, 0 (67) WindowGroupLimit -Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] -Arguments: [i_category#16], [sumsales#23 DESC NULLS LAST], rank(sumsales#23), 100, Partial +Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] +Arguments: [i_category#23], [sumsales#31 DESC NULLS LAST], rank(sumsales#31), 100, Partial (68) Exchange -Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] -Arguments: hashpartitioning(i_category#16, 5), ENSURE_REQUIREMENTS, [plan_id=13] +Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] +Arguments: hashpartitioning(i_category#23, 5), ENSURE_REQUIREMENTS, [plan_id=13] (69) Sort [codegen id : 82] -Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] -Arguments: [i_category#16 ASC NULLS FIRST, sumsales#23 DESC NULLS LAST], false, 0 +Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] +Arguments: [i_category#23 ASC NULLS FIRST, sumsales#31 DESC NULLS LAST], false, 0 (70) WindowGroupLimit -Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] -Arguments: [i_category#16], [sumsales#23 DESC NULLS LAST], rank(sumsales#23), 100, Final +Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] +Arguments: [i_category#23], [sumsales#31 DESC NULLS LAST], rank(sumsales#31), 100, Final (71) Window -Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] -Arguments: [rank(sumsales#23) windowspecdefinition(i_category#16, sumsales#23 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rk#212], [i_category#16], [sumsales#23 DESC NULLS LAST] +Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] +Arguments: [rank(sumsales#31) windowspecdefinition(i_category#23, sumsales#31 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rk#220], [i_category#23], [sumsales#31 DESC NULLS LAST] (72) Filter [codegen id : 83] -Input [10]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23, rk#212] -Condition : (rk#212 <= 100) +Input [10]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31, rk#220] +Condition : (rk#220 <= 100) (73) TakeOrderedAndProject -Input [10]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23, rk#212] -Arguments: 100, [i_category#16 ASC NULLS FIRST, i_class#15 ASC NULLS FIRST, i_brand#14 ASC NULLS FIRST, i_product_name#17 ASC NULLS FIRST, d_year#8 ASC NULLS FIRST, d_qoy#10 ASC NULLS FIRST, d_moy#9 ASC NULLS FIRST, s_store_id#12 ASC NULLS FIRST, sumsales#23 ASC NULLS FIRST, rk#212 ASC NULLS FIRST], [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23, rk#212] +Input [10]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31, rk#220] +Arguments: 100, [i_category#23 ASC NULLS FIRST, i_class#24 ASC NULLS FIRST, i_brand#25 ASC NULLS FIRST, i_product_name#26 ASC NULLS FIRST, d_year#27 ASC NULLS FIRST, d_qoy#28 ASC NULLS FIRST, d_moy#29 ASC NULLS FIRST, s_store_id#30 ASC NULLS FIRST, sumsales#31 ASC NULLS FIRST, rk#220 ASC NULLS FIRST], [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31, rk#220] ===== Subqueries ===== @@ -457,22 +457,22 @@ BroadcastExchange (78) (74) Scan parquet spark_catalog.default.date_dim -Output [5]: [d_date_sk#7, d_month_seq#213, d_year#8, d_moy#9, d_qoy#10] +Output [5]: [d_date_sk#7, d_month_seq#221, d_year#8, d_moy#9, d_qoy#10] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_month_seq), GreaterThanOrEqual(d_month_seq,1212), LessThanOrEqual(d_month_seq,1223), IsNotNull(d_date_sk)] ReadSchema: struct (75) ColumnarToRow [codegen id : 1] -Input [5]: [d_date_sk#7, d_month_seq#213, d_year#8, d_moy#9, d_qoy#10] +Input [5]: [d_date_sk#7, d_month_seq#221, d_year#8, d_moy#9, d_qoy#10] (76) Filter [codegen id : 1] -Input [5]: [d_date_sk#7, d_month_seq#213, d_year#8, d_moy#9, d_qoy#10] -Condition : (((isnotnull(d_month_seq#213) AND (d_month_seq#213 >= 1212)) AND (d_month_seq#213 <= 1223)) AND isnotnull(d_date_sk#7)) +Input [5]: [d_date_sk#7, d_month_seq#221, d_year#8, d_moy#9, d_qoy#10] +Condition : (((isnotnull(d_month_seq#221) AND (d_month_seq#221 >= 1212)) AND (d_month_seq#221 <= 1223)) AND isnotnull(d_date_sk#7)) (77) Project [codegen id : 1] Output [4]: [d_date_sk#7, d_year#8, d_moy#9, d_qoy#10] -Input [5]: [d_date_sk#7, d_month_seq#213, d_year#8, d_moy#9, d_qoy#10] +Input [5]: [d_date_sk#7, d_month_seq#221, d_year#8, d_moy#9, d_qoy#10] (78) BroadcastExchange Input [4]: [d_date_sk#7, d_year#8, d_moy#9, d_qoy#10] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/simplified.txt index 795fa297b9bad..b6a4358c4d43b 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/simplified.txt @@ -14,7 +14,7 @@ TakeOrderedAndProject [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_ InputAdapter Union WholeStageCodegen (8) - HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce((ss_sales_price * cast(ss_quantity as decimal(10,0))), 0.00)),sumsales,sum,isEmpty] + HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce((ss_sales_price * cast(ss_quantity as decimal(10,0))), 0.00)),i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sumsales,sum,isEmpty] InputAdapter Exchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id] #2 WholeStageCodegen (7) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/explain.txt index 75d526da4ba71..417af4fe924ee 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/explain.txt @@ -171,265 +171,265 @@ Input [10]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, Keys [8]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12] Functions [1]: [sum(coalesce((ss_sales_price#4 * cast(ss_quantity#3 as decimal(10,0))), 0.00))] Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#4 * cast(ss_quantity#3 as decimal(10,0))), 0.00))#22] -Results [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, cast(sum(coalesce((ss_sales_price#4 * cast(ss_quantity#3 as decimal(10,0))), 0.00))#22 as decimal(38,2)) AS sumsales#23] +Results [9]: [i_category#16 AS i_category#23, i_class#15 AS i_class#24, i_brand#14 AS i_brand#25, i_product_name#17 AS i_product_name#26, d_year#8 AS d_year#27, d_qoy#10 AS d_qoy#28, d_moy#9 AS d_moy#29, s_store_id#12 AS s_store_id#30, cast(sum(coalesce((ss_sales_price#4 * cast(ss_quantity#3 as decimal(10,0))), 0.00))#22 as decimal(38,2)) AS sumsales#31] (22) ReusedExchange [Reuses operator id: 20] -Output [10]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, s_store_id#31, sum#32, isEmpty#33] +Output [10]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, s_store_id#39, sum#40, isEmpty#41] (23) HashAggregate [codegen id : 10] -Input [10]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, s_store_id#31, sum#32, isEmpty#33] -Keys [8]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, s_store_id#31] -Functions [1]: [sum(coalesce((ss_sales_price#34 * cast(ss_quantity#35 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#34 * cast(ss_quantity#35 as decimal(10,0))), 0.00))#22] -Results [8]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sum(coalesce((ss_sales_price#34 * cast(ss_quantity#35 as decimal(10,0))), 0.00))#22 AS sumsales#36] +Input [10]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, s_store_id#39, sum#40, isEmpty#41] +Keys [8]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, s_store_id#39] +Functions [1]: [sum(coalesce((ss_sales_price#42 * cast(ss_quantity#43 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#42 * cast(ss_quantity#43 as decimal(10,0))), 0.00))#22] +Results [8]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sum(coalesce((ss_sales_price#42 * cast(ss_quantity#43 as decimal(10,0))), 0.00))#22 AS sumsales#44] (24) HashAggregate [codegen id : 10] -Input [8]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sumsales#36] -Keys [7]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30] -Functions [1]: [partial_sum(sumsales#36)] -Aggregate Attributes [2]: [sum#37, isEmpty#38] -Results [9]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sum#39, isEmpty#40] +Input [8]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sumsales#44] +Keys [7]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38] +Functions [1]: [partial_sum(sumsales#44)] +Aggregate Attributes [2]: [sum#45, isEmpty#46] +Results [9]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sum#47, isEmpty#48] (25) Exchange -Input [9]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sum#39, isEmpty#40] -Arguments: hashpartitioning(i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, 5), ENSURE_REQUIREMENTS, [plan_id=4] +Input [9]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sum#47, isEmpty#48] +Arguments: hashpartitioning(i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, 5), ENSURE_REQUIREMENTS, [plan_id=4] (26) HashAggregate [codegen id : 11] -Input [9]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sum#39, isEmpty#40] -Keys [7]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30] -Functions [1]: [sum(sumsales#36)] -Aggregate Attributes [1]: [sum(sumsales#36)#41] -Results [9]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, null AS s_store_id#42, sum(sumsales#36)#41 AS sumsales#43] +Input [9]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sum#47, isEmpty#48] +Keys [7]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38] +Functions [1]: [sum(sumsales#44)] +Aggregate Attributes [1]: [sum(sumsales#44)#49] +Results [9]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, null AS s_store_id#50, sum(sumsales#44)#49 AS sumsales#51] (27) ReusedExchange [Reuses operator id: 20] -Output [10]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, d_moy#50, s_store_id#51, sum#52, isEmpty#53] +Output [10]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, d_moy#58, s_store_id#59, sum#60, isEmpty#61] (28) HashAggregate [codegen id : 16] -Input [10]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, d_moy#50, s_store_id#51, sum#52, isEmpty#53] -Keys [8]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, d_moy#50, s_store_id#51] -Functions [1]: [sum(coalesce((ss_sales_price#54 * cast(ss_quantity#55 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#54 * cast(ss_quantity#55 as decimal(10,0))), 0.00))#22] -Results [7]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sum(coalesce((ss_sales_price#54 * cast(ss_quantity#55 as decimal(10,0))), 0.00))#22 AS sumsales#56] +Input [10]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, d_moy#58, s_store_id#59, sum#60, isEmpty#61] +Keys [8]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, d_moy#58, s_store_id#59] +Functions [1]: [sum(coalesce((ss_sales_price#62 * cast(ss_quantity#63 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#62 * cast(ss_quantity#63 as decimal(10,0))), 0.00))#22] +Results [7]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sum(coalesce((ss_sales_price#62 * cast(ss_quantity#63 as decimal(10,0))), 0.00))#22 AS sumsales#64] (29) HashAggregate [codegen id : 16] -Input [7]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sumsales#56] -Keys [6]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49] -Functions [1]: [partial_sum(sumsales#56)] -Aggregate Attributes [2]: [sum#57, isEmpty#58] -Results [8]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sum#59, isEmpty#60] +Input [7]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sumsales#64] +Keys [6]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57] +Functions [1]: [partial_sum(sumsales#64)] +Aggregate Attributes [2]: [sum#65, isEmpty#66] +Results [8]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sum#67, isEmpty#68] (30) Exchange -Input [8]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sum#59, isEmpty#60] -Arguments: hashpartitioning(i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, 5), ENSURE_REQUIREMENTS, [plan_id=5] +Input [8]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sum#67, isEmpty#68] +Arguments: hashpartitioning(i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, 5), ENSURE_REQUIREMENTS, [plan_id=5] (31) HashAggregate [codegen id : 17] -Input [8]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sum#59, isEmpty#60] -Keys [6]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49] -Functions [1]: [sum(sumsales#56)] -Aggregate Attributes [1]: [sum(sumsales#56)#61] -Results [9]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, null AS d_moy#62, null AS s_store_id#63, sum(sumsales#56)#61 AS sumsales#64] +Input [8]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sum#67, isEmpty#68] +Keys [6]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57] +Functions [1]: [sum(sumsales#64)] +Aggregate Attributes [1]: [sum(sumsales#64)#69] +Results [9]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, null AS d_moy#70, null AS s_store_id#71, sum(sumsales#64)#69 AS sumsales#72] (32) ReusedExchange [Reuses operator id: 20] -Output [10]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, d_qoy#70, d_moy#71, s_store_id#72, sum#73, isEmpty#74] +Output [10]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, d_qoy#78, d_moy#79, s_store_id#80, sum#81, isEmpty#82] (33) HashAggregate [codegen id : 22] -Input [10]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, d_qoy#70, d_moy#71, s_store_id#72, sum#73, isEmpty#74] -Keys [8]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, d_qoy#70, d_moy#71, s_store_id#72] -Functions [1]: [sum(coalesce((ss_sales_price#75 * cast(ss_quantity#76 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#75 * cast(ss_quantity#76 as decimal(10,0))), 0.00))#22] -Results [6]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sum(coalesce((ss_sales_price#75 * cast(ss_quantity#76 as decimal(10,0))), 0.00))#22 AS sumsales#77] +Input [10]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, d_qoy#78, d_moy#79, s_store_id#80, sum#81, isEmpty#82] +Keys [8]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, d_qoy#78, d_moy#79, s_store_id#80] +Functions [1]: [sum(coalesce((ss_sales_price#83 * cast(ss_quantity#84 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#83 * cast(ss_quantity#84 as decimal(10,0))), 0.00))#22] +Results [6]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sum(coalesce((ss_sales_price#83 * cast(ss_quantity#84 as decimal(10,0))), 0.00))#22 AS sumsales#85] (34) HashAggregate [codegen id : 22] -Input [6]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sumsales#77] -Keys [5]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69] -Functions [1]: [partial_sum(sumsales#77)] -Aggregate Attributes [2]: [sum#78, isEmpty#79] -Results [7]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sum#80, isEmpty#81] +Input [6]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sumsales#85] +Keys [5]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77] +Functions [1]: [partial_sum(sumsales#85)] +Aggregate Attributes [2]: [sum#86, isEmpty#87] +Results [7]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sum#88, isEmpty#89] (35) Exchange -Input [7]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sum#80, isEmpty#81] -Arguments: hashpartitioning(i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, 5), ENSURE_REQUIREMENTS, [plan_id=6] +Input [7]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sum#88, isEmpty#89] +Arguments: hashpartitioning(i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, 5), ENSURE_REQUIREMENTS, [plan_id=6] (36) HashAggregate [codegen id : 23] -Input [7]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sum#80, isEmpty#81] -Keys [5]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69] -Functions [1]: [sum(sumsales#77)] -Aggregate Attributes [1]: [sum(sumsales#77)#82] -Results [9]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, null AS d_qoy#83, null AS d_moy#84, null AS s_store_id#85, sum(sumsales#77)#82 AS sumsales#86] +Input [7]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sum#88, isEmpty#89] +Keys [5]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77] +Functions [1]: [sum(sumsales#85)] +Aggregate Attributes [1]: [sum(sumsales#85)#90] +Results [9]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, null AS d_qoy#91, null AS d_moy#92, null AS s_store_id#93, sum(sumsales#85)#90 AS sumsales#94] (37) ReusedExchange [Reuses operator id: 20] -Output [10]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, d_year#91, d_qoy#92, d_moy#93, s_store_id#94, sum#95, isEmpty#96] +Output [10]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, d_year#99, d_qoy#100, d_moy#101, s_store_id#102, sum#103, isEmpty#104] (38) HashAggregate [codegen id : 28] -Input [10]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, d_year#91, d_qoy#92, d_moy#93, s_store_id#94, sum#95, isEmpty#96] -Keys [8]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, d_year#91, d_qoy#92, d_moy#93, s_store_id#94] -Functions [1]: [sum(coalesce((ss_sales_price#97 * cast(ss_quantity#98 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#97 * cast(ss_quantity#98 as decimal(10,0))), 0.00))#22] -Results [5]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sum(coalesce((ss_sales_price#97 * cast(ss_quantity#98 as decimal(10,0))), 0.00))#22 AS sumsales#99] +Input [10]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, d_year#99, d_qoy#100, d_moy#101, s_store_id#102, sum#103, isEmpty#104] +Keys [8]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, d_year#99, d_qoy#100, d_moy#101, s_store_id#102] +Functions [1]: [sum(coalesce((ss_sales_price#105 * cast(ss_quantity#106 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#105 * cast(ss_quantity#106 as decimal(10,0))), 0.00))#22] +Results [5]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sum(coalesce((ss_sales_price#105 * cast(ss_quantity#106 as decimal(10,0))), 0.00))#22 AS sumsales#107] (39) HashAggregate [codegen id : 28] -Input [5]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sumsales#99] -Keys [4]: [i_category#87, i_class#88, i_brand#89, i_product_name#90] -Functions [1]: [partial_sum(sumsales#99)] -Aggregate Attributes [2]: [sum#100, isEmpty#101] -Results [6]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sum#102, isEmpty#103] +Input [5]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sumsales#107] +Keys [4]: [i_category#95, i_class#96, i_brand#97, i_product_name#98] +Functions [1]: [partial_sum(sumsales#107)] +Aggregate Attributes [2]: [sum#108, isEmpty#109] +Results [6]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sum#110, isEmpty#111] (40) Exchange -Input [6]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sum#102, isEmpty#103] -Arguments: hashpartitioning(i_category#87, i_class#88, i_brand#89, i_product_name#90, 5), ENSURE_REQUIREMENTS, [plan_id=7] +Input [6]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sum#110, isEmpty#111] +Arguments: hashpartitioning(i_category#95, i_class#96, i_brand#97, i_product_name#98, 5), ENSURE_REQUIREMENTS, [plan_id=7] (41) HashAggregate [codegen id : 29] -Input [6]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sum#102, isEmpty#103] -Keys [4]: [i_category#87, i_class#88, i_brand#89, i_product_name#90] -Functions [1]: [sum(sumsales#99)] -Aggregate Attributes [1]: [sum(sumsales#99)#104] -Results [9]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, null AS d_year#105, null AS d_qoy#106, null AS d_moy#107, null AS s_store_id#108, sum(sumsales#99)#104 AS sumsales#109] +Input [6]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sum#110, isEmpty#111] +Keys [4]: [i_category#95, i_class#96, i_brand#97, i_product_name#98] +Functions [1]: [sum(sumsales#107)] +Aggregate Attributes [1]: [sum(sumsales#107)#112] +Results [9]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, null AS d_year#113, null AS d_qoy#114, null AS d_moy#115, null AS s_store_id#116, sum(sumsales#107)#112 AS sumsales#117] (42) ReusedExchange [Reuses operator id: 20] -Output [10]: [i_category#110, i_class#111, i_brand#112, i_product_name#113, d_year#114, d_qoy#115, d_moy#116, s_store_id#117, sum#118, isEmpty#119] +Output [10]: [i_category#118, i_class#119, i_brand#120, i_product_name#121, d_year#122, d_qoy#123, d_moy#124, s_store_id#125, sum#126, isEmpty#127] (43) HashAggregate [codegen id : 34] -Input [10]: [i_category#110, i_class#111, i_brand#112, i_product_name#113, d_year#114, d_qoy#115, d_moy#116, s_store_id#117, sum#118, isEmpty#119] -Keys [8]: [i_category#110, i_class#111, i_brand#112, i_product_name#113, d_year#114, d_qoy#115, d_moy#116, s_store_id#117] -Functions [1]: [sum(coalesce((ss_sales_price#120 * cast(ss_quantity#121 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#120 * cast(ss_quantity#121 as decimal(10,0))), 0.00))#22] -Results [4]: [i_category#110, i_class#111, i_brand#112, sum(coalesce((ss_sales_price#120 * cast(ss_quantity#121 as decimal(10,0))), 0.00))#22 AS sumsales#122] +Input [10]: [i_category#118, i_class#119, i_brand#120, i_product_name#121, d_year#122, d_qoy#123, d_moy#124, s_store_id#125, sum#126, isEmpty#127] +Keys [8]: [i_category#118, i_class#119, i_brand#120, i_product_name#121, d_year#122, d_qoy#123, d_moy#124, s_store_id#125] +Functions [1]: [sum(coalesce((ss_sales_price#128 * cast(ss_quantity#129 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#128 * cast(ss_quantity#129 as decimal(10,0))), 0.00))#22] +Results [4]: [i_category#118, i_class#119, i_brand#120, sum(coalesce((ss_sales_price#128 * cast(ss_quantity#129 as decimal(10,0))), 0.00))#22 AS sumsales#130] (44) HashAggregate [codegen id : 34] -Input [4]: [i_category#110, i_class#111, i_brand#112, sumsales#122] -Keys [3]: [i_category#110, i_class#111, i_brand#112] -Functions [1]: [partial_sum(sumsales#122)] -Aggregate Attributes [2]: [sum#123, isEmpty#124] -Results [5]: [i_category#110, i_class#111, i_brand#112, sum#125, isEmpty#126] +Input [4]: [i_category#118, i_class#119, i_brand#120, sumsales#130] +Keys [3]: [i_category#118, i_class#119, i_brand#120] +Functions [1]: [partial_sum(sumsales#130)] +Aggregate Attributes [2]: [sum#131, isEmpty#132] +Results [5]: [i_category#118, i_class#119, i_brand#120, sum#133, isEmpty#134] (45) Exchange -Input [5]: [i_category#110, i_class#111, i_brand#112, sum#125, isEmpty#126] -Arguments: hashpartitioning(i_category#110, i_class#111, i_brand#112, 5), ENSURE_REQUIREMENTS, [plan_id=8] +Input [5]: [i_category#118, i_class#119, i_brand#120, sum#133, isEmpty#134] +Arguments: hashpartitioning(i_category#118, i_class#119, i_brand#120, 5), ENSURE_REQUIREMENTS, [plan_id=8] (46) HashAggregate [codegen id : 35] -Input [5]: [i_category#110, i_class#111, i_brand#112, sum#125, isEmpty#126] -Keys [3]: [i_category#110, i_class#111, i_brand#112] -Functions [1]: [sum(sumsales#122)] -Aggregate Attributes [1]: [sum(sumsales#122)#127] -Results [9]: [i_category#110, i_class#111, i_brand#112, null AS i_product_name#128, null AS d_year#129, null AS d_qoy#130, null AS d_moy#131, null AS s_store_id#132, sum(sumsales#122)#127 AS sumsales#133] +Input [5]: [i_category#118, i_class#119, i_brand#120, sum#133, isEmpty#134] +Keys [3]: [i_category#118, i_class#119, i_brand#120] +Functions [1]: [sum(sumsales#130)] +Aggregate Attributes [1]: [sum(sumsales#130)#135] +Results [9]: [i_category#118, i_class#119, i_brand#120, null AS i_product_name#136, null AS d_year#137, null AS d_qoy#138, null AS d_moy#139, null AS s_store_id#140, sum(sumsales#130)#135 AS sumsales#141] (47) ReusedExchange [Reuses operator id: 20] -Output [10]: [i_category#134, i_class#135, i_brand#136, i_product_name#137, d_year#138, d_qoy#139, d_moy#140, s_store_id#141, sum#142, isEmpty#143] +Output [10]: [i_category#142, i_class#143, i_brand#144, i_product_name#145, d_year#146, d_qoy#147, d_moy#148, s_store_id#149, sum#150, isEmpty#151] (48) HashAggregate [codegen id : 40] -Input [10]: [i_category#134, i_class#135, i_brand#136, i_product_name#137, d_year#138, d_qoy#139, d_moy#140, s_store_id#141, sum#142, isEmpty#143] -Keys [8]: [i_category#134, i_class#135, i_brand#136, i_product_name#137, d_year#138, d_qoy#139, d_moy#140, s_store_id#141] -Functions [1]: [sum(coalesce((ss_sales_price#144 * cast(ss_quantity#145 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#144 * cast(ss_quantity#145 as decimal(10,0))), 0.00))#22] -Results [3]: [i_category#134, i_class#135, sum(coalesce((ss_sales_price#144 * cast(ss_quantity#145 as decimal(10,0))), 0.00))#22 AS sumsales#146] +Input [10]: [i_category#142, i_class#143, i_brand#144, i_product_name#145, d_year#146, d_qoy#147, d_moy#148, s_store_id#149, sum#150, isEmpty#151] +Keys [8]: [i_category#142, i_class#143, i_brand#144, i_product_name#145, d_year#146, d_qoy#147, d_moy#148, s_store_id#149] +Functions [1]: [sum(coalesce((ss_sales_price#152 * cast(ss_quantity#153 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#152 * cast(ss_quantity#153 as decimal(10,0))), 0.00))#22] +Results [3]: [i_category#142, i_class#143, sum(coalesce((ss_sales_price#152 * cast(ss_quantity#153 as decimal(10,0))), 0.00))#22 AS sumsales#154] (49) HashAggregate [codegen id : 40] -Input [3]: [i_category#134, i_class#135, sumsales#146] -Keys [2]: [i_category#134, i_class#135] -Functions [1]: [partial_sum(sumsales#146)] -Aggregate Attributes [2]: [sum#147, isEmpty#148] -Results [4]: [i_category#134, i_class#135, sum#149, isEmpty#150] +Input [3]: [i_category#142, i_class#143, sumsales#154] +Keys [2]: [i_category#142, i_class#143] +Functions [1]: [partial_sum(sumsales#154)] +Aggregate Attributes [2]: [sum#155, isEmpty#156] +Results [4]: [i_category#142, i_class#143, sum#157, isEmpty#158] (50) Exchange -Input [4]: [i_category#134, i_class#135, sum#149, isEmpty#150] -Arguments: hashpartitioning(i_category#134, i_class#135, 5), ENSURE_REQUIREMENTS, [plan_id=9] +Input [4]: [i_category#142, i_class#143, sum#157, isEmpty#158] +Arguments: hashpartitioning(i_category#142, i_class#143, 5), ENSURE_REQUIREMENTS, [plan_id=9] (51) HashAggregate [codegen id : 41] -Input [4]: [i_category#134, i_class#135, sum#149, isEmpty#150] -Keys [2]: [i_category#134, i_class#135] -Functions [1]: [sum(sumsales#146)] -Aggregate Attributes [1]: [sum(sumsales#146)#151] -Results [9]: [i_category#134, i_class#135, null AS i_brand#152, null AS i_product_name#153, null AS d_year#154, null AS d_qoy#155, null AS d_moy#156, null AS s_store_id#157, sum(sumsales#146)#151 AS sumsales#158] +Input [4]: [i_category#142, i_class#143, sum#157, isEmpty#158] +Keys [2]: [i_category#142, i_class#143] +Functions [1]: [sum(sumsales#154)] +Aggregate Attributes [1]: [sum(sumsales#154)#159] +Results [9]: [i_category#142, i_class#143, null AS i_brand#160, null AS i_product_name#161, null AS d_year#162, null AS d_qoy#163, null AS d_moy#164, null AS s_store_id#165, sum(sumsales#154)#159 AS sumsales#166] (52) ReusedExchange [Reuses operator id: 20] -Output [10]: [i_category#159, i_class#160, i_brand#161, i_product_name#162, d_year#163, d_qoy#164, d_moy#165, s_store_id#166, sum#167, isEmpty#168] +Output [10]: [i_category#167, i_class#168, i_brand#169, i_product_name#170, d_year#171, d_qoy#172, d_moy#173, s_store_id#174, sum#175, isEmpty#176] (53) HashAggregate [codegen id : 46] -Input [10]: [i_category#159, i_class#160, i_brand#161, i_product_name#162, d_year#163, d_qoy#164, d_moy#165, s_store_id#166, sum#167, isEmpty#168] -Keys [8]: [i_category#159, i_class#160, i_brand#161, i_product_name#162, d_year#163, d_qoy#164, d_moy#165, s_store_id#166] -Functions [1]: [sum(coalesce((ss_sales_price#169 * cast(ss_quantity#170 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#169 * cast(ss_quantity#170 as decimal(10,0))), 0.00))#22] -Results [2]: [i_category#159, sum(coalesce((ss_sales_price#169 * cast(ss_quantity#170 as decimal(10,0))), 0.00))#22 AS sumsales#171] +Input [10]: [i_category#167, i_class#168, i_brand#169, i_product_name#170, d_year#171, d_qoy#172, d_moy#173, s_store_id#174, sum#175, isEmpty#176] +Keys [8]: [i_category#167, i_class#168, i_brand#169, i_product_name#170, d_year#171, d_qoy#172, d_moy#173, s_store_id#174] +Functions [1]: [sum(coalesce((ss_sales_price#177 * cast(ss_quantity#178 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#177 * cast(ss_quantity#178 as decimal(10,0))), 0.00))#22] +Results [2]: [i_category#167, sum(coalesce((ss_sales_price#177 * cast(ss_quantity#178 as decimal(10,0))), 0.00))#22 AS sumsales#179] (54) HashAggregate [codegen id : 46] -Input [2]: [i_category#159, sumsales#171] -Keys [1]: [i_category#159] -Functions [1]: [partial_sum(sumsales#171)] -Aggregate Attributes [2]: [sum#172, isEmpty#173] -Results [3]: [i_category#159, sum#174, isEmpty#175] +Input [2]: [i_category#167, sumsales#179] +Keys [1]: [i_category#167] +Functions [1]: [partial_sum(sumsales#179)] +Aggregate Attributes [2]: [sum#180, isEmpty#181] +Results [3]: [i_category#167, sum#182, isEmpty#183] (55) Exchange -Input [3]: [i_category#159, sum#174, isEmpty#175] -Arguments: hashpartitioning(i_category#159, 5), ENSURE_REQUIREMENTS, [plan_id=10] +Input [3]: [i_category#167, sum#182, isEmpty#183] +Arguments: hashpartitioning(i_category#167, 5), ENSURE_REQUIREMENTS, [plan_id=10] (56) HashAggregate [codegen id : 47] -Input [3]: [i_category#159, sum#174, isEmpty#175] -Keys [1]: [i_category#159] -Functions [1]: [sum(sumsales#171)] -Aggregate Attributes [1]: [sum(sumsales#171)#176] -Results [9]: [i_category#159, null AS i_class#177, null AS i_brand#178, null AS i_product_name#179, null AS d_year#180, null AS d_qoy#181, null AS d_moy#182, null AS s_store_id#183, sum(sumsales#171)#176 AS sumsales#184] +Input [3]: [i_category#167, sum#182, isEmpty#183] +Keys [1]: [i_category#167] +Functions [1]: [sum(sumsales#179)] +Aggregate Attributes [1]: [sum(sumsales#179)#184] +Results [9]: [i_category#167, null AS i_class#185, null AS i_brand#186, null AS i_product_name#187, null AS d_year#188, null AS d_qoy#189, null AS d_moy#190, null AS s_store_id#191, sum(sumsales#179)#184 AS sumsales#192] (57) ReusedExchange [Reuses operator id: 20] -Output [10]: [i_category#185, i_class#186, i_brand#187, i_product_name#188, d_year#189, d_qoy#190, d_moy#191, s_store_id#192, sum#193, isEmpty#194] +Output [10]: [i_category#193, i_class#194, i_brand#195, i_product_name#196, d_year#197, d_qoy#198, d_moy#199, s_store_id#200, sum#201, isEmpty#202] (58) HashAggregate [codegen id : 52] -Input [10]: [i_category#185, i_class#186, i_brand#187, i_product_name#188, d_year#189, d_qoy#190, d_moy#191, s_store_id#192, sum#193, isEmpty#194] -Keys [8]: [i_category#185, i_class#186, i_brand#187, i_product_name#188, d_year#189, d_qoy#190, d_moy#191, s_store_id#192] -Functions [1]: [sum(coalesce((ss_sales_price#195 * cast(ss_quantity#196 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#195 * cast(ss_quantity#196 as decimal(10,0))), 0.00))#22] -Results [1]: [sum(coalesce((ss_sales_price#195 * cast(ss_quantity#196 as decimal(10,0))), 0.00))#22 AS sumsales#197] +Input [10]: [i_category#193, i_class#194, i_brand#195, i_product_name#196, d_year#197, d_qoy#198, d_moy#199, s_store_id#200, sum#201, isEmpty#202] +Keys [8]: [i_category#193, i_class#194, i_brand#195, i_product_name#196, d_year#197, d_qoy#198, d_moy#199, s_store_id#200] +Functions [1]: [sum(coalesce((ss_sales_price#203 * cast(ss_quantity#204 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#203 * cast(ss_quantity#204 as decimal(10,0))), 0.00))#22] +Results [1]: [sum(coalesce((ss_sales_price#203 * cast(ss_quantity#204 as decimal(10,0))), 0.00))#22 AS sumsales#205] (59) HashAggregate [codegen id : 52] -Input [1]: [sumsales#197] +Input [1]: [sumsales#205] Keys: [] -Functions [1]: [partial_sum(sumsales#197)] -Aggregate Attributes [2]: [sum#198, isEmpty#199] -Results [2]: [sum#200, isEmpty#201] +Functions [1]: [partial_sum(sumsales#205)] +Aggregate Attributes [2]: [sum#206, isEmpty#207] +Results [2]: [sum#208, isEmpty#209] (60) Exchange -Input [2]: [sum#200, isEmpty#201] +Input [2]: [sum#208, isEmpty#209] Arguments: SinglePartition, ENSURE_REQUIREMENTS, [plan_id=11] (61) HashAggregate [codegen id : 53] -Input [2]: [sum#200, isEmpty#201] +Input [2]: [sum#208, isEmpty#209] Keys: [] -Functions [1]: [sum(sumsales#197)] -Aggregate Attributes [1]: [sum(sumsales#197)#202] -Results [9]: [null AS i_category#203, null AS i_class#204, null AS i_brand#205, null AS i_product_name#206, null AS d_year#207, null AS d_qoy#208, null AS d_moy#209, null AS s_store_id#210, sum(sumsales#197)#202 AS sumsales#211] +Functions [1]: [sum(sumsales#205)] +Aggregate Attributes [1]: [sum(sumsales#205)#210] +Results [9]: [null AS i_category#211, null AS i_class#212, null AS i_brand#213, null AS i_product_name#214, null AS d_year#215, null AS d_qoy#216, null AS d_moy#217, null AS s_store_id#218, sum(sumsales#205)#210 AS sumsales#219] (62) Union (63) Sort [codegen id : 54] -Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] -Arguments: [i_category#16 ASC NULLS FIRST, sumsales#23 DESC NULLS LAST], false, 0 +Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] +Arguments: [i_category#23 ASC NULLS FIRST, sumsales#31 DESC NULLS LAST], false, 0 (64) WindowGroupLimit -Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] -Arguments: [i_category#16], [sumsales#23 DESC NULLS LAST], rank(sumsales#23), 100, Partial +Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] +Arguments: [i_category#23], [sumsales#31 DESC NULLS LAST], rank(sumsales#31), 100, Partial (65) Exchange -Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] -Arguments: hashpartitioning(i_category#16, 5), ENSURE_REQUIREMENTS, [plan_id=12] +Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] +Arguments: hashpartitioning(i_category#23, 5), ENSURE_REQUIREMENTS, [plan_id=12] (66) Sort [codegen id : 55] -Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] -Arguments: [i_category#16 ASC NULLS FIRST, sumsales#23 DESC NULLS LAST], false, 0 +Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] +Arguments: [i_category#23 ASC NULLS FIRST, sumsales#31 DESC NULLS LAST], false, 0 (67) WindowGroupLimit -Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] -Arguments: [i_category#16], [sumsales#23 DESC NULLS LAST], rank(sumsales#23), 100, Final +Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] +Arguments: [i_category#23], [sumsales#31 DESC NULLS LAST], rank(sumsales#31), 100, Final (68) Window -Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] -Arguments: [rank(sumsales#23) windowspecdefinition(i_category#16, sumsales#23 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rk#212], [i_category#16], [sumsales#23 DESC NULLS LAST] +Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] +Arguments: [rank(sumsales#31) windowspecdefinition(i_category#23, sumsales#31 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rk#220], [i_category#23], [sumsales#31 DESC NULLS LAST] (69) Filter [codegen id : 56] -Input [10]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23, rk#212] -Condition : (rk#212 <= 100) +Input [10]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31, rk#220] +Condition : (rk#220 <= 100) (70) TakeOrderedAndProject -Input [10]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23, rk#212] -Arguments: 100, [i_category#16 ASC NULLS FIRST, i_class#15 ASC NULLS FIRST, i_brand#14 ASC NULLS FIRST, i_product_name#17 ASC NULLS FIRST, d_year#8 ASC NULLS FIRST, d_qoy#10 ASC NULLS FIRST, d_moy#9 ASC NULLS FIRST, s_store_id#12 ASC NULLS FIRST, sumsales#23 ASC NULLS FIRST, rk#212 ASC NULLS FIRST], [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23, rk#212] +Input [10]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31, rk#220] +Arguments: 100, [i_category#23 ASC NULLS FIRST, i_class#24 ASC NULLS FIRST, i_brand#25 ASC NULLS FIRST, i_product_name#26 ASC NULLS FIRST, d_year#27 ASC NULLS FIRST, d_qoy#28 ASC NULLS FIRST, d_moy#29 ASC NULLS FIRST, s_store_id#30 ASC NULLS FIRST, sumsales#31 ASC NULLS FIRST, rk#220 ASC NULLS FIRST], [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31, rk#220] ===== Subqueries ===== @@ -442,22 +442,22 @@ BroadcastExchange (75) (71) Scan parquet spark_catalog.default.date_dim -Output [5]: [d_date_sk#7, d_month_seq#213, d_year#8, d_moy#9, d_qoy#10] +Output [5]: [d_date_sk#7, d_month_seq#221, d_year#8, d_moy#9, d_qoy#10] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_month_seq), GreaterThanOrEqual(d_month_seq,1212), LessThanOrEqual(d_month_seq,1223), IsNotNull(d_date_sk)] ReadSchema: struct (72) ColumnarToRow [codegen id : 1] -Input [5]: [d_date_sk#7, d_month_seq#213, d_year#8, d_moy#9, d_qoy#10] +Input [5]: [d_date_sk#7, d_month_seq#221, d_year#8, d_moy#9, d_qoy#10] (73) Filter [codegen id : 1] -Input [5]: [d_date_sk#7, d_month_seq#213, d_year#8, d_moy#9, d_qoy#10] -Condition : (((isnotnull(d_month_seq#213) AND (d_month_seq#213 >= 1212)) AND (d_month_seq#213 <= 1223)) AND isnotnull(d_date_sk#7)) +Input [5]: [d_date_sk#7, d_month_seq#221, d_year#8, d_moy#9, d_qoy#10] +Condition : (((isnotnull(d_month_seq#221) AND (d_month_seq#221 >= 1212)) AND (d_month_seq#221 <= 1223)) AND isnotnull(d_date_sk#7)) (74) Project [codegen id : 1] Output [4]: [d_date_sk#7, d_year#8, d_moy#9, d_qoy#10] -Input [5]: [d_date_sk#7, d_month_seq#213, d_year#8, d_moy#9, d_qoy#10] +Input [5]: [d_date_sk#7, d_month_seq#221, d_year#8, d_moy#9, d_qoy#10] (75) BroadcastExchange Input [4]: [d_date_sk#7, d_year#8, d_moy#9, d_qoy#10] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/simplified.txt index 89393f265a49f..5a43dced056bd 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/simplified.txt @@ -14,7 +14,7 @@ TakeOrderedAndProject [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_ InputAdapter Union WholeStageCodegen (5) - HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce((ss_sales_price * cast(ss_quantity as decimal(10,0))), 0.00)),sumsales,sum,isEmpty] + HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce((ss_sales_price * cast(ss_quantity as decimal(10,0))), 0.00)),i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sumsales,sum,isEmpty] InputAdapter Exchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id] #2 WholeStageCodegen (4) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala index 7b608b7438c29..7a2ce1d7836b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala @@ -714,6 +714,27 @@ abstract class CTEInlineSuiteBase |""".stripMargin) checkAnswer(df, Row(1)) } + + test("SPARK-49816: should only update out-going-ref-count for referenced outer CTE relation") { + withView("v") { + sql( + """ + |WITH + |t1 AS (SELECT 1 col), + |t2 AS (SELECT * FROM t1) + |SELECT * FROM t2 + |""".stripMargin).createTempView("v") + // r1 is un-referenced, but it should not decrease the ref count of t2 inside view v. + val df = sql( + """ + |WITH + |r1 AS (SELECT * FROM v), + |r2 AS (SELECT * FROM v) + |SELECT * FROM r2 + |""".stripMargin) + checkAnswer(df, Row(1)) + } + } } class CTEInlineSuiteAEOff extends CTEInlineSuiteBase with DisableAdaptiveExecutionSuite diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index 2342722c0bb14..8600ec4f8787f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.sql.Timestamp import org.apache.spark.{SparkFunSuite, SparkRuntimeException} +import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.variant.ParseJson import org.apache.spark.sql.internal.SqlApiConf @@ -46,7 +47,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * * @param inputEntry - List of all input entries that need to be generated * @param collationType - Flag defining collation type to use - * @return + * @return - List of data generated for expression instance creation */ def generateData( inputEntry: Seq[Any], @@ -54,33 +55,21 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi inputEntry.map(generateSingleEntry(_, collationType)) } - /** - * Helper function to generate single entry of data as a string. - * @param inputEntry - Single input entry that requires generation - * @param collationType - Flag defining collation type to use - * @return - */ - def generateDataAsStrings( - inputEntry: Seq[AbstractDataType], - collationType: CollationType): Seq[Any] = { - inputEntry.map(generateInputAsString(_, collationType)) - } - /** * Helper function to generate single entry of data. * @param inputEntry - Single input entry that requires generation * @param collationType - Flag defining collation type to use - * @return + * @return - Single input entry data */ def generateSingleEntry( inputEntry: Any, collationType: CollationType): Any = inputEntry match { case e: Class[_] if e.isAssignableFrom(classOf[Expression]) => - generateLiterals(StringTypeAnyCollation, collationType) + generateLiterals(StringTypeWithCaseAccentSensitivity, collationType) case se: Class[_] if se.isAssignableFrom(classOf[Seq[Expression]]) => - CreateArray(Seq(generateLiterals(StringTypeAnyCollation, collationType), - generateLiterals(StringTypeAnyCollation, collationType))) + CreateArray(Seq(generateLiterals(StringTypeWithCaseAccentSensitivity, collationType), + generateLiterals(StringTypeWithCaseAccentSensitivity, collationType))) case oe: Class[_] if oe.isAssignableFrom(classOf[Option[Any]]) => None case b: Class[_] if b.isAssignableFrom(classOf[Boolean]) => false case dt: Class[_] if dt.isAssignableFrom(classOf[DataType]) => StringType @@ -100,7 +89,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * * @param inputType - Single input literal type that requires generation * @param collationType - Flag defining collation type to use - * @return + * @return - Literal/Expression containing expression ready for evaluation */ def generateLiterals( inputType: AbstractDataType, @@ -116,6 +105,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi } case BooleanType => Literal(true) case _: DatetimeType => Literal(Timestamp.valueOf("2009-07-30 12:58:59")) + case DecimalType => Literal((new Decimal).set(5)) case _: DecimalType => Literal((new Decimal).set(5)) case _: DoubleType => Literal(5.0) case IntegerType | NumericType | IntegralType => Literal(5) @@ -152,21 +142,26 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType)) ).head case ArrayType => - generateLiterals(StringTypeAnyCollation, collationType).map( + generateLiterals(StringTypeWithCaseAccentSensitivity, collationType).map( lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType)) ).head case MapType => - val key = generateLiterals(StringTypeAnyCollation, collationType) - val value = generateLiterals(StringTypeAnyCollation, collationType) - Literal.create(Map(key -> value)) + val key = generateLiterals(StringTypeWithCaseAccentSensitivity, collationType) + val value = generateLiterals(StringTypeWithCaseAccentSensitivity, collationType) + CreateMap(Seq(key, value)) case MapType(keyType, valueType, _) => val key = generateLiterals(keyType, collationType) val value = generateLiterals(valueType, collationType) - Literal.create(Map(key -> value)) + CreateMap(Seq(key, value)) + case AbstractMapType(keyType, valueType) => + val key = generateLiterals(keyType, collationType) + val value = generateLiterals(valueType, collationType) + CreateMap(Seq(key, value)) case StructType => CreateNamedStruct( - Seq(Literal("start"), generateLiterals(StringTypeAnyCollation, collationType), - Literal("end"), generateLiterals(StringTypeAnyCollation, collationType))) + Seq(Literal("start"), + generateLiterals(StringTypeWithCaseAccentSensitivity, collationType), + Literal("end"), generateLiterals(StringTypeWithCaseAccentSensitivity, collationType))) } /** @@ -174,7 +169,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * * @param inputType - Single input type that requires generation * @param collationType - Flag defining collation type to use - * @return + * @return - String representation of a input ready for SQL query */ def generateInputAsString( inputType: AbstractDataType, @@ -189,6 +184,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi } case BooleanType => "True" case _: DatetimeType => "date'2016-04-08'" + case DecimalType => "5.0" case _: DecimalType => "5.0" case _: DoubleType => "5.0" case IntegerType | NumericType | IntegralType => "5" @@ -214,16 +210,20 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case ArrayType(elementType, _) => "array(" + generateInputAsString(elementType, collationType) + ")" case ArrayType => - "array(" + generateInputAsString(StringTypeAnyCollation, collationType) + ")" + "array(" + generateInputAsString(StringTypeWithCaseAccentSensitivity, collationType) + ")" case MapType => - "map(" + generateInputAsString(StringTypeAnyCollation, collationType) + ", " + - generateInputAsString(StringTypeAnyCollation, collationType) + ")" + "map(" + generateInputAsString(StringTypeWithCaseAccentSensitivity, collationType) + ", " + + generateInputAsString(StringTypeWithCaseAccentSensitivity, collationType) + ")" case MapType(keyType, valueType, _) => "map(" + generateInputAsString(keyType, collationType) + ", " + generateInputAsString(valueType, collationType) + ")" + case AbstractMapType(keyType, valueType) => + "map(" + generateInputAsString(keyType, collationType) + ", " + + generateInputAsString(valueType, collationType) + ")" case StructType => - "named_struct( 'start', " + generateInputAsString(StringTypeAnyCollation, collationType) + - ", 'end', " + generateInputAsString(StringTypeAnyCollation, collationType) + ")" + "named_struct( 'start', " + + generateInputAsString(StringTypeWithCaseAccentSensitivity, collationType) + ", 'end', " + + generateInputAsString(StringTypeWithCaseAccentSensitivity, collationType) + ")" case StructType(fields) => "named_struct(" + fields.map(f => "'" + f.name + "', " + generateInputAsString(f.dataType, collationType)).mkString(", ") + ")" @@ -234,7 +234,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * * @param inputType - Single input type that requires generation * @param collationType - Flag defining collation type to use - * @return + * @return - String representation for SQL query of a inputType */ def generateInputTypeAsStrings( inputType: AbstractDataType, @@ -244,6 +244,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case BinaryType => "BINARY" case BooleanType => "BOOLEAN" case _: DatetimeType => "DATE" + case DecimalType => "DECIMAL(2, 1)" case _: DecimalType => "DECIMAL(2, 1)" case _: DoubleType => "DOUBLE" case IntegerType | NumericType | IntegralType => "INT" @@ -268,17 +269,23 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case ArrayType(elementType, _) => "array<" + generateInputTypeAsStrings(elementType, collationType) + ">" case ArrayType => - "array<" + generateInputTypeAsStrings(StringTypeAnyCollation, collationType) + ">" + "array<" + generateInputTypeAsStrings(StringTypeWithCaseAccentSensitivity, collationType) + + ">" case MapType => - "map<" + generateInputTypeAsStrings(StringTypeAnyCollation, collationType) + ", " + - generateInputTypeAsStrings(StringTypeAnyCollation, collationType) + ">" + "map<" + generateInputTypeAsStrings(StringTypeWithCaseAccentSensitivity, collationType) + + ", " + + generateInputTypeAsStrings(StringTypeWithCaseAccentSensitivity, collationType) + ">" case MapType(keyType, valueType, _) => "map<" + generateInputTypeAsStrings(keyType, collationType) + ", " + generateInputTypeAsStrings(valueType, collationType) + ">" + case AbstractMapType(keyType, valueType) => + "map<" + generateInputTypeAsStrings(keyType, collationType) + ", " + + generateInputTypeAsStrings(valueType, collationType) + ">" case StructType => - "struct" + generateInputTypeAsStrings(StringTypeWithCaseAccentSensitivity, collationType) + ">" case StructType(fields) => "named_struct<" + fields.map(f => "'" + f.name + "', " + generateInputTypeAsStrings(f.dataType, collationType)).mkString(", ") + ">" @@ -287,12 +294,12 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi /** * Helper function to extract types of relevance * @param inputType - * @return + * @return - Boolean that represents if inputType has/is a StringType */ def hasStringType(inputType: AbstractDataType): Boolean = { inputType match { - case _: StringType | StringTypeAnyCollation | StringTypeBinaryLcase | AnyDataType => - true + case _: StringType | StringTypeWithCaseAccentSensitivity | StringTypeBinaryLcase | AnyDataType + => true case ArrayType => true case MapType => true case MapType(keyType, valueType, _) => hasStringType(keyType) || hasStringType(valueType) @@ -300,7 +307,6 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case AbstractArrayType(elementType) => hasStringType(elementType) case TypeCollection(typeCollection) => typeCollection.exists(hasStringType) - case StructType => true case StructType(fields) => fields.exists(sf => hasStringType(sf.dataType)) case _ => false } @@ -310,7 +316,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * Helper function to replace expected parameters with expected input types. * @param inputTypes - Input types generated by ExpectsInputType.inputTypes * @param params - Parameters that are read from expression info - * @return + * @return - List of parameters where Expressions are replaced with input types */ def replaceExpressions(inputTypes: Seq[AbstractDataType], params: Seq[Class[_]]): Seq[Any] = { (inputTypes, params) match { @@ -325,7 +331,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi /** * Helper method to extract relevant expressions that can be walked over. - * @return + * @return - (List of relevant expressions that expect input, List of expressions to skip) */ def extractRelevantExpressions(): (Array[ExpressionInfo], List[String]) = { var expressionCounter = 0 @@ -384,6 +390,47 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi (funInfos, toSkip) } + /** + * Helper method to extract relevant expressions that can be walked over but are built with + * expression builder. + * + * @return - (List of expressions that are relevant builders, List of expressions to skip) + */ + def extractRelevantBuilders(): (Array[ExpressionInfo], List[String]) = { + var builderExpressionCounter = 0 + val funInfos = spark.sessionState.functionRegistry.listFunction().map { funcId => + spark.sessionState.catalog.lookupFunctionInfo(funcId) + }.filter(funInfo => { + // make sure that there is a constructor. + val cl = Utils.classForName(funInfo.getClassName) + cl.isAssignableFrom(classOf[ExpressionBuilder]) + }).filter(funInfo => { + builderExpressionCounter = builderExpressionCounter + 1 + val cl = Utils.classForName(funInfo.getClassName) + val method = cl.getMethod("build", + Utils.classForName("java.lang.String"), + Utils.classForName("scala.collection.Seq")) + var input: Seq[Expression] = Seq.empty + var i = 0 + for (_ <- 1 to 10) { + input = input :+ generateLiterals(StringTypeWithCaseAccentSensitivity, Utf8Binary) + try { + method.invoke(null, funInfo.getClassName, input).asInstanceOf[ExpectsInputTypes] + } + catch { + case _: Exception => i = i + 1 + } + } + if (i == 10) false + else true + }).toArray + + logInfo("Total number of expression that are built: " + builderExpressionCounter) + logInfo("Number of extracted expressions of relevance: " + funInfos.length) + + (funInfos, List()) + } + /** * Helper function to generate string of an expression suitable for execution. * @param expr - Expression that needs to be converted @@ -441,10 +488,36 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * 5) Otherwise, check if exceptions are the same */ test("SPARK-48280: Expression Walker for expression evaluation") { - val (funInfos, toSkip) = extractRelevantExpressions() + val (funInfosExpr, toSkip) = extractRelevantExpressions() + val (funInfosBuild, _) = extractRelevantBuilders() + val funInfos = funInfosExpr ++ funInfosBuild for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { - val cl = Utils.classForName(f.getClassName) + val TempCl = Utils.classForName(f.getClassName) + val cl = if (TempCl.isAssignableFrom(classOf[ExpressionBuilder])) { + val clTemp = Utils.classForName(f.getClassName) + val method = clTemp.getMethod("build", + Utils.classForName("java.lang.String"), + Utils.classForName("scala.collection.Seq")) + val instance = { + var input: Seq[Expression] = Seq.empty + var result: Expression = null + for (_ <- 1 to 10) { + input = input :+ generateLiterals(StringTypeWithCaseAccentSensitivity, Utf8Binary) + try { + val tempResult = method.invoke(null, f.getClassName, input) + if (result == null) result = tempResult.asInstanceOf[Expression] + } + catch { + case _: Exception => + } + } + result + } + instance.getClass + } + else Utils.classForName(f.getClassName) + val headConstructor = cl.getConstructors .zip(cl.getConstructors.map(c => c.getParameters.length)).minBy(a => a._2)._1 val params = headConstructor.getParameters.map(p => p.getType) @@ -526,10 +599,36 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * 5) Otherwise, check if exceptions are the same */ test("SPARK-48280: Expression Walker for codeGen generation") { - val (funInfos, toSkip) = extractRelevantExpressions() + val (funInfosExpr, toSkip) = extractRelevantExpressions() + val (funInfosBuild, _) = extractRelevantBuilders() + val funInfos = funInfosExpr ++ funInfosBuild for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { - val cl = Utils.classForName(f.getClassName) + val TempCl = Utils.classForName(f.getClassName) + val cl = if (TempCl.isAssignableFrom(classOf[ExpressionBuilder])) { + val clTemp = Utils.classForName(f.getClassName) + val method = clTemp.getMethod("build", + Utils.classForName("java.lang.String"), + Utils.classForName("scala.collection.Seq")) + val instance = { + var input: Seq[Expression] = Seq.empty + var result: Expression = null + for (_ <- 1 to 10) { + input = input :+ generateLiterals(StringTypeWithCaseAccentSensitivity, Utf8Binary) + try { + val tempResult = method.invoke(null, f.getClassName, input) + if (result == null) result = tempResult.asInstanceOf[Expression] + } + catch { + case _: Exception => + } + } + result + } + instance.getClass + } + else Utils.classForName(f.getClassName) + val headConstructor = cl.getConstructors .zip(cl.getConstructors.map(c => c.getParameters.length)).minBy(a => a._2)._1 val params = headConstructor.getParameters.map(p => p.getType) @@ -642,7 +741,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi assert(resultUTF8.collect() === resultUTF8Lcase.collect()) } } catch { - case e: SparkRuntimeException => assert(e.getErrorClass == "USER_RAISED_EXCEPTION") + case e: SparkRuntimeException => assert(e.getCondition == "USER_RAISED_EXCEPTION") case other: Throwable => throw other } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index f8cd840ecdbb9..ce6818652d2b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -23,7 +23,7 @@ import java.text.SimpleDateFormat import scala.collection.immutable.Seq import org.apache.spark.{SparkConf, SparkException, SparkIllegalArgumentException, SparkRuntimeException} -import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Mode import org.apache.spark.sql.internal.{SqlApiConf, SQLConf} @@ -49,9 +49,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( Md5TestCase("Spark", "UTF8_BINARY", "8cde774d6f7333752ed72cacddb05126"), + Md5TestCase("Spark", "UTF8_BINARY_RTRIM", "8cde774d6f7333752ed72cacddb05126"), Md5TestCase("Spark", "UTF8_LCASE", "8cde774d6f7333752ed72cacddb05126"), + Md5TestCase("Spark", "UTF8_LCASE_RTRIM", "8cde774d6f7333752ed72cacddb05126"), Md5TestCase("SQL", "UNICODE", "9778840a0100cb30c982876741b0b5a2"), - Md5TestCase("SQL", "UNICODE_CI", "9778840a0100cb30c982876741b0b5a2") + Md5TestCase("SQL", "UNICODE_RTRIM", "9778840a0100cb30c982876741b0b5a2"), + Md5TestCase("SQL", "UNICODE_CI", "9778840a0100cb30c982876741b0b5a2"), + Md5TestCase("SQL", "UNICODE_CI_RTRIM", "9778840a0100cb30c982876741b0b5a2") ) // Supported collations @@ -81,11 +85,19 @@ class CollationSQLExpressionsSuite val testCases = Seq( Sha2TestCase("Spark", "UTF8_BINARY", 256, "529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b"), + Sha2TestCase("Spark", "UTF8_BINARY_RTRIM", 256, + "529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b"), Sha2TestCase("Spark", "UTF8_LCASE", 256, "529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b"), + Sha2TestCase("Spark", "UTF8_LCASE_RTRIM", 256, + "529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b"), Sha2TestCase("SQL", "UNICODE", 256, "a7056a455639d1c7deec82ee787db24a0c1878e2792b4597709f0facf7cc7b35"), + Sha2TestCase("SQL", "UNICODE_RTRIM", 256, + "a7056a455639d1c7deec82ee787db24a0c1878e2792b4597709f0facf7cc7b35"), Sha2TestCase("SQL", "UNICODE_CI", 256, + "a7056a455639d1c7deec82ee787db24a0c1878e2792b4597709f0facf7cc7b35"), + Sha2TestCase("SQL", "UNICODE_CI_RTRIM", 256, "a7056a455639d1c7deec82ee787db24a0c1878e2792b4597709f0facf7cc7b35") ) @@ -114,9 +126,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( Sha1TestCase("Spark", "UTF8_BINARY", "85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c"), + Sha1TestCase("Spark", "UTF8_BINARY_RTRIM", "85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c"), Sha1TestCase("Spark", "UTF8_LCASE", "85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c"), + Sha1TestCase("Spark", "UTF8_LCASE_RTRIM", "85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c"), Sha1TestCase("SQL", "UNICODE", "2064cb643caa8d9e1de12eea7f3e143ca9f8680d"), - Sha1TestCase("SQL", "UNICODE_CI", "2064cb643caa8d9e1de12eea7f3e143ca9f8680d") + Sha1TestCase("SQL", "UNICODE_RTRIM", "2064cb643caa8d9e1de12eea7f3e143ca9f8680d"), + Sha1TestCase("SQL", "UNICODE_CI", "2064cb643caa8d9e1de12eea7f3e143ca9f8680d"), + Sha1TestCase("SQL", "UNICODE_CI_RTRIM", "2064cb643caa8d9e1de12eea7f3e143ca9f8680d") ) // Supported collations @@ -144,9 +160,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( Crc321TestCase("Spark", "UTF8_BINARY", 1557323817), + Crc321TestCase("Spark", "UTF8_BINARY_RTRIM", 1557323817), Crc321TestCase("Spark", "UTF8_LCASE", 1557323817), + Crc321TestCase("Spark", "UTF8_LCASE_RTRIM", 1557323817), Crc321TestCase("SQL", "UNICODE", 1299261525), - Crc321TestCase("SQL", "UNICODE_CI", 1299261525) + Crc321TestCase("SQL", "UNICODE_RTRIM", 1299261525), + Crc321TestCase("SQL", "UNICODE_CI", 1299261525), + Crc321TestCase("SQL", "UNICODE_CI_RTRIM", 1299261525) ) // Supported collations @@ -172,9 +192,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( Murmur3HashTestCase("Spark", "UTF8_BINARY", 228093765), + Murmur3HashTestCase("Spark ", "UTF8_BINARY_RTRIM", 1779328737), Murmur3HashTestCase("Spark", "UTF8_LCASE", -1928694360), + Murmur3HashTestCase("Spark ", "UTF8_LCASE_RTRIM", -1928694360), Murmur3HashTestCase("SQL", "UNICODE", -1923567940), - Murmur3HashTestCase("SQL", "UNICODE_CI", 1029527950) + Murmur3HashTestCase("SQL ", "UNICODE_RTRIM", -1923567940), + Murmur3HashTestCase("SQL", "UNICODE_CI", 1029527950), + Murmur3HashTestCase("SQL ", "UNICODE_CI_RTRIM", 1029527950) ) // Supported collations @@ -200,9 +224,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( XxHash64TestCase("Spark", "UTF8_BINARY", -4294468057691064905L), + XxHash64TestCase("Spark ", "UTF8_BINARY_RTRIM", 6480371823304753502L), XxHash64TestCase("Spark", "UTF8_LCASE", -3142112654825786434L), + XxHash64TestCase("Spark ", "UTF8_LCASE_RTRIM", -3142112654825786434L), XxHash64TestCase("SQL", "UNICODE", 5964849564945649886L), - XxHash64TestCase("SQL", "UNICODE_CI", 3732497619779520590L) + XxHash64TestCase("SQL ", "UNICODE_RTRIM", 5964849564945649886L), + XxHash64TestCase("SQL", "UNICODE_CI", 3732497619779520590L), + XxHash64TestCase("SQL ", "UNICODE_CI_RTRIM", 3732497619779520590L) ) // Supported collations @@ -982,6 +1010,11 @@ class CollationSQLExpressionsSuite StringToMapTestCase("1/AX2/BX3/C", "x", "/", "UNICODE_CI", Map("1" -> "A", "2" -> "B", "3" -> "C")) ) + val unsupportedTestCases = Seq( + StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_AI", null), + StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_RTRIM", null), + StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UTF8_BINARY_RTRIM", null), + StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UTF8_LCASE_RTRIM", null)) testCases.foreach(t => { // Unit test. val text = Literal.create(t.text, StringType(t.collation)) @@ -996,6 +1029,31 @@ class CollationSQLExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(dataType)) } }) + // Test unsupported collation. + unsupportedTestCases.foreach(t => { + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> t.collation) { + val query = + s"select str_to_map('${t.text}', '${t.pairDelim}', " + + s"'${t.keyValueDelim}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> ("\"str_to_map('a:1,b:2,c:3' collate " + s"${t.collation}, " + + "'?' collate " + s"${t.collation}, '?' collate ${t.collation})" + "\""), + "paramIndex" -> "first", + "inputSql" -> ("\"'a:1,b:2,c:3' collate " + s"${t.collation}" + "\""), + "inputType" -> ("\"STRING COLLATE " + s"${t.collation}" + "\""), + "requiredType" -> "\"STRING\""), + context = ExpectedContext( + fragment = "str_to_map('a:1,b:2,c:3', '?', '?')", + start = 7, + stop = 41)) + } + }) } test("Support RaiseError misc expression with collation") { @@ -1728,7 +1786,7 @@ class CollationSQLExpressionsSuite UTF8StringModeTestCase("unicode_ci", bufferValuesUTF8String, "b"), UTF8StringModeTestCase("unicode", bufferValuesUTF8String, "a")) - testCasesUTF8String.foreach(t => { + testCasesUTF8String.foreach ( t => { val buffer = new OpenHashMap[AnyRef, Long](5) val myMode = Mode(child = Literal.create("some_column_name", StringType(t.collationId))) t.bufferValues.foreach { case (k, v) => buffer.update(k, v) } @@ -1736,6 +1794,40 @@ class CollationSQLExpressionsSuite }) } + test("Support Mode.eval(buffer) with complex types") { + case class UTF8StringModeTestCase[R]( + collationId: String, + bufferValues: Map[InternalRow, Long], + result: R) + + val bufferValuesUTF8String: Map[Any, Long] = Map( + UTF8String.fromString("a") -> 5L, + UTF8String.fromString("b") -> 4L, + UTF8String.fromString("B") -> 3L, + UTF8String.fromString("d") -> 2L, + UTF8String.fromString("e") -> 1L) + + val bufferValuesComplex = bufferValuesUTF8String.map{ + case (k, v) => (InternalRow.fromSeq(Seq(k, k, k)), v) + } + val testCasesUTF8String = Seq( + UTF8StringModeTestCase("utf8_binary", bufferValuesComplex, "[a,a,a]"), + UTF8StringModeTestCase("UTF8_LCASE", bufferValuesComplex, "[b,b,b]"), + UTF8StringModeTestCase("unicode_ci", bufferValuesComplex, "[b,b,b]"), + UTF8StringModeTestCase("unicode", bufferValuesComplex, "[a,a,a]")) + + testCasesUTF8String.foreach { t => + val buffer = new OpenHashMap[AnyRef, Long](5) + val myMode = Mode(child = Literal.create(null, StructType(Seq( + StructField("f1", StringType(t.collationId), true), + StructField("f2", StringType(t.collationId), true), + StructField("f3", StringType(t.collationId), true) + )))) + t.bufferValues.foreach { case (k, v) => buffer.update(k, v) } + assert(myMode.eval(buffer).toString.toLowerCase() == t.result.toLowerCase()) + } + } + test("Support mode for string expression with collated strings in struct") { case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R) val testCases = Seq( @@ -1756,33 +1848,7 @@ class CollationSQLExpressionsSuite t.collationId + ", f2: INT>) USING parquet") sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd) val query = s"SELECT lower(mode(i).f1) FROM ${tableName}" - if(t.collationId == "UTF8_LCASE" || - t.collationId == "unicode_ci" || - t.collationId == "unicode") { - // Cannot resolve "mode(i)" due to data type mismatch: - // Input to function mode was a complex type with strings collated on non-binary - // collations, which is not yet supported.. SQLSTATE: 42K09; line 1 pos 13; - val params = Seq(("sqlExpr", "\"mode(i)\""), - ("msg", "The input to the function 'mode'" + - " was a type of binary-unstable type that is not currently supported by mode."), - ("hint", "")).toMap - checkError( - exception = intercept[AnalysisException] { - sql(query) - }, - condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", - parameters = params, - queryContext = Array( - ExpectedContext(objectType = "", - objectName = "", - startIndex = 13, - stopIndex = 19, - fragment = "mode(i)") - ) - ) - } else { - checkAnswer(sql(query), Row(t.result)) - } + checkAnswer(sql(query), Row(t.result)) } }) } @@ -1795,47 +1861,21 @@ class CollationSQLExpressionsSuite ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") ) - testCases.foreach(t => { + testCases.foreach { t => val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) => (0L to numRepeats).map(_ => s"named_struct('f1', " + s"named_struct('f2', collate('$elt', '${t.collationId}')), 'f3', 1)").mkString(",") }.mkString(",") - val tableName = s"t_${t.collationId}_mode_nested_struct" + val tableName = s"t_${t.collationId}_mode_nested_struct1" withTable(tableName) { sql(s"CREATE TABLE ${tableName}(i STRUCT, f3: INT>) USING parquet") sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd) val query = s"SELECT lower(mode(i).f1.f2) FROM ${tableName}" - if(t.collationId == "UTF8_LCASE" || - t.collationId == "unicode_ci" || - t.collationId == "unicode") { - // Cannot resolve "mode(i)" due to data type mismatch: - // Input to function mode was a complex type with strings collated on non-binary - // collations, which is not yet supported.. SQLSTATE: 42K09; line 1 pos 13; - val params = Seq(("sqlExpr", "\"mode(i)\""), - ("msg", "The input to the function 'mode' " + - "was a type of binary-unstable type that is not currently supported by mode."), - ("hint", "")).toMap - checkError( - exception = intercept[AnalysisException] { - sql(query) - }, - condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", - parameters = params, - queryContext = Array( - ExpectedContext(objectType = "", - objectName = "", - startIndex = 13, - stopIndex = 19, - fragment = "mode(i)") - ) - ) - } else { - checkAnswer(sql(query), Row(t.result)) - } + checkAnswer(sql(query), Row(t.result)) } - }) + } } test("Support mode for string expression with collated strings in array complex type") { @@ -1846,44 +1886,105 @@ class CollationSQLExpressionsSuite ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") ) - testCases.foreach(t => { + testCases.foreach { t => + val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) => + (0L to numRepeats).map(_ => s"array(named_struct('f2', " + + s"collate('$elt', '${t.collationId}'), 'f3', 1))").mkString(",") + }.mkString(",") + + val tableName = s"t_${t.collationId}_mode_nested_struct2" + withTable(tableName) { + sql(s"CREATE TABLE ${tableName}(" + + s"i ARRAY< STRUCT>)" + + s" USING parquet") + sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd) + val query = s"SELECT lower(element_at(mode(i).f2, 1)) FROM ${tableName}" + checkAnswer(sql(query), Row(t.result)) + } + } + } + + test("Support mode for string expression with collated strings in 3D array type") { + case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R) + val testCases = Seq( + ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), + ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") + ) + testCases.foreach { t => + val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) => + (0L to numRepeats).map(_ => + s"array(array(array(collate('$elt', '${t.collationId}'))))").mkString(",") + }.mkString(",") + + val tableName = s"t_${t.collationId}_mode_nested_3d_array" + withTable(tableName) { + sql(s"CREATE TABLE ${tableName}(i ARRAY>>) USING parquet") + sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd) + val query = s"SELECT lower(" + + s"element_at(element_at(element_at(mode(i),1),1),1)) FROM ${tableName}" + checkAnswer(sql(query), Row(t.result)) + } + } + } + + test("Support mode for string expression with collated complex type - Highly nested") { + case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R) + val testCases = Seq( + ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), + ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") + ) + testCases.foreach { t => val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) => (0L to numRepeats).map(_ => s"array(named_struct('s1', named_struct('a2', " + s"array(collate('$elt', '${t.collationId}'))), 'f3', 1))").mkString(",") }.mkString(",") - val tableName = s"t_${t.collationId}_mode_nested_struct" + val tableName = s"t_${t.collationId}_mode_highly_nested_struct" withTable(tableName) { sql(s"CREATE TABLE ${tableName}(" + s"i ARRAY>, f3: INT>>)" + s" USING parquet") sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd) val query = s"SELECT lower(element_at(element_at(mode(i), 1).s1.a2, 1)) FROM ${tableName}" - if(t.collationId == "UTF8_LCASE" || - t.collationId == "unicode_ci" || t.collationId == "unicode") { - val params = Seq(("sqlExpr", "\"mode(i)\""), - ("msg", "The input to the function 'mode' was a type" + - " of binary-unstable type that is not currently supported by mode."), - ("hint", "")).toMap - checkError( - exception = intercept[AnalysisException] { - sql(query) - }, - condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", - parameters = params, - queryContext = Array( - ExpectedContext(objectType = "", - objectName = "", - startIndex = 35, - stopIndex = 41, - fragment = "mode(i)") - ) - ) - } else { + checkAnswer(sql(query), Row(t.result)) - } } - }) + } + } + + test("Support mode for string expression with collated complex type - nested map") { + case class ModeTestCase(collationId: String, bufferValues: Map[String, Long], result: String) + Seq( + ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{a -> 1}"), + ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{a -> 1}"), + ModeTestCase("utf8_lcase", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{b -> 1}"), + ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{b -> 1}") + ).foreach { t1 => + def getValuesToAdd(t: ModeTestCase): String = { + val valuesToAdd = t.bufferValues.map { + case (elt, numRepeats) => + (0L to numRepeats).map(i => + s"named_struct('m1', map(collate('$elt', '${t.collationId}'), 1))" + ).mkString(",") + }.mkString(",") + valuesToAdd + } + val tableName = s"t_${t1.collationId}_mode_nested_map_struct1" + withTable(tableName) { + sql(s"CREATE TABLE ${tableName}(" + + s"i STRUCT>) USING parquet") + sql(s"INSERT INTO ${tableName} VALUES ${getValuesToAdd(t1)}") + val query = "SELECT lower(cast(mode(i).m1 as string))" + + s" FROM ${tableName}" + val queryResult = sql(query) + checkAnswer(queryResult, Row(t1.result)) + } + } } test("SPARK-48430: Map value extraction with collations") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 6804411d470b9..fe9872ddaf575 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -98,6 +98,7 @@ class CollationStringExpressionsSuite SplitPartTestCase("1a2", "A", 2, "UTF8_LCASE", "2"), SplitPartTestCase("1a2", "A", 2, "UNICODE_CI", "2") ) + val unsupportedTestCase = SplitPartTestCase("1a2", "a", 2, "UNICODE_AI", "2") testCases.foreach(t => { // Unit test. val str = Literal.create(t.str, StringType(t.collation)) @@ -111,6 +112,26 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select split_part('${unsupportedTestCase.str}', '${unsupportedTestCase.delimiter}', " + + s"${unsupportedTestCase.partNum})" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"split_part('1a2' collate UNICODE_AI, 'a' collate UNICODE_AI, 2)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'1a2' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "split_part('1a2', 'a', 2)", start = 7, stop = 31) + ) + } } test("Support `StringSplitSQL` string expression with collation") { @@ -166,6 +187,7 @@ class CollationStringExpressionsSuite ContainsTestCase("abcde", "FGH", "UTF8_LCASE", false), ContainsTestCase("abcde", "BCD", "UNICODE_CI", true) ) + val unsupportedTestCase = ContainsTestCase("abcde", "A", "UNICODE_AI", false) testCases.foreach(t => { // Unit test. val left = Literal.create(t.left, StringType(t.collation)) @@ -178,6 +200,25 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(BooleanType)) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select contains('${unsupportedTestCase.left}', '${unsupportedTestCase.right}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"contains('abcde' collate UNICODE_AI, 'A' collate UNICODE_AI)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'abcde' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "contains('abcde', 'A')", start = 7, stop = 28) + ) + } } test("Support `SubstringIndex` expression with collation") { @@ -194,6 +235,7 @@ class CollationStringExpressionsSuite SubstringIndexTestCase("aaaaaaaaaa", "aa", 2, "UNICODE", "a"), SubstringIndexTestCase("wwwmapacheMorg", "M", -2, "UNICODE_CI", "apacheMorg") ) + val unsupportedTestCase = SubstringIndexTestCase("abacde", "a", 2, "UNICODE_AI", "cde") testCases.foreach(t => { // Unit test. val strExpr = Literal.create(t.strExpr, StringType(t.collation)) @@ -207,6 +249,29 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select substring_index('${unsupportedTestCase.strExpr}', " + + s"'${unsupportedTestCase.delimExpr}', ${unsupportedTestCase.countExpr})" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> ("\"substring_index('abacde' collate UNICODE_AI, " + + "'a' collate UNICODE_AI, 2)\""), + "paramIndex" -> "first", + "inputSql" -> "\"'abacde' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext( + fragment = "substring_index('abacde', 'a', 2)", + start = 7, + stop = 39)) + } } test("Support `StringInStr` string expression with collation") { @@ -219,6 +284,7 @@ class CollationStringExpressionsSuite StringInStrTestCase("test大千世界X大千世界", "界x", "UNICODE_CI", 8), StringInStrTestCase("abİo12", "i̇o", "UNICODE_CI", 3) ) + val unsupportedTestCase = StringInStrTestCase("a", "abcde", "UNICODE_AI", 0) testCases.foreach(t => { // Unit test. val str = Literal.create(t.str, StringType(t.collation)) @@ -231,6 +297,25 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(IntegerType)) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select instr('${unsupportedTestCase.str}', '${unsupportedTestCase.substr}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"instr('a' collate UNICODE_AI, 'abcde' collate UNICODE_AI)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'a' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "instr('a', 'abcde')", start = 7, stop = 25) + ) + } } test("Support `FindInSet` string expression with collation") { @@ -264,6 +349,7 @@ class CollationStringExpressionsSuite StartsWithTestCase("abcde", "FGH", "UTF8_LCASE", false), StartsWithTestCase("abcde", "ABC", "UNICODE_CI", true) ) + val unsupportedTestCase = StartsWithTestCase("abcde", "A", "UNICODE_AI", false) testCases.foreach(t => { // Unit test. val left = Literal.create(t.left, StringType(t.collation)) @@ -276,6 +362,25 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(BooleanType)) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select startswith('${unsupportedTestCase.left}', '${unsupportedTestCase.right}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"startswith('abcde' collate UNICODE_AI, 'A' collate UNICODE_AI)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'abcde' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "startswith('abcde', 'A')", start = 7, stop = 30) + ) + } } test("Support `StringTranslate` string expression with collation") { @@ -291,6 +396,7 @@ class CollationStringExpressionsSuite StringTranslateTestCase("Translate", "Rn", "\u0000\u0000", "UNICODE", "Traslate"), StringTranslateTestCase("Translate", "Rn", "1234", "UNICODE_CI", "T1a2slate") ) + val unsupportedTestCase = StringTranslateTestCase("ABC", "AB", "12", "UNICODE_AI", "12C") testCases.foreach(t => { // Unit test. val srcExpr = Literal.create(t.srcExpr, StringType(t.collation)) @@ -304,6 +410,27 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select translate('${unsupportedTestCase.srcExpr}', " + + s"'${unsupportedTestCase.matchingExpr}', '${unsupportedTestCase.replaceExpr}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> ("\"translate('ABC' collate UNICODE_AI, 'AB' collate UNICODE_AI, " + + "'12' collate UNICODE_AI)\""), + "paramIndex" -> "first", + "inputSql" -> "\"'ABC' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "translate('ABC', 'AB', '12')", start = 7, stop = 34) + ) + } } test("Support `StringReplace` string expression with collation") { @@ -321,6 +448,7 @@ class CollationStringExpressionsSuite StringReplaceTestCase("abi̇o12i̇o", "İo", "yy", "UNICODE_CI", "abyy12yy"), StringReplaceTestCase("abİo12i̇o", "i̇o", "xx", "UNICODE_CI", "abxx12xx") ) + val unsupportedTestCase = StringReplaceTestCase("abcde", "A", "B", "UNICODE_AI", "abcde") testCases.foreach(t => { // Unit test. val srcExpr = Literal.create(t.srcExpr, StringType(t.collation)) @@ -334,6 +462,27 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select replace('${unsupportedTestCase.srcExpr}', '${unsupportedTestCase.searchExpr}', " + + s"'${unsupportedTestCase.replaceExpr}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> ("\"replace('abcde' collate UNICODE_AI, 'A' collate UNICODE_AI, " + + "'B' collate UNICODE_AI)\""), + "paramIndex" -> "first", + "inputSql" -> "\"'abcde' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "replace('abcde', 'A', 'B')", start = 7, stop = 32) + ) + } } test("Support `EndsWith` string expression with collation") { @@ -344,6 +493,7 @@ class CollationStringExpressionsSuite EndsWithTestCase("abcde", "FGH", "UTF8_LCASE", false), EndsWithTestCase("abcde", "CDE", "UNICODE_CI", true) ) + val unsupportedTestCase = EndsWithTestCase("abcde", "A", "UNICODE_AI", false) testCases.foreach(t => { // Unit test. val left = Literal.create(t.left, StringType(t.collation)) @@ -355,6 +505,25 @@ class CollationStringExpressionsSuite checkAnswer(sql(query), Row(t.result)) assert(sql(query).schema.fields.head.dataType.sameType(BooleanType)) } + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select endswith('${unsupportedTestCase.left}', '${unsupportedTestCase.right}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"endswith('abcde' collate UNICODE_AI, 'A' collate UNICODE_AI)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'abcde' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "endswith('abcde', 'A')", start = 7, stop = 28) + ) + } }) } @@ -1097,6 +1266,7 @@ class CollationStringExpressionsSuite StringLocateTestCase("aa", "Aaads", 0, "UNICODE_CI", 0), StringLocateTestCase("界x", "test大千世界X大千世界", 1, "UNICODE_CI", 8) ) + val unsupportedTestCase = StringLocateTestCase("aa", "Aaads", 0, "UNICODE_AI", 1) testCases.foreach(t => { // Unit test. val substr = Literal.create(t.substr, StringType(t.collation)) @@ -1110,6 +1280,26 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(IntegerType)) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select locate('${unsupportedTestCase.substr}', '${unsupportedTestCase.str}', " + + s"${unsupportedTestCase.start})" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"locate('aa' collate UNICODE_AI, 'Aaads' collate UNICODE_AI, 0)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'aa' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "locate('aa', 'Aaads', 0)", start = 7, stop = 30) + ) + } } test("Support `StringTrimLeft` string expression with collation") { @@ -1124,6 +1314,7 @@ class CollationStringExpressionsSuite StringTrimLeftTestCase("xxasdxx", Some("y"), "UNICODE", "xxasdxx"), StringTrimLeftTestCase(" asd ", None, "UNICODE_CI", "asd ") ) + val unsupportedTestCase = StringTrimLeftTestCase("xxasdxx", Some("x"), "UNICODE_AI", null) testCases.foreach(t => { // Unit test. val srcStr = Literal.create(t.srcStr, StringType(t.collation)) @@ -1137,6 +1328,25 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val trimString = s"'${unsupportedTestCase.trimStr.get}', " + val query = s"select ltrim($trimString'${unsupportedTestCase.srcStr}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"TRIM(LEADING 'x' collate UNICODE_AI FROM 'xxasdxx' collate UNICODE_AI)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'xxasdxx' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "ltrim('x', 'xxasdxx')", start = 7, stop = 27) + ) + } } test("Support `StringTrimRight` string expression with collation") { @@ -1151,6 +1361,7 @@ class CollationStringExpressionsSuite StringTrimRightTestCase("xxasdxx", Some("y"), "UNICODE", "xxasdxx"), StringTrimRightTestCase(" asd ", None, "UNICODE_CI", " asd") ) + val unsupportedTestCase = StringTrimRightTestCase("xxasdxx", Some("x"), "UNICODE_AI", "xxasd") testCases.foreach(t => { // Unit test. val srcStr = Literal.create(t.srcStr, StringType(t.collation)) @@ -1164,6 +1375,26 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val trimString = s"'${unsupportedTestCase.trimStr.get}', " + val query = s"select rtrim($trimString'${unsupportedTestCase.srcStr}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> ("\"TRIM(TRAILING 'x' collate UNICODE_AI FROM 'xxasdxx'" + + " collate UNICODE_AI)\""), + "paramIndex" -> "first", + "inputSql" -> "\"'xxasdxx' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "rtrim('x', 'xxasdxx')", start = 7, stop = 27) + ) + } } test("Support `StringTrim` string expression with collation") { @@ -1178,6 +1409,7 @@ class CollationStringExpressionsSuite StringTrimTestCase("xxasdxx", Some("y"), "UNICODE", "xxasdxx"), StringTrimTestCase(" asd ", None, "UNICODE_CI", "asd") ) + val unsupportedTestCase = StringTrimTestCase("xxasdxx", Some("x"), "UNICODE_AI", "asd") testCases.foreach(t => { // Unit test. val srcStr = Literal.create(t.srcStr, StringType(t.collation)) @@ -1191,6 +1423,25 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val trimString = s"'${unsupportedTestCase.trimStr.get}', " + val query = s"select trim($trimString'${unsupportedTestCase.srcStr}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"TRIM(BOTH 'x' collate UNICODE_AI FROM 'xxasdxx' collate UNICODE_AI)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'xxasdxx' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "trim('x', 'xxasdxx')", start = 7, stop = 26) + ) + } } test("Support `StringTrimBoth` string expression with collation") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index d5d18b1ab081c..b6da0b169f050 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -44,42 +44,82 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { private val allFileBasedDataSources = collationPreservingSources ++ collationNonPreservingSources test("collate returns proper type") { - Seq("utf8_binary", "utf8_lcase", "unicode", "unicode_ci").foreach { collationName => + Seq( + "utf8_binary", + "utf8_lcase", + "unicode", + "unicode_ci", + "unicode_rtrim_ci", + "utf8_lcase_rtrim", + "utf8_binary_rtrim" + ).foreach { collationName => checkAnswer(sql(s"select 'aaa' collate $collationName"), Row("aaa")) val collationId = CollationFactory.collationNameToId(collationName) - assert(sql(s"select 'aaa' collate $collationName").schema(0).dataType - == StringType(collationId)) + assert( + sql(s"select 'aaa' collate $collationName").schema(0).dataType + == StringType(collationId) + ) } } test("collation name is case insensitive") { - Seq("uTf8_BiNaRy", "utf8_lcase", "uNicOde", "UNICODE_ci").foreach { collationName => + Seq( + "uTf8_BiNaRy", + "utf8_lcase", + "uNicOde", + "UNICODE_ci", + "uNiCoDE_rtRIm_cI", + "UtF8_lCaSE_rtRIM", + "utf8_biNAry_RtRiM" + ).foreach { collationName => checkAnswer(sql(s"select 'aaa' collate $collationName"), Row("aaa")) val collationId = CollationFactory.collationNameToId(collationName) - assert(sql(s"select 'aaa' collate $collationName").schema(0).dataType - == StringType(collationId)) + assert( + sql(s"select 'aaa' collate $collationName").schema(0).dataType + == StringType(collationId) + ) } } test("collation expression returns name of collation") { - Seq("utf8_binary", "utf8_lcase", "unicode", "unicode_ci").foreach { collationName => + Seq( + "utf8_binary", + "utf8_lcase", + "unicode", + "unicode_ci", + "unicode_ci_rtrim", + "utf8_lcase_rtrim", + "utf8_binary_rtrim" + ).foreach { collationName => checkAnswer( - sql(s"select collation('aaa' collate $collationName)"), Row(collationName.toUpperCase())) + sql(s"select collation('aaa' collate $collationName)"), + Row(collationName.toUpperCase()) + ) } } test("collate function syntax") { assert(sql(s"select collate('aaa', 'utf8_binary')").schema(0).dataType == StringType("UTF8_BINARY")) + assert(sql(s"select collate('aaa', 'utf8_binary_rtrim')").schema(0).dataType == + StringType("UTF8_BINARY_RTRIM")) assert(sql(s"select collate('aaa', 'utf8_lcase')").schema(0).dataType == StringType("UTF8_LCASE")) + assert(sql(s"select collate('aaa', 'utf8_lcase_rtrim')").schema(0).dataType == + StringType("UTF8_LCASE_RTRIM")) } test("collate function syntax with default collation set") { withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UTF8_LCASE") { - assert(sql(s"select collate('aaa', 'utf8_lcase')").schema(0).dataType == - StringType("UTF8_LCASE")) + assert( + sql(s"select collate('aaa', 'utf8_lcase')").schema(0).dataType == + StringType("UTF8_LCASE") + ) assert(sql(s"select collate('aaa', 'UNICODE')").schema(0).dataType == StringType("UNICODE")) + assert( + sql(s"select collate('aaa', 'UNICODE_RTRIM')").schema(0).dataType == + StringType("UNICODE_RTRIM") + ) } } @@ -162,9 +202,14 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { withTable(tableName) { sql( s""" - |CREATE TABLE $tableName - |(id INT, c1 STRING COLLATE UNICODE, c2 string) - |USING parquet + |CREATE TABLE $tableName ( + | id INT, + | c1 STRING COLLATE UNICODE, + | c2 STRING, + | struct_col STRUCT, + | array_col ARRAY, + | map_col MAP + |) USING parquet |CLUSTERED BY (${bucketColumns.mkString(",")}) |INTO 4 BUCKETS""".stripMargin ) @@ -175,14 +220,20 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { createTable("c2") createTable("id", "c2") - Seq(Seq("c1"), Seq("c1", "id"), Seq("c1", "c2")).foreach { bucketColumns => + val failBucketingColumns = Seq( + Seq("c1"), Seq("c1", "id"), Seq("c1", "c2"), + Seq("struct_col"), Seq("array_col"), Seq("map_col") + ) + + failBucketingColumns.foreach { bucketColumns => checkError( exception = intercept[AnalysisException] { createTable(bucketColumns: _*) }, condition = "INVALID_BUCKET_COLUMN_DATA_TYPE", - parameters = Map("type" -> "\"STRING COLLATE UNICODE\"") - ); + parameters = Map("type" -> ".*STRING COLLATE UNICODE.*"), + matchPVals = true + ) } } @@ -213,14 +264,23 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Seq( ("utf8_binary", "aaa", "AAA", false), ("utf8_binary", "aaa", "aaa", true), + ("utf8_binary_rtrim", "aaa", "AAA", false), + ("utf8_binary_rtrim", "aaa", "aaa ", true), ("utf8_lcase", "aaa", "aaa", true), ("utf8_lcase", "aaa", "AAA", true), ("utf8_lcase", "aaa", "bbb", false), + ("utf8_lcase_rtrim", "aaa", "AAA ", true), + ("utf8_lcase_rtrim", "aaa", "bbb", false), ("unicode", "aaa", "aaa", true), ("unicode", "aaa", "AAA", false), + ("unicode_rtrim", "aaa ", "aaa ", true), + ("unicode_rtrim", "aaa", "AAA", false), ("unicode_CI", "aaa", "aaa", true), ("unicode_CI", "aaa", "AAA", true), - ("unicode_CI", "aaa", "bbb", false) + ("unicode_CI", "aaa", "bbb", false), + ("unicode_CI_rtrim", "aaa", "aaa", true), + ("unicode_CI_rtrim", "aaa ", "AAA ", true), + ("unicode_CI_rtrim", "aaa", "bbb", false) ).foreach { case (collationName, left, right, expected) => checkAnswer( @@ -237,15 +297,19 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ("utf8_binary", "AAA", "aaa", true), ("utf8_binary", "aaa", "aaa", false), ("utf8_binary", "aaa", "BBB", false), + ("utf8_binary_rtrim", "aaa ", "aaa ", false), ("utf8_lcase", "aaa", "aaa", false), ("utf8_lcase", "AAA", "aaa", false), ("utf8_lcase", "aaa", "bbb", true), + ("utf8_lcase_rtrim", "AAA ", "aaa", false), ("unicode", "aaa", "aaa", false), ("unicode", "aaa", "AAA", true), ("unicode", "aaa", "BBB", true), + ("unicode_rtrim", "aaa ", "aaa", false), ("unicode_CI", "aaa", "aaa", false), ("unicode_CI", "aaa", "AAA", false), - ("unicode_CI", "aaa", "bbb", true) + ("unicode_CI", "aaa", "bbb", true), + ("unicode_CI_rtrim", "aaa ", "aaa", false) ).foreach { case (collationName, left, right, expected) => checkAnswer( @@ -308,18 +372,22 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { test("aggregates count respects collation") { Seq( + ("utf8_binary_rtrim", Seq("aaa", "aaa "), Seq(Row(2, "aaa"))), ("utf8_binary", Seq("AAA", "aaa"), Seq(Row(1, "AAA"), Row(1, "aaa"))), ("utf8_binary", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))), ("utf8_binary", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))), ("utf8_lcase", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))), ("utf8_lcase", Seq("AAA", "aaa"), Seq(Row(2, "AAA"))), ("utf8_lcase", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))), + ("utf8_lcase_rtrim", Seq("aaa", "AAA "), Seq(Row(2, "aaa"))), ("unicode", Seq("AAA", "aaa"), Seq(Row(1, "AAA"), Row(1, "aaa"))), ("unicode", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))), ("unicode", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))), + ("unicode_rtrim", Seq("aaa", "aaa "), Seq(Row(2, "aaa"))), ("unicode_CI", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))), ("unicode_CI", Seq("AAA", "aaa"), Seq(Row(2, "AAA"))), - ("unicode_CI", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))) + ("unicode_CI", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))), + ("unicode_CI_rtrim", Seq("aaa", "AAA "), Seq(Row(2, "aaa"))) ).foreach { case (collationName: String, input: Seq[String], expected: Seq[Row]) => checkAnswer(sql( @@ -1054,11 +1122,218 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } - for (collation <- Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI", "")) { + test("Check order by on table with collated string column") { + val tableName = "t" + Seq( + // (collationName, data, expResult) + ( + "", // non-collated + Seq((5, "bbb"), (3, "a"), (1, "A"), (4, "aaaa"), (6, "cc"), (2, "BbB")), + Seq(1, 2, 3, 4, 5, 6) + ), + ( + "UTF8_BINARY", + Seq((5, "bbb"), (3, "a"), (1, "A"), (4, "aaaa"), (6, "cc"), (2, "BbB")), + Seq(1, 2, 3, 4, 5, 6) + ), + ( + "UTF8_LCASE", + Seq((2, "bbb"), (1, "a"), (1, "A"), (1, "aaaa"), (3, "cc"), (2, "BbB")), + Seq(1, 1, 1, 2, 2, 3) + ), + ( + "UNICODE", + Seq((4, "bbb"), (1, "a"), (2, "A"), (3, "aaaa"), (6, "cc"), (5, "BbB")), + Seq(1, 2, 3, 4, 5, 6) + ), + ( + "UNICODE_CI", + Seq((2, "bbb"), (1, "a"), (1, "A"), (1, "aaaa"), (3, "cc"), (2, "BbB")), + Seq(1, 1, 1, 2, 2, 3) + ) + ).foreach { + case (collationName, data, expResult) => + val collationSetup = if (collationName.isEmpty) "" else "collate " + collationName + withTable(tableName) { + sql(s"create table $tableName (c1 integer, c2 string $collationSetup)") + data.foreach { + case (c1, c2) => + sql(s"insert into $tableName values ($c1, '$c2')") + } + checkAnswer(sql(s"select c1 from $tableName order by c2"), expResult.map(Row(_))) + } + } + } + + test("Check order by on StructType") { + Seq( + // (collationName, data, expResult) + ( + "", // non-collated + Seq((5, "b", "A"), (3, "aa", "A"), (6, "b", "B"), (2, "A", "c"), (1, "A", "D"), + (4, "aa", "B")), + Seq(1, 2, 3, 4, 5, 6) + ), + ( + "UTF8_BINARY", + Seq((5, "b", "A"), (3, "aa", "A"), (6, "b", "B"), (2, "A", "c"), (1, "A", "D"), + (4, "aa", "B")), + Seq(1, 2, 3, 4, 5, 6) + ), + ( + "UTF8_LCASE", + Seq((3, "A", "C"), (2, "A", "b"), (2, "a", "b"), (4, "B", "c"), (1, "a", "a"), + (5, "b", "d")), + Seq(1, 2, 2, 3, 4, 5) + ), + ( + "UNICODE", + Seq((4, "A", "C"), (3, "A", "b"), (2, "a", "b"), (5, "b", "c"), (1, "a", "a"), + (6, "b", "d")), + Seq(1, 2, 3, 4, 5, 6) + ), + ( + "UNICODE_CI", + Seq((3, "A", "C"), (2, "A", "b"), (2, "a", "b"), (4, "B", "c"), (1, "a", "a"), + (5, "b", "d")), + Seq(1, 2, 2, 3, 4, 5) + ) + ).foreach { + case (collationName, data, expResult) => + val collationSetup = if (collationName.isEmpty) "" else "collate " + collationName + val tableName = "t" + withTable(tableName) { + sql(s"create table $tableName (c1 integer, c2 struct<" + + s"s1: string $collationSetup," + + s"s2: string $collationSetup>)") + data.foreach { + case (c1, s1, s2) => + sql(s"insert into $tableName values ($c1, struct('$s1', '$s2'))") + } + checkAnswer(sql(s"select c1 from $tableName order by c2"), expResult.map(Row(_))) + } + } + } + + test("Check order by on StructType with few collated fields") { + val data = Seq( + (2, "b", "a", "a", "a", "a"), + (4, "b", "b", "B", "a", "a"), + (1, "a", "a", "a", "a", "a"), + (6, "b", "b", "b", "B", "B"), + (3, "b", "b", "a", "a", "a"), + (5, "b", "b", "b", "B", "a")) + val tableName = "t" + withTable(tableName) { + sql(s"create table $tableName (c1 integer, c2 struct<" + + s"s1: string, " + + s"s2: string collate UTF8_BINARY, " + + s"s3: string collate UTF8_LCASE, " + + s"s4: string collate UNICODE, " + + s"s5: string collate UNICODE_CI>)") + data.foreach { + case (order, s1, s2, s3, s4, s5) => + sql(s"insert into $tableName values ($order, struct('$s1', '$s2', '$s3', '$s4', '$s5'))") + } + val expResult = Seq(1, 2, 3, 4, 5, 6) + checkAnswer(sql(s"select c1 from $tableName order by c2"), expResult.map(Row(_))) + } + } + + test("Check order by on ArrayType with collated strings") { + Seq( + // (collationName, order, data) + ( + "", + Seq((3, Seq("b", "Aa", "c")), (2, Seq("A", "b")), (1, Seq("A")), (2, Seq("A", "b"))), + Seq(1, 2, 2, 3) + ), + ( + "UTF8_BINARY", + Seq((3, Seq("b", "Aa", "c")), (2, Seq("A", "b")), (1, Seq("A")), (2, Seq("A", "b"))), + Seq(1, 2, 2, 3) + ), + ( + "UTF8_LCASE", + Seq((4, Seq("B", "a")), (4, Seq("b", "A")), (2, Seq("aa")), (1, Seq("A")), + (5, Seq("b", "e")), (3, Seq("b"))), + Seq(1, 2, 3, 4, 4, 5) + ), + ( + "UNICODE", + Seq((5, Seq("b", "C")), (4, Seq("b", "AA")), (1, Seq("a")), (4, Seq("b", "AA")), + (3, Seq("b")), (2, Seq("A", "a"))), + Seq(1, 2, 3, 4, 4, 5) + ), + ( + "UNICODE_CI", + Seq((4, Seq("B", "a")), (4, Seq("b", "A")), (2, Seq("aa")), (1, Seq("A")), + (5, Seq("b", "e")), (3, Seq("b"))), + Seq(1, 2, 3, 4, 4, 5) + ) + ).foreach { + case (collationName, dataWithOrder, expResult) => + val collationSetup = if (collationName.isEmpty) "" else "collate " + collationName + val tableName1 = "t1" + val tableName2 = "t2" + withTable(tableName1, tableName2) { + sql(s"create table $tableName1 (c1 integer, c2 array)") + sql(s"create table $tableName2 (c1 integer," + + s" c2 struct>)") + dataWithOrder.foreach { + case (order, data) => + val arrayData = data.map(d => s"'$d'").mkString(", ") + sql(s"insert into $tableName1 values ($order, array($arrayData))") + sql(s"insert into $tableName2 values ($order, struct(array($arrayData)))") + } + checkAnswer(sql(s"select c1 from $tableName1 order by c2"), expResult.map(Row(_))) + checkAnswer(sql(s"select c1 from $tableName2 order by c2"), expResult.map(Row(_))) + } + } + } + + test("Check order by on StructType with different types containing collated strings") { + val data = Seq( + (5, ("b", Seq(("b", "B", "a"), ("a", "a", "a")), "a")), + (2, ("b", Seq(("a", "a", "a")), "a")), + (2, ("b", Seq(("a", "a", "a")), "a")), + (4, ("b", Seq(("b", "a", "a")), "a")), + (3, ("b", Seq(("a", "a", "a"), ("a", "a", "a")), "a")), + (5, ("b", Seq(("b", "B", "a")), "a")), + (4, ("b", Seq(("b", "a", "a")), "a")), + (6, ("b", Seq(("b", "b", "B")), "A")), + (5, ("b", Seq(("b", "b", "a")), "a")), + (1, ("a", Seq(("a", "a", "a")), "a")), + (7, ("b", Seq(("b", "b", "B")), "b")), + (6, ("b", Seq(("b", "b", "B")), "a")), + (5, ("b", Seq(("b", "b", "a")), "a")) + ) + val tableName = "t" + withTable(tableName) { + sql(s"create table $tableName " + + s"(c1 integer," + + s"c2 string," + + s"c3 array>," + + s"c4 string collate UNICODE_CI)") + data.foreach { + case (c1, (c2, c3, c4)) => + val c3String = c3.map { case (f1, f2, f3) => s"struct('$f1', '$f2', '$f3')"} + .mkString(", ") + sql(s"insert into $tableName values ($c1, '$c2', array($c3String), '$c4')") + } + val expResult = Seq(1, 2, 2, 3, 4, 4, 5, 5, 5, 5, 6, 6, 7) + checkAnswer(sql(s"select c1 from $tableName order by c2, c3, c4"), expResult.map(Row(_))) + } + } + + for (collation <- Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI", + "UNICODE_CI_RTRIM", "")) { for (codeGen <- Seq("NO_CODEGEN", "CODEGEN_ONLY")) { val collationSetup = if (collation.isEmpty) "" else " COLLATE " + collation val supportsBinaryEquality = collation.isEmpty || collation == "UNICODE" || - CollationFactory.fetchCollation(collation).supportsBinaryEquality + CollationFactory.fetchCollation(collation).isUtf8BinaryType test(s"Group by on map containing$collationSetup strings ($codeGen)") { val tableName = "t" @@ -1247,21 +1522,23 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { val t1 = "T_1" val t2 = "T_2" - case class HashJoinTestCase[R](collation: String, result: R) + case class HashJoinTestCase[R](collation: String, data1: String, data2: String, result: R) val testCases = Seq( - HashJoinTestCase("UTF8_BINARY", Seq(Row("aa", 1, "aa", 2))), - HashJoinTestCase("UTF8_LCASE", Seq(Row("aa", 1, "AA", 2), Row("aa", 1, "aa", 2))), - HashJoinTestCase("UNICODE", Seq(Row("aa", 1, "aa", 2))), - HashJoinTestCase("UNICODE_CI", Seq(Row("aa", 1, "AA", 2), Row("aa", 1, "aa", 2))) + HashJoinTestCase("UTF8_BINARY", "aa", "AA", Seq(Row("aa", 1, "aa", 2))), + HashJoinTestCase("UTF8_LCASE", "aa", "AA", Seq(Row("aa", 1, "AA", 2), Row("aa", 1, "aa", 2))), + HashJoinTestCase("UNICODE", "aa", "AA", Seq(Row("aa", 1, "aa", 2))), + HashJoinTestCase("UNICODE_CI", "aa", "AA", Seq(Row("aa", 1, "AA", 2), Row("aa", 1, "aa", 2))), + HashJoinTestCase("UNICODE_CI_RTRIM", "aa", "AA ", Seq(Row("aa", 1, "AA ", 2), + Row("aa", 1, "aa", 2))) ) testCases.foreach(t => { withTable(t1, t2) { sql(s"CREATE TABLE $t1 (x STRING COLLATE ${t.collation}, i int) USING PARQUET") - sql(s"INSERT INTO $t1 VALUES ('aa', 1)") + sql(s"INSERT INTO $t1 VALUES ('${t.data1}', 1)") sql(s"CREATE TABLE $t2 (y STRING COLLATE ${t.collation}, j int) USING PARQUET") - sql(s"INSERT INTO $t2 VALUES ('AA', 2), ('aa', 2)") + sql(s"INSERT INTO $t2 VALUES ('${t.data2}', 2), ('${t.data1}', 2)") val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") checkAnswer(df, t.result) @@ -1281,7 +1558,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { assert(collectFirst(queryPlan) { case b: HashJoin => b.leftKeys.head }.head.isInstanceOf[CollationKey]) @@ -1298,25 +1575,27 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { val t1 = "T_1" val t2 = "T_2" - case class HashJoinTestCase[R](collation: String, result: R) + case class HashJoinTestCase[R](collation: String, data1: String, data2: String, result: R) val testCases = Seq( - HashJoinTestCase("UTF8_BINARY", + HashJoinTestCase("UTF8_BINARY", "aa", "AA", Seq(Row(Seq("aa"), 1, Seq("aa"), 2))), - HashJoinTestCase("UTF8_LCASE", + HashJoinTestCase("UTF8_LCASE", "aa", "AA", Seq(Row(Seq("aa"), 1, Seq("AA"), 2), Row(Seq("aa"), 1, Seq("aa"), 2))), - HashJoinTestCase("UNICODE", + HashJoinTestCase("UNICODE", "aa", "AA", Seq(Row(Seq("aa"), 1, Seq("aa"), 2))), - HashJoinTestCase("UNICODE_CI", - Seq(Row(Seq("aa"), 1, Seq("AA"), 2), Row(Seq("aa"), 1, Seq("aa"), 2))) + HashJoinTestCase("UNICODE_CI", "aa", "AA", + Seq(Row(Seq("aa"), 1, Seq("AA"), 2), Row(Seq("aa"), 1, Seq("aa"), 2))), + HashJoinTestCase("UNICODE_CI_RTRIM", "aa", "AA ", + Seq(Row(Seq("aa"), 1, Seq("AA "), 2), Row(Seq("aa"), 1, Seq("aa"), 2))) ) testCases.foreach(t => { withTable(t1, t2) { sql(s"CREATE TABLE $t1 (x ARRAY, i int) USING PARQUET") - sql(s"INSERT INTO $t1 VALUES (array('aa'), 1)") + sql(s"INSERT INTO $t1 VALUES (array('${t.data1}'), 1)") sql(s"CREATE TABLE $t2 (y ARRAY, j int) USING PARQUET") - sql(s"INSERT INTO $t2 VALUES (array('AA'), 2), (array('aa'), 2)") + sql(s"INSERT INTO $t2 VALUES (array('${t.data2}'), 2), (array('${t.data1}'), 2)") val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") checkAnswer(df, t.result) @@ -1336,7 +1615,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { assert(collectFirst(queryPlan) { case b: BroadcastHashJoinExec => b.leftKeys.head }.head.asInstanceOf[ArrayTransform].function.asInstanceOf[LambdaFunction]. @@ -1354,27 +1633,30 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { val t1 = "T_1" val t2 = "T_2" - case class HashJoinTestCase[R](collation: String, result: R) + case class HashJoinTestCase[R](collation: String, data1: String, data2: String, result: R) val testCases = Seq( - HashJoinTestCase("UTF8_BINARY", + HashJoinTestCase("UTF8_BINARY", "aa", "AA", Seq(Row(Seq(Seq("aa")), 1, Seq(Seq("aa")), 2))), - HashJoinTestCase("UTF8_LCASE", + HashJoinTestCase("UTF8_LCASE", "aa", "AA", Seq(Row(Seq(Seq("aa")), 1, Seq(Seq("AA")), 2), Row(Seq(Seq("aa")), 1, Seq(Seq("aa")), 2))), - HashJoinTestCase("UNICODE", + HashJoinTestCase("UNICODE", "aa", "AA", Seq(Row(Seq(Seq("aa")), 1, Seq(Seq("aa")), 2))), - HashJoinTestCase("UNICODE_CI", - Seq(Row(Seq(Seq("aa")), 1, Seq(Seq("AA")), 2), Row(Seq(Seq("aa")), 1, Seq(Seq("aa")), 2))) + HashJoinTestCase("UNICODE_CI", "aa", "AA", + Seq(Row(Seq(Seq("aa")), 1, Seq(Seq("AA")), 2), Row(Seq(Seq("aa")), 1, Seq(Seq("aa")), 2))), + HashJoinTestCase("UNICODE_CI_RTRIM", "aa", "AA ", + Seq(Row(Seq(Seq("aa")), 1, Seq(Seq("AA ")), 2), Row(Seq(Seq("aa")), 1, Seq(Seq("aa")), 2))) ) testCases.foreach(t => { withTable(t1, t2) { sql(s"CREATE TABLE $t1 (x ARRAY>, i int) USING " + s"PARQUET") - sql(s"INSERT INTO $t1 VALUES (array(array('aa')), 1)") + sql(s"INSERT INTO $t1 VALUES (array(array('${t.data1}')), 1)") sql(s"CREATE TABLE $t2 (y ARRAY>, j int) USING " + s"PARQUET") - sql(s"INSERT INTO $t2 VALUES (array(array('AA')), 2), (array(array('aa')), 2)") + sql(s"INSERT INTO $t2 VALUES (array(array('${t.data2}')), 2)," + + s" (array(array('${t.data1}')), 2)") val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") checkAnswer(df, t.result) @@ -1394,7 +1676,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { assert(collectFirst(queryPlan) { case b: BroadcastHashJoinExec => b.leftKeys.head }.head.asInstanceOf[ArrayTransform].function. @@ -1413,24 +1695,27 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { val t1 = "T_1" val t2 = "T_2" - case class HashJoinTestCase[R](collation: String, result: R) + case class HashJoinTestCase[R](collation: String, data1 : String, data2: String, result: R) val testCases = Seq( - HashJoinTestCase("UTF8_BINARY", + HashJoinTestCase("UTF8_BINARY", "aa", "AA", Seq(Row(Row("aa"), 1, Row("aa"), 2))), - HashJoinTestCase("UTF8_LCASE", + HashJoinTestCase("UTF8_LCASE", "aa", "AA", Seq(Row(Row("aa"), 1, Row("AA"), 2), Row(Row("aa"), 1, Row("aa"), 2))), - HashJoinTestCase("UNICODE", + HashJoinTestCase("UNICODE", "aa", "AA", Seq(Row(Row("aa"), 1, Row("aa"), 2))), - HashJoinTestCase("UNICODE_CI", - Seq(Row(Row("aa"), 1, Row("AA"), 2), Row(Row("aa"), 1, Row("aa"), 2))) + HashJoinTestCase("UNICODE_CI", "aa", "AA", + Seq(Row(Row("aa"), 1, Row("AA"), 2), Row(Row("aa"), 1, Row("aa"), 2))), + HashJoinTestCase("UNICODE_CI_RTRIM", "aa", "AA ", + Seq(Row(Row("aa"), 1, Row("AA "), 2), Row(Row("aa"), 1, Row("aa"), 2))) ) testCases.foreach(t => { withTable(t1, t2) { sql(s"CREATE TABLE $t1 (x STRUCT, i int) USING PARQUET") - sql(s"INSERT INTO $t1 VALUES (named_struct('f', 'aa'), 1)") + sql(s"INSERT INTO $t1 VALUES (named_struct('f', '${t.data1}'), 1)") sql(s"CREATE TABLE $t2 (y STRUCT, j int) USING PARQUET") - sql(s"INSERT INTO $t2 VALUES (named_struct('f', 'AA'), 2), (named_struct('f', 'aa'), 2)") + sql(s"INSERT INTO $t2 VALUES (named_struct('f', '${t.data2}'), 2)," + + s" (named_struct('f', '${t.data1}'), 2)") val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") checkAnswer(df, t.result) @@ -1450,7 +1735,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { assert(queryPlan.toString().contains("collationkey")) } else { assert(!queryPlan.toString().contains("collationkey")) @@ -1463,29 +1748,33 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { val t1 = "T_1" val t2 = "T_2" - case class HashJoinTestCase[R](collation: String, result: R) + case class HashJoinTestCase[R](collation: String, data1: String, data2: String, result: R) val testCases = Seq( - HashJoinTestCase("UTF8_BINARY", + HashJoinTestCase("UTF8_BINARY", "aa", "AA", Seq(Row(Row(Seq(Row("aa"))), 1, Row(Seq(Row("aa"))), 2))), - HashJoinTestCase("UTF8_LCASE", + HashJoinTestCase("UTF8_LCASE", "aa", "AA", Seq(Row(Row(Seq(Row("aa"))), 1, Row(Seq(Row("AA"))), 2), Row(Row(Seq(Row("aa"))), 1, Row(Seq(Row("aa"))), 2))), - HashJoinTestCase("UNICODE", + HashJoinTestCase("UNICODE", "aa", "AA", Seq(Row(Row(Seq(Row("aa"))), 1, Row(Seq(Row("aa"))), 2))), - HashJoinTestCase("UNICODE_CI", + HashJoinTestCase("UNICODE_CI", "aa", "AA", Seq(Row(Row(Seq(Row("aa"))), 1, Row(Seq(Row("AA"))), 2), + Row(Row(Seq(Row("aa"))), 1, Row(Seq(Row("aa"))), 2))), + HashJoinTestCase("UNICODE_CI_RTRIM", "aa", "AA ", + Seq(Row(Row(Seq(Row("aa"))), 1, Row(Seq(Row("AA "))), 2), Row(Row(Seq(Row("aa"))), 1, Row(Seq(Row("aa"))), 2))) ) testCases.foreach(t => { withTable(t1, t2) { sql(s"CREATE TABLE $t1 (x STRUCT>>, " + s"i int) USING PARQUET") - sql(s"INSERT INTO $t1 VALUES (named_struct('f', array(named_struct('f', 'aa'))), 1)") + sql(s"INSERT INTO $t1 VALUES (named_struct('f', array(named_struct('f', '${t.data1}'))), 1)" + ) sql(s"CREATE TABLE $t2 (y STRUCT>>, " + s"j int) USING PARQUET") - sql(s"INSERT INTO $t2 VALUES (named_struct('f', array(named_struct('f', 'AA'))), 2), " + - s"(named_struct('f', array(named_struct('f', 'aa'))), 2)") + sql(s"INSERT INTO $t2 VALUES (named_struct('f', array(named_struct('f', '${t.data2}'))), 2)" + + s", (named_struct('f', array(named_struct('f', '${t.data1}'))), 2)") val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") checkAnswer(df, t.result) @@ -1505,7 +1794,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { assert(queryPlan.toString().contains("collationkey")) } else { assert(!queryPlan.toString().contains("collationkey")) @@ -1566,7 +1855,9 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { HashMultiJoinTestCase("STRING COLLATE UTF8_BINARY", "STRING COLLATE UTF8_LCASE", "'a', 'a', 1", "'a', 'A', 1", Row("a", "a", 1, "a", "A", 1)), HashMultiJoinTestCase("STRING COLLATE UTF8_LCASE", "STRING COLLATE UNICODE_CI", - "'a', 'a', 1", "'A', 'A', 1", Row("a", "a", 1, "A", "A", 1)) + "'a', 'a', 1", "'A', 'A', 1", Row("a", "a", 1, "A", "A", 1)), + HashMultiJoinTestCase("STRING COLLATE UTF8_LCASE", "STRING COLLATE UNICODE_CI_RTRIM", + "'a', 'a', 1", "'A', 'A ', 1", Row("a", "a", 1, "A", "A ", 1)) ) testCases.foreach(t => { @@ -1599,15 +1890,19 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { test("hll sketch aggregate should respect collation") { case class HllSketchAggTestCase[R](c: String, result: R) val testCases = Seq( - HllSketchAggTestCase("UTF8_BINARY", 4), - HllSketchAggTestCase("UTF8_LCASE", 3), - HllSketchAggTestCase("UNICODE", 4), - HllSketchAggTestCase("UNICODE_CI", 3) + HllSketchAggTestCase("UTF8_BINARY", 5), + HllSketchAggTestCase("UTF8_BINARY_RTRIM", 4), + HllSketchAggTestCase("UTF8_LCASE", 4), + HllSketchAggTestCase("UTF8_LCASE_RTRIM", 3), + HllSketchAggTestCase("UNICODE", 5), + HllSketchAggTestCase("UNICODE_RTRIM", 4), + HllSketchAggTestCase("UNICODE_CI", 4), + HllSketchAggTestCase("UNICODE_CI_RTRIM", 3) ) testCases.foreach(t => { withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.c) { val q = "SELECT hll_sketch_estimate(hll_sketch_agg(col)) FROM " + - "VALUES ('a'), ('A'), ('b'), ('b'), ('c') tab(col)" + "VALUES ('a'), ('A'), ('b'), ('b'), ('c'), ('c ') tab(col)" val df = sql(q) checkAnswer(df, Seq(Row(t.result))) } @@ -1661,17 +1956,17 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Row("SYSTEM", "BUILTIN", "UNICODE", "", "", "ACCENT_SENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), Row("SYSTEM", "BUILTIN", "UNICODE_AI", "", "", - "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), - Row("SYSTEM", "BUILTIN", "UNICODE_CI", "", "", "ACCENT_INSENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), + Row("SYSTEM", "BUILTIN", "UNICODE_CI", "", "", + "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), Row("SYSTEM", "BUILTIN", "UNICODE_CI_AI", "", "", "ACCENT_INSENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), Row("SYSTEM", "BUILTIN", "af", "Afrikaans", "", "ACCENT_SENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), Row("SYSTEM", "BUILTIN", "af_AI", "Afrikaans", "", - "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), - Row("SYSTEM", "BUILTIN", "af_CI", "Afrikaans", "", "ACCENT_INSENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), + Row("SYSTEM", "BUILTIN", "af_CI", "Afrikaans", "", + "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), Row("SYSTEM", "BUILTIN", "af_CI_AI", "Afrikaans", "", "ACCENT_INSENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"))) @@ -1683,9 +1978,9 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Seq(Row("SYSTEM", "BUILTIN", "zh_Hant_HKG", "Chinese", "Hong Kong SAR China", "ACCENT_SENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), Row("SYSTEM", "BUILTIN", "zh_Hant_HKG_AI", "Chinese", "Hong Kong SAR China", - "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), - Row("SYSTEM", "BUILTIN", "zh_Hant_HKG_CI", "Chinese", "Hong Kong SAR China", "ACCENT_INSENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), + Row("SYSTEM", "BUILTIN", "zh_Hant_HKG_CI", "Chinese", "Hong Kong SAR China", + "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), Row("SYSTEM", "BUILTIN", "zh_Hant_HKG_CI_AI", "Chinese", "Hong Kong SAR China", "ACCENT_INSENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"))) @@ -1693,9 +1988,9 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Seq(Row("SYSTEM", "BUILTIN", "zh_Hans_SGP", "Chinese", "Singapore", "ACCENT_SENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), Row("SYSTEM", "BUILTIN", "zh_Hans_SGP_AI", "Chinese", "Singapore", - "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), - Row("SYSTEM", "BUILTIN", "zh_Hans_SGP_CI", "Chinese", "Singapore", "ACCENT_INSENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), + Row("SYSTEM", "BUILTIN", "zh_Hans_SGP_CI", "Chinese", "Singapore", + "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), Row("SYSTEM", "BUILTIN", "zh_Hans_SGP_CI_AI", "Chinese", "Singapore", "ACCENT_INSENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"))) @@ -1704,17 +1999,17 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Seq(Row("SYSTEM", "BUILTIN", "en_USA", "English", "United States", "ACCENT_SENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), Row("SYSTEM", "BUILTIN", "en_USA_AI", "English", "United States", - "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), - Row("SYSTEM", "BUILTIN", "en_USA_CI", "English", "United States", "ACCENT_INSENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), + Row("SYSTEM", "BUILTIN", "en_USA_CI", "English", "United States", + "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), Row("SYSTEM", "BUILTIN", "en_USA_CI_AI", "English", "United States", "ACCENT_INSENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"))) checkAnswer(sql("SELECT NAME, LANGUAGE, ACCENT_SENSITIVITY, CASE_SENSITIVITY " + "FROM collations() WHERE COUNTRY = 'United States'"), Seq(Row("en_USA", "English", "ACCENT_SENSITIVE", "CASE_SENSITIVE"), - Row("en_USA_AI", "English", "ACCENT_SENSITIVE", "CASE_INSENSITIVE"), - Row("en_USA_CI", "English", "ACCENT_INSENSITIVE", "CASE_SENSITIVE"), + Row("en_USA_AI", "English", "ACCENT_INSENSITIVE", "CASE_SENSITIVE"), + Row("en_USA_CI", "English", "ACCENT_SENSITIVE", "CASE_INSENSITIVE"), Row("en_USA_CI_AI", "English", "ACCENT_INSENSITIVE", "CASE_INSENSITIVE"))) checkAnswer(sql("SELECT NAME FROM collations() WHERE ICU_VERSION is null"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index 6589282fd3a51..e6907b8656482 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -24,7 +24,8 @@ import java.util.Locale import scala.jdk.CollectionConverters._ -import org.apache.spark.{SparkException, SparkUnsupportedOperationException, SparkUpgradeException} +import org.apache.spark.{SparkException, SparkRuntimeException, + SparkUnsupportedOperationException, SparkUpgradeException} import org.apache.spark.sql.errors.DataTypeErrors.toSQLType import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -234,7 +235,7 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { val schema = new StructType().add("str", StringType) val options = Map("maxCharsPerColumn" -> "2") - val exception = intercept[SparkException] { + val exception = intercept[SparkRuntimeException] { df.select(from_csv($"value", schema, options)).collect() }.getCause.getMessage diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index f16171940df21..47691e1ccd40f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql -import java.lang.reflect.Modifier import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} +import scala.reflect.runtime.universe.runtimeMirror import scala.util.Random import org.apache.spark.{QueryContextType, SPARK_DOC_ROOT, SparkException, SparkRuntimeException} @@ -82,7 +82,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "bucket", "days", "hours", "months", "years", // Datasource v2 partition transformations "product", // Discussed in https://github.com/apache/spark/pull/30745 "unwrap_udt", - "collect_top_k", "timestamp_add", "timestamp_diff" ) @@ -92,10 +91,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { val word_pattern = """\w*""" // Set of DataFrame functions in org.apache.spark.sql.functions - val dataFrameFunctions = functions.getClass - .getDeclaredMethods - .filter(m => Modifier.isPublic(m.getModifiers)) - .map(_.getName) + val dataFrameFunctions = runtimeMirror(getClass.getClassLoader) + .reflect(functions) + .symbol + .typeSignature + .decls + .filter(s => s.isMethod && s.isPublic) + .map(_.name.toString) .toSet .filter(_.matches(word_pattern)) .diff(excludedDataFrameFunctions) @@ -313,6 +315,44 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer(df.select(isnotnull(col("a"))), Seq(Row(false))) } + test("nullif function") { + Seq(true, false).foreach { alwaysInlineCommonExpr => + withSQLConf(SQLConf.ALWAYS_INLINE_COMMON_EXPR.key -> alwaysInlineCommonExpr.toString) { + Seq( + "SELECT NULLIF(1, 1)" -> Seq(Row(null)), + "SELECT NULLIF(1, 2)" -> Seq(Row(1)), + "SELECT NULLIF(NULL, 1)" -> Seq(Row(null)), + "SELECT NULLIF(1, NULL)" -> Seq(Row(1)), + "SELECT NULLIF(NULL, NULL)" -> Seq(Row(null)), + "SELECT NULLIF('abc', 'abc')" -> Seq(Row(null)), + "SELECT NULLIF('abc', 'xyz')" -> Seq(Row("abc")), + "SELECT NULLIF(id, 1) " + + "FROM range(10) " + + "GROUP BY NULLIF(id, 1)" -> Seq(Row(null), Row(2), Row(3), Row(4), Row(5), Row(6), + Row(7), Row(8), Row(9), Row(0)), + "SELECT NULLIF(id, 1), COUNT(*)" + + "FROM range(10) " + + "GROUP BY NULLIF(id, 1) " + + "HAVING COUNT(*) > 1" -> Seq.empty[Row] + ).foreach { + case (sqlText, expected) => checkAnswer(sql(sqlText), expected) + } + + checkError( + exception = intercept[AnalysisException] { + sql("SELECT NULLIF(id, 1), COUNT(*) " + + "FROM range(10) " + + "GROUP BY NULLIF(id, 2)") + }, + condition = "MISSING_AGGREGATION", + parameters = Map( + "expression" -> "\"id\"", + "expressionAnyValue" -> "\"any_value(id)\"") + ) + } + } + } + test("equal_null function") { val df = Seq[(Integer, Integer)]((null, 8)).toDF("a", "b") checkAnswer(df.selectExpr("equal_null(a, b)"), Seq(Row(false))) @@ -322,15 +362,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer(df.select(equal_null(col("a"), col("a"))), Seq(Row(true))) } - test("nullif function") { - val df = Seq((5, 8)).toDF("a", "b") - checkAnswer(df.selectExpr("nullif(5, 8)"), Seq(Row(5))) - checkAnswer(df.select(nullif(lit(5), lit(8))), Seq(Row(5))) - - checkAnswer(df.selectExpr("nullif(a, a)"), Seq(Row(null))) - checkAnswer(df.select(nullif(lit(5), lit(5))), Seq(Row(null))) - } - test("nullifzero function") { withTable("t") { // Here we exercise a non-nullable, non-foldable column. @@ -409,6 +440,110 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer(df.select(nvl2(col("b"), col("a"), col("c"))), Seq(Row(null))) } + test("randstr function") { + withTable("t") { + sql("create table t(col int not null) using csv") + sql("insert into t values (0)") + val df = sql("select col from t") + checkAnswer( + df.select(randstr(lit(5), lit(0)).alias("x")).select(length(col("x"))), + Seq(Row(5))) + // The random seed is optional. + checkAnswer( + df.select(randstr(lit(5)).alias("x")).select(length(col("x"))), + Seq(Row(5))) + } + // Here we exercise some error cases. + val df = Seq((0)).toDF("a") + var expr = randstr(lit(10), lit("a")) + checkError( + intercept[AnalysisException](df.select(expr)), + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"randstr(10, a)\"", + "paramIndex" -> "second", + "inputSql" -> "\"a\"", + "inputType" -> "\"STRING\"", + "requiredType" -> "INT or SMALLINT"), + context = ExpectedContext( + contextType = QueryContextType.DataFrame, + fragment = "randstr", + objectType = "", + objectName = "", + callSitePattern = "", + startIndex = 0, + stopIndex = 0)) + expr = randstr(col("a"), lit(10)) + checkError( + intercept[AnalysisException](df.select(expr)), + condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + parameters = Map( + "inputName" -> "length", + "inputType" -> "INT or SMALLINT", + "inputExpr" -> "\"a\"", + "sqlExpr" -> "\"randstr(a, 10)\""), + context = ExpectedContext( + contextType = QueryContextType.DataFrame, + fragment = "randstr", + objectType = "", + objectName = "", + callSitePattern = "", + startIndex = 0, + stopIndex = 0)) + } + + test("uniform function") { + withTable("t") { + sql("create table t(col int not null) using csv") + sql("insert into t values (0)") + val df = sql("select col from t") + checkAnswer( + df.select(uniform(lit(10), lit(20), lit(0)).alias("x")).selectExpr("x > 5"), + Seq(Row(true))) + // The random seed is optional. + checkAnswer( + df.select(uniform(lit(10), lit(20)).alias("x")).selectExpr("x > 5"), + Seq(Row(true))) + } + // Here we exercise some error cases. + val df = Seq((0)).toDF("a") + var expr = uniform(lit(10), lit("a")) + checkError( + intercept[AnalysisException](df.select(expr)), + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"uniform(10, a)\"", + "paramIndex" -> "second", + "inputSql" -> "\"a\"", + "inputType" -> "\"STRING\"", + "requiredType" -> "integer or floating-point"), + context = ExpectedContext( + contextType = QueryContextType.DataFrame, + fragment = "uniform", + objectType = "", + objectName = "", + callSitePattern = "", + startIndex = 0, + stopIndex = 0)) + expr = uniform(col("a"), lit(10)) + checkError( + intercept[AnalysisException](df.select(expr)), + condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + parameters = Map( + "inputName" -> "min", + "inputType" -> "integer or floating-point", + "inputExpr" -> "\"a\"", + "sqlExpr" -> "\"uniform(a, 10)\""), + context = ExpectedContext( + contextType = QueryContextType.DataFrame, + fragment = "uniform", + objectType = "", + objectName = "", + callSitePattern = "", + startIndex = 0, + stopIndex = 0)) + } + test("zeroifnull function") { withTable("t") { // Here we exercise a non-nullable, non-foldable column. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala index 1d7698df2f1be..f0ed2241fd286 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import org.apache.spark.api.python.PythonEvalType -import org.apache.spark.sql.api.python.PythonSQLUtils.distributed_sequence_id import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeReference, PythonUDF, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate, ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan} import org.apache.spark.sql.expressions.Window @@ -405,7 +404,7 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { assertAmbiguousSelfJoin(df12.join(df11, df11("x") === df12("y"))) // Test for AttachDistributedSequence - val df13 = df1.select(distributed_sequence_id().alias("seq"), col("*")) + val df13 = df1.select(Column.internalFn("distributed_sequence_id").alias("seq"), col("*")) val df14 = df13.filter($"value" === "A2") assertAmbiguousSelfJoin(df13.join(df14, df13("key1") === df14("key2"))) assertAmbiguousSelfJoin(df14.join(df13, df13("key1") === df14("key2"))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala index 1ac1dda374fa7..6c1ca94a03079 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala @@ -547,4 +547,55 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession } } } + + test("SPARK-49836 using window fn with window as parameter should preserve parent operator") { + withTempView("clicks") { + val df = Seq( + // small window: [00:00, 01:00), user1, 2 + ("2024-09-30 00:00:00", "user1"), ("2024-09-30 00:00:30", "user1"), + // small window: [01:00, 02:00), user2, 2 + ("2024-09-30 00:01:00", "user2"), ("2024-09-30 00:01:30", "user2"), + // small window: [03:00, 04:00), user1, 1 + ("2024-09-30 00:03:30", "user1"), + // small window: [11:00, 12:00), user1, 3 + ("2024-09-30 00:11:00", "user1"), ("2024-09-30 00:11:30", "user1"), + ("2024-09-30 00:11:45", "user1") + ).toDF("eventTime", "userId") + + // session window: (01:00, 09:00), user1, 3 / (02:00, 07:00), user2, 2 / + // (12:00, 12:05), user1, 3 + + df.createOrReplaceTempView("clicks") + + val aggregatedData = spark.sql( + """ + |SELECT + | userId, + | avg(cpu_large.numClicks) AS clicksPerSession + |FROM + |( + | SELECT + | session_window(small_window, '5 minutes') AS session, + | userId, + | sum(numClicks) AS numClicks + | FROM + | ( + | SELECT + | window(eventTime, '1 minute') AS small_window, + | userId, + | count(*) AS numClicks + | FROM clicks + | GROUP BY window, userId + | ) cpu_small + | GROUP BY session_window, userId + |) cpu_large + |GROUP BY userId + |""".stripMargin) + + checkAnswer( + aggregatedData, + Seq(Row("user1", 3), Row("user2", 2)) + ) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameShowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameShowSuite.scala index d728cc5810a21..86d3ca45fd08e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameShowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameShowSuite.scala @@ -112,13 +112,12 @@ class DataFrameShowSuite extends QueryTest with SharedSparkSession { ||key|value| |+---+-----+ |+---+-----+ - |only showing top 0 rows - |""".stripMargin + |only showing top 0 rows""".stripMargin assert(testData.select($"*").showString(-1) === expectedAnswer) } test("showString(negative), vertical = true") { - val expectedAnswer = "(0 rows)\n" + val expectedAnswer = "(0 rows)" assert(testData.select($"*").showString(-1, vertical = true) === expectedAnswer) } @@ -127,8 +126,7 @@ class DataFrameShowSuite extends QueryTest with SharedSparkSession { ||key|value| |+---+-----+ |+---+-----+ - |only showing top 0 rows - |""".stripMargin + |only showing top 0 rows""".stripMargin assert(testData.select($"*").showString(0) === expectedAnswer) } @@ -145,7 +143,7 @@ class DataFrameShowSuite extends QueryTest with SharedSparkSession { } test("showString(0), vertical = true") { - val expectedAnswer = "(0 rows)\n" + val expectedAnswer = "(0 rows)" assert(testData.select($"*").showString(0, vertical = true) === expectedAnswer) } @@ -286,8 +284,7 @@ class DataFrameShowSuite extends QueryTest with SharedSparkSession { |+---+-----+ || 1| 1| |+---+-----+ - |only showing top 1 row - |""".stripMargin + |only showing top 1 row""".stripMargin assert(testData.select($"*").showString(1) === expectedAnswer) } @@ -295,7 +292,7 @@ class DataFrameShowSuite extends QueryTest with SharedSparkSession { val expectedAnswer = "-RECORD 0----\n" + " key | 1 \n" + " value | 1 \n" + - "only showing top 1 row\n" + "only showing top 1 row" assert(testData.select($"*").showString(1, vertical = true) === expectedAnswer) } @@ -337,7 +334,7 @@ class DataFrameShowSuite extends QueryTest with SharedSparkSession { } test("SPARK-7327 show with empty dataFrame, vertical = true") { - assert(testData.select($"*").filter($"key" < 0).showString(1, vertical = true) === "(0 rows)\n") + assert(testData.select($"*").filter($"key" < 0).showString(1, vertical = true) === "(0 rows)") } test("SPARK-18350 show with session local timezone") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index e1774cab4a0de..2c0d9e29bb273 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -29,7 +29,6 @@ import org.scalatest.matchers.should.Matchers._ import org.apache.spark.SparkException import org.apache.spark.api.python.PythonEvalType -import org.apache.spark.sql.api.python.PythonSQLUtils.distributed_sequence_id import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -2318,7 +2317,7 @@ class DataFrameSuite extends QueryTest test("SPARK-36338: DataFrame.withSequenceColumn should append unique sequence IDs") { val ids = spark.range(10).repartition(5).select( - distributed_sequence_id().alias("default_index"), col("id")) + Column.internalFn("distributed_sequence_id").alias("default_index"), col("id")) assert(ids.collect().map(_.getLong(0)).toSet === Range(0, 10).toSet) assert(ids.take(5).map(_.getLong(0)).toSet === Range(0, 5).toSet) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index 6ee173bc6af67..c52d428cd5dd4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import java.sql.Timestamp import java.time.LocalDateTime import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -714,4 +715,56 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSparkSession { ) } } + + test("SPARK-49836 using window fn with window as parameter should preserve parent operator") { + withTempView("clicks") { + val df = Seq( + // small window: [00:00, 01:00), user1, 2 + ("2024-09-30 00:00:00", "user1"), ("2024-09-30 00:00:30", "user1"), + // small window: [01:00, 02:00), user2, 2 + ("2024-09-30 00:01:00", "user2"), ("2024-09-30 00:01:30", "user2"), + // small window: [07:00, 08:00), user1, 1 + ("2024-09-30 00:07:00", "user1"), + // small window: [11:00, 12:00), user1, 3 + ("2024-09-30 00:11:00", "user1"), ("2024-09-30 00:11:30", "user1"), + ("2024-09-30 00:11:45", "user1") + ).toDF("eventTime", "userId") + + // large window: [00:00, 10:00), user1, 3, [00:00, 10:00), user2, 2, [10:00, 20:00), user1, 3 + + df.createOrReplaceTempView("clicks") + + val aggregatedData = spark.sql( + """ + |SELECT + | cpu_large.large_window.end AS timestamp, + | avg(cpu_large.numClicks) AS avgClicksPerUser + |FROM + |( + | SELECT + | window(small_window, '10 minutes') AS large_window, + | userId, + | sum(numClicks) AS numClicks + | FROM + | ( + | SELECT + | window(eventTime, '1 minute') AS small_window, + | userId, + | count(*) AS numClicks + | FROM clicks + | GROUP BY window, userId + | ) cpu_small + | GROUP BY window, userId + |) cpu_large + |GROUP BY timestamp + |""".stripMargin) + + checkAnswer( + aggregatedData, + Seq( + Row(Timestamp.valueOf("2024-09-30 00:10:00"), 2.5), + Row(Timestamp.valueOf("2024-09-30 00:20:00"), 3)) + ) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 089ce79201dd8..45c34d9c73367 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1273,7 +1273,7 @@ class DatasetSuite extends QueryTest // Just check the error class here to avoid flakiness due to different parameters. assert(intercept[SparkRuntimeException] { buildDataset(Row(Row("hello", null))).collect() - }.getErrorClass == "NOT_NULL_ASSERT_VIOLATION") + }.getCondition == "NOT_NULL_ASSERT_VIOLATION") } test("SPARK-12478: top level null field") { @@ -1416,7 +1416,7 @@ class DatasetSuite extends QueryTest val ex = intercept[SparkRuntimeException] { spark.createDataFrame(rdd, schema).collect() } - assert(ex.getErrorClass == "EXPRESSION_ENCODING_FAILED") + assert(ex.getCondition == "EXPRESSION_ENCODING_FAILED") assert(ex.getCause.getMessage.contains("The 1th field 'b' of input row cannot be null")) } @@ -1612,7 +1612,7 @@ class DatasetSuite extends QueryTest test("Dataset should throw RuntimeException if top-level product input object is null") { val e = intercept[SparkRuntimeException](Seq(ClassData("a", 1), null).toDS()) - assert(e.getErrorClass == "NOT_NULL_ASSERT_VIOLATION") + assert(e.getCondition == "NOT_NULL_ASSERT_VIOLATION") } test("dropDuplicates") { @@ -1849,6 +1849,26 @@ class DatasetSuite extends QueryTest } } + test("Dataset().localCheckpoint() lazy with StorageLevel") { + val df = spark.range(10).repartition($"id" % 2) + val checkpointedDf = df.localCheckpoint(eager = false, StorageLevel.DISK_ONLY) + val checkpointedPlan = checkpointedDf.queryExecution.analyzed + val rdd = checkpointedPlan.asInstanceOf[LogicalRDD].rdd + assert(rdd.getStorageLevel == StorageLevel.DISK_ONLY) + assert(!rdd.isCheckpointed) + checkpointedDf.collect() + assert(rdd.isCheckpointed) + } + + test("Dataset().localCheckpoint() eager with StorageLevel") { + val df = spark.range(10).repartition($"id" % 2) + val checkpointedDf = df.localCheckpoint(eager = true, StorageLevel.DISK_ONLY) + val checkpointedPlan = checkpointedDf.queryExecution.analyzed + val rdd = checkpointedPlan.asInstanceOf[LogicalRDD].rdd + assert(rdd.isCheckpointed) + assert(rdd.getStorageLevel == StorageLevel.DISK_ONLY) + } + test("identity map for primitive arrays") { val arrayByte = Array(1.toByte, 2.toByte, 3.toByte) val arrayInt = Array(1, 2, 3) @@ -2101,7 +2121,7 @@ class DatasetSuite extends QueryTest test("SPARK-23835: null primitive data type should throw NullPointerException") { val ds = Seq[(Option[Int], Option[Int])]((Some(1), None)).toDS() val exception = intercept[SparkRuntimeException](ds.as[(Int, Int)].collect()) - assert(exception.getErrorClass == "NOT_NULL_ASSERT_VIOLATION") + assert(exception.getCondition == "NOT_NULL_ASSERT_VIOLATION") } test("SPARK-24569: Option of primitive types are mistakenly mapped to struct type") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index e44bd5de4f4c4..9c529d1422119 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -128,13 +128,20 @@ class FileBasedDataSourceSuite extends QueryTest allFileBasedDataSources.foreach { format => test(s"SPARK-23372 error while writing empty schema files using $format") { + val formatMapping = Map( + "csv" -> "CSV", + "json" -> "JSON", + "parquet" -> "Parquet", + "orc" -> "ORC", + "text" -> "Text" + ) withTempPath { outputPath => checkError( exception = intercept[AnalysisException] { spark.emptyDataFrame.write.format(format).save(outputPath.toString) }, - condition = "_LEGACY_ERROR_TEMP_1142", - parameters = Map.empty + condition = "EMPTY_SCHEMA_NOT_SUPPORTED_FOR_DATASOURCE", + parameters = Map("format" -> formatMapping(format)) ) } @@ -150,8 +157,8 @@ class FileBasedDataSourceSuite extends QueryTest exception = intercept[AnalysisException] { df.write.format(format).save(outputPath.toString) }, - condition = "_LEGACY_ERROR_TEMP_1142", - parameters = Map.empty + condition = "EMPTY_SCHEMA_NOT_SUPPORTED_FOR_DATASOURCE", + parameters = Map("format" -> formatMapping(format)) ) } } @@ -506,14 +513,23 @@ class FileBasedDataSourceSuite extends QueryTest withSQLConf( SQLConf.USE_V1_SOURCE_LIST.key -> useV1List, SQLConf.LEGACY_INTERVAL_ENABLED.key -> "true") { + val formatMapping = Map( + "csv" -> "CSV", + "json" -> "JSON", + "parquet" -> "Parquet", + "orc" -> "ORC" + ) // write path Seq("csv", "json", "parquet", "orc").foreach { format => checkError( exception = intercept[AnalysisException] { sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir) }, - condition = "_LEGACY_ERROR_TEMP_1136", - parameters = Map.empty + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + parameters = Map( + "format" -> formatMapping(format), + "columnName" -> "`INTERVAL '1 days'`", + "columnType" -> "\"INTERVAL\"") ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index 9afba65183974..3f921618297d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.scalactic.source.Position import org.scalatest.Tag +import org.apache.spark.SparkRuntimeException import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, ExpressionSet} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.Aggregate @@ -204,7 +205,7 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { withLCAOn { checkAnswer(sql(query), expectedAnswerLCAOn) } withLCAOff { assert(intercept[AnalysisException]{ sql(query) } - .getErrorClass == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + .getCondition == "UNRESOLVED_COLUMN.WITH_SUGGESTION") } } @@ -215,8 +216,8 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { errorParams: Map[String, String]): Unit = { val e1 = intercept[AnalysisException] { sql(q1) } val e2 = intercept[AnalysisException] { sql(q2) } - assert(e1.getErrorClass == condition) - assert(e2.getErrorClass == condition) + assert(e1.getCondition == condition) + assert(e2.getCondition == condition) errorParams.foreach { case (k, v) => assert(e1.messageParameters.get(k).exists(_ == v)) assert(e2.messageParameters.get(k).exists(_ == v)) @@ -554,7 +555,15 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { | FROM (SELECT dept * 2.0 AS id, id + 1 AS id2 FROM $testTable)) > 5 |ORDER BY id |""".stripMargin - withLCAOff { intercept[AnalysisException] { sql(query4) } } + withLCAOff { + val exception = intercept[SparkRuntimeException] { + sql(query4).collect() + } + checkError( + exception, + condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS" + ) + } withLCAOn { val analyzedPlan = sql(query4).queryExecution.analyzed assert(!analyzedPlan.containsPattern(OUTER_REFERENCE)) @@ -1178,7 +1187,7 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { "sum_avg * 1.0 as sum_avg1, sum_avg1 + dept " + s"from $testTable group by dept, properties.joinYear $havingSuffix" ).foreach { query => - assert(intercept[AnalysisException](sql(query)).getErrorClass == + assert(intercept[AnalysisException](sql(query)).getCondition == "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_WITH_WINDOW_AND_HAVING") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala index 352197f96acb6..009fe55664a2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.internal.config -import org.apache.spark.sql.internal.RuntimeConfigImpl +import org.apache.spark.sql.internal.{RuntimeConfigImpl, SQLConf} import org.apache.spark.sql.internal.SQLConf.CHECKPOINT_LOCATION import org.apache.spark.sql.internal.StaticSQLConf.GLOBAL_TEMP_DATABASE @@ -81,4 +81,24 @@ class RuntimeConfigSuite extends SparkFunSuite { } assert(ex.getMessage.contains("Spark config")) } + + test("set and get a config with defaultValue") { + val conf = newConf() + val key = SQLConf.SESSION_LOCAL_TIMEZONE.key + // By default, the value when getting an unset config entry is its defaultValue. + assert(conf.get(key) == SQLConf.SESSION_LOCAL_TIMEZONE.defaultValue.get) + assert(conf.getOption(key).contains(SQLConf.SESSION_LOCAL_TIMEZONE.defaultValue.get)) + // Get the unset config entry with a different default value, which should return the given + // default parameter. + assert(conf.get(key, "Europe/Amsterdam") == "Europe/Amsterdam") + + // Set a config entry. + conf.set(key, "Europe/Berlin") + // Get the set config entry. + assert(conf.get(key) == "Europe/Berlin") + // Unset the config entry. + conf.unset(key) + // Get the unset config entry, which should return its defaultValue again. + assert(conf.get(key) == SQLConf.SESSION_LOCAL_TIMEZONE.defaultValue.get) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala index 754c46cc5cd3e..b48ff7121c767 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala @@ -64,7 +64,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS sql("INSERT INTO t VALUES ('txt', null)") } } - assert(e.getErrorClass == "NOT_NULL_ASSERT_VIOLATION") + assert(e.getCondition == "NOT_NULL_ASSERT_VIOLATION") } } @@ -404,7 +404,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS private def assertNotNullException(e: SparkRuntimeException, colPath: Seq[String]): Unit = { e.getCause match { - case _ if e.getErrorClass == "NOT_NULL_ASSERT_VIOLATION" => + case _ if e.getCondition == "NOT_NULL_ASSERT_VIOLATION" => case other => fail(s"Unexpected exception cause: $other") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index ce88f7dc475d6..e3346684285a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -111,10 +111,13 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } test("SPARK-34678: describe functions for table-valued functions") { + sql("describe function range").show(false) checkKeywordsExist(sql("describe function range"), "Function: range", "Class: org.apache.spark.sql.catalyst.plans.logical.Range", - "range(end: long)" + "range(start[, end[, step[, numSlices]]])", + "range(end)", + "Returns a table of values within a specified range." ) } @@ -4925,6 +4928,19 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark ) ) } + + test("SPARK-49743: OptimizeCsvJsonExpr does not change schema when pruning struct") { + val df = sql(""" + | SELECT + | from_json('[{"a": '||id||', "b": '|| (2*id) ||'}]', 'array>').a, + | from_json('[{"a": '||id||', "b": '|| (2*id) ||'}]', 'array>').A + | FROM + | range(3) as t + |""".stripMargin) + val expectedAnswer = Seq( + Row(Array(0), Array(0)), Row(Array(1), Array(1)), Row(Array(2), Array(2))) + checkAnswer(df, expectedAnswer) + } } case class Foo(bar: Option[String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala index 38e004e0b7209..4bd20bc245613 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala @@ -148,7 +148,7 @@ trait SQLQueryTestHelper extends Logging { try { result } catch { - case e: SparkThrowable with Throwable if e.getErrorClass != null => + case e: SparkThrowable with Throwable if e.getCondition != null => (emptySchema, Seq(e.getClass.getName, getMessage(e, format))) case a: AnalysisException => // Do not output the logical plan tree which contains expression IDs. @@ -160,7 +160,7 @@ trait SQLQueryTestHelper extends Logging { // information of stage, task ID, etc. // To make result matching simpler, here we match the cause of the exception if it exists. s.getCause match { - case e: SparkThrowable with Throwable if e.getErrorClass != null => + case e: SparkThrowable with Throwable if e.getCondition != null => (emptySchema, Seq(e.getClass.getName, getMessage(e, format))) case cause => (emptySchema, Seq(cause.getClass.getName, cause.getMessage)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index 16118526f2fe4..76919d6583106 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -163,9 +163,9 @@ class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSparkSession exception = intercept[SparkUnsupportedOperationException] { Seq(InvalidInJava(1)).toDS() }, - condition = "_LEGACY_ERROR_TEMP_2140", + condition = "INVALID_JAVA_IDENTIFIER_AS_FIELD_NAME", parameters = Map( - "fieldName" -> "abstract", + "fieldName" -> "`abstract`", "walkedTypePath" -> "- root class: \"org.apache.spark.sql.InvalidInJava\"")) } @@ -174,9 +174,9 @@ class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSparkSession exception = intercept[SparkUnsupportedOperationException] { Seq(InvalidInJava2(1)).toDS() }, - condition = "_LEGACY_ERROR_TEMP_2140", + condition = "INVALID_JAVA_IDENTIFIER_AS_FIELD_NAME", parameters = Map( - "fieldName" -> "0", + "fieldName" -> "`0`", "walkedTypePath" -> "- root class: \"org.apache.spark.sql.ScalaReflectionRelationSuite.InvalidInJava2\"")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala new file mode 100644 index 0000000000000..c4fd16ca5ce59 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSparkSession + +/** + * Make sure the api.SparkSessionBuilder binds to Classic implementation. + */ +class SparkSessionBuilderImplementationBindingSuite + extends SharedSparkSession + with api.SparkSessionBuilderImplementationBindingSuite diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala index 7fa29dd38fd96..74329ac0e0d23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala @@ -25,7 +25,7 @@ import java.time.LocalDateTime import scala.collection.mutable import scala.util.Random -import org.apache.spark.sql.catalyst.{FullQualifiedTableName, TableIdentifier} +import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, CatalogTable, HiveTableRelation} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.AttributeMap @@ -270,7 +270,7 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils def getTableFromCatalogCache(tableName: String): LogicalPlan = { val catalog = spark.sessionState.catalog - val qualifiedTableName = FullQualifiedTableName( + val qualifiedTableName = QualifiedTableName( CatalogManager.SESSION_CATALOG_NAME, catalog.getCurrentDatabase, tableName) catalog.getCachedTable(qualifiedTableName) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 23c4d51983bb4..f8f7fd246832f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import scala.collection.mutable.ArrayBuffer +import org.apache.spark.SparkRuntimeException import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, LogicalPlan, Project, Sort, Union} @@ -527,43 +528,30 @@ class SubquerySuite extends QueryTest test("SPARK-18504 extra GROUP BY column in correlated scalar subquery is not permitted") { withTempView("v") { Seq((1, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("v") - - val exception = intercept[AnalysisException] { - sql("select (select sum(-1) from v t2 where t1.c2 = t2.c1 group by t2.c2) sum from v t1") + val exception = intercept[SparkRuntimeException] { + sql("select (select sum(-1) from v t2 where t1.c2 = t2.c1 group by t2.c2) sum from v t1"). + collect() } checkError( exception, - condition = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + - "NON_CORRELATED_COLUMNS_IN_GROUP_BY", - parameters = Map("value" -> "c2"), - sqlState = None, - context = ExpectedContext( - fragment = "(select sum(-1) from v t2 where t1.c2 = t2.c1 group by t2.c2)", - start = 7, stop = 67)) } + condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS" + ) + } } test("non-aggregated correlated scalar subquery") { - val exception1 = intercept[AnalysisException] { - sql("select a, (select b from l l2 where l2.a = l1.a) sum_b from l l1") + val exception1 = intercept[SparkRuntimeException] { + sql("select a, (select b from l l2 where l2.a = l1.a) sum_b from l l1").collect() } checkError( exception1, - condition = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + - "MUST_AGGREGATE_CORRELATED_SCALAR_SUBQUERY", - parameters = Map.empty, - context = ExpectedContext( - fragment = "(select b from l l2 where l2.a = l1.a)", start = 10, stop = 47)) - val exception2 = intercept[AnalysisException] { - sql("select a, (select b from l l2 where l2.a = l1.a group by 1) sum_b from l l1") - } - checkErrorMatchPVals( - exception2, - condition = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + - "MUST_AGGREGATE_CORRELATED_SCALAR_SUBQUERY", - parameters = Map.empty[String, String], - sqlState = None, - context = ExpectedContext( - fragment = "(select b from l l2 where l2.a = l1.a group by 1)", start = 10, stop = 58)) + condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS" + ) + checkAnswer( + sql("select a, (select b from l l2 where l2.a = l1.a group by 1) sum_b from l l1"), + Row(1, 2.0) :: Row(1, 2.0) :: Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: + Row(null, null) :: Row(null, null) :: Row(6, null) :: Nil + ) } test("non-equal correlated scalar subquery") { @@ -937,12 +925,12 @@ class SubquerySuite extends QueryTest withSQLConf(SQLConf.DECORRELATE_INNER_QUERY_ENABLED.key -> "false") { val error = intercept[AnalysisException] { sql(query) } - assert(error.getErrorClass == "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + assert(error.getCondition == "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + "ACCESSING_OUTER_QUERY_COLUMN_IS_NOT_ALLOWED") } withSQLConf(SQLConf.DECORRELATE_SET_OPS_ENABLED.key -> "false") { val error = intercept[AnalysisException] { sql(query) } - assert(error.getErrorClass == "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + assert(error.getCondition == "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + "ACCESSING_OUTER_QUERY_COLUMN_IS_NOT_ALLOWED") } @@ -1016,12 +1004,12 @@ class SubquerySuite extends QueryTest withSQLConf(SQLConf.DECORRELATE_INNER_QUERY_ENABLED.key -> "false") { val error = intercept[AnalysisException] { sql(query) } - assert(error.getErrorClass == "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + assert(error.getCondition == "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + "ACCESSING_OUTER_QUERY_COLUMN_IS_NOT_ALLOWED") } withSQLConf(SQLConf.DECORRELATE_SET_OPS_ENABLED.key -> "false") { val error = intercept[AnalysisException] { sql(query) } - assert(error.getErrorClass == "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + assert(error.getCondition == "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + "ACCESSING_OUTER_QUERY_COLUMN_IS_NOT_ALLOWED") } } @@ -2154,6 +2142,24 @@ class SubquerySuite extends QueryTest } } + test("SPARK-49819: Do not collapse projects with exist subqueries") { + withTempView("v") { + Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("v") + checkAnswer( + sql(""" + |SELECT m, CASE WHEN EXISTS (SELECT SUM(c2) FROM v WHERE c1 = m) THEN 1 ELSE 0 END + |FROM (SELECT MIN(c2) AS m FROM v) + |""".stripMargin), + Row(1, 1) :: Nil) + checkAnswer( + sql(""" + |SELECT c, CASE WHEN EXISTS (SELECT SUM(c2) FROM v WHERE c1 = c) THEN 1 ELSE 0 END + |FROM (SELECT c1 AS c FROM v GROUP BY c1) + |""".stripMargin), + Row(0, 1) :: Row(1, 1) :: Nil) + } + } + test("SPARK-37199: deterministic in QueryPlan considers subquery") { val deterministicQueryPlan = sql("select (select 1 as b) as b") .queryExecution.executedPlan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 2e072e5afc926..d550d0f94f236 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -821,14 +821,14 @@ class UDFSuite extends QueryTest with SharedSparkSession { val e1 = intercept[SparkException] { Seq("20").toDF("col").select(udf(f1).apply(Column("col"))).collect() } - assert(e1.getErrorClass == "FAILED_EXECUTE_UDF") + assert(e1.getCondition == "FAILED_EXECUTE_UDF") assert(e1.getCause.getStackTrace.head.toString.contains( "UDFSuite$MalformedClassObject$MalformedNonPrimitiveFunction")) val e2 = intercept[SparkException] { Seq(20).toDF("col").select(udf(f2).apply(Column("col"))).collect() } - assert(e2.getErrorClass == "FAILED_EXECUTE_UDF") + assert(e2.getCondition == "FAILED_EXECUTE_UDF") assert(e2.getCause.getStackTrace.head.toString.contains( "UDFSuite$MalformedClassObject$MalformedPrimitiveFunction")) } @@ -938,7 +938,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { val e = intercept[SparkException] { input.select(overflowFunc($"dateTime")).collect() } - assert(e.getErrorClass == "FAILED_EXECUTE_UDF") + assert(e.getCondition == "FAILED_EXECUTE_UDF") assert(e.getCause.isInstanceOf[java.lang.ArithmeticException]) } @@ -1053,7 +1053,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { val e = intercept[SparkException] { input.select(overflowFunc($"d")).collect() } - assert(e.getErrorClass == "FAILED_EXECUTE_UDF") + assert(e.getCondition == "FAILED_EXECUTE_UDF") assert(e.getCause.isInstanceOf[java.lang.ArithmeticException]) } @@ -1101,7 +1101,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { val e = intercept[SparkException] { input.select(overflowFunc($"p")).collect() } - assert(e.getErrorClass == "FAILED_EXECUTE_UDF") + assert(e.getCondition == "FAILED_EXECUTE_UDF") assert(e.getCause.isInstanceOf[java.lang.ArithmeticException]) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala index 3224baf42f3e5..fe5c6ef004920 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql +import org.apache.spark.{SparkException, SparkRuntimeException} import org.apache.spark.sql.QueryTest.sameRows import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} @@ -28,6 +29,7 @@ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarArray import org.apache.spark.types.variant.VariantBuilder +import org.apache.spark.types.variant.VariantUtil._ import org.apache.spark.unsafe.types.VariantVal class VariantEndToEndSuite extends QueryTest with SharedSparkSession { @@ -37,8 +39,10 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { def check(input: String, output: String = null): Unit = { val df = Seq(input).toDF("v") val variantDF = df.select(to_json(parse_json(col("v")))) + val variantDF2 = df.select(to_json(from_json(col("v"), VariantType))) val expected = if (output != null) output else input checkAnswer(variantDF, Seq(Row(expected))) + checkAnswer(variantDF2, Seq(Row(expected))) } check("null") @@ -339,4 +343,40 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { } } } + + test("from_json(_, 'variant') with duplicate keys") { + val json: String = """{"a": 1, "b": 2, "c": "3", "a": 4}""" + withSQLConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS.key -> "true") { + val df = Seq(json).toDF("j") + .selectExpr("from_json(j,'variant')") + val actual = df.collect().head(0).asInstanceOf[VariantVal] + val expectedValue: Array[Byte] = Array(objectHeader(false, 1, 1), + /* size */ 3, + /* id list */ 0, 1, 2, + /* offset list */ 4, 0, 2, 6, + /* field data */ primitiveHeader(INT1), 2, shortStrHeader(1), '3', + primitiveHeader(INT1), 4) + val expectedMetadata: Array[Byte] = Array(VERSION, 3, 0, 1, 2, 3, 'a', 'b', 'c') + assert(actual === new VariantVal(expectedValue, expectedMetadata)) + } + // Check whether the parse_json and from_json expressions throw the correct exception. + Seq("from_json(j, 'variant')", "parse_json(j)").foreach { expr => + withSQLConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS.key -> "false") { + val df = Seq(json).toDF("j").selectExpr(expr) + val exception = intercept[SparkException] { + df.collect() + } + checkError( + exception = exception, + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + parameters = Map("badRecord" -> json, "failFastMode" -> "FAILFAST") + ) + checkError( + exception = exception.getCause.asInstanceOf[SparkRuntimeException], + condition = "VARIANT_DUPLICATE_KEY", + parameters = Map("key" -> "a") + ) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala index d6599debd3b11..6b0fd6084099c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala @@ -414,8 +414,8 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { new JavaStrLen(new JavaStrLenNoImpl)) checkError( exception = intercept[AnalysisException](sql("SELECT testcat.ns.strlen('abc')").collect()), - condition = "_LEGACY_ERROR_TEMP_3055", - parameters = Map("scalarFunc" -> "strlen"), + condition = "SCALAR_FUNCTION_NOT_FULLY_IMPLEMENTED", + parameters = Map("scalarFunc" -> "`strlen`"), context = ExpectedContext( fragment = "testcat.ns.strlen('abc')", start = 7, @@ -448,8 +448,8 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { addFunction(Identifier.of(Array("ns"), "add"), new JavaLongAdd(new JavaLongAddMismatchMagic)) checkError( exception = intercept[AnalysisException](sql("SELECT testcat.ns.add(1L, 2L)").collect()), - condition = "_LEGACY_ERROR_TEMP_3055", - parameters = Map("scalarFunc" -> "long_add_mismatch_magic"), + condition = "SCALAR_FUNCTION_NOT_FULLY_IMPLEMENTED", + parameters = Map("scalarFunc" -> "`long_add_mismatch_magic`"), context = ExpectedContext( fragment = "testcat.ns.add(1L, 2L)", start = 7, @@ -458,6 +458,23 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { ) } + test("SPARK-49549: scalar function w/ mismatch a compatible ScalarFunction#produceResult") { + case object CharLength extends ScalarFunction[Int] { + override def inputTypes(): Array[DataType] = Array(StringType) + override def resultType(): DataType = IntegerType + override def name(): String = "CHAR_LENGTH" + } + + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "my_strlen"), StrLen(CharLength)) + checkError( + exception = intercept[SparkUnsupportedOperationException] + (sql("SELECT testcat.ns.my_strlen('abc')").collect()), + condition = "SCALAR_FUNCTION_NOT_COMPATIBLE", + parameters = Map("scalarFunc" -> "`CHAR_LENGTH`") + ) + } + test("SPARK-35390: scalar function w/ type coercion") { catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) addFunction(Identifier.of(Array("ns"), "add"), new JavaLongAdd(new JavaLongAddDefault(false))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 7aaec6d500ba0..52ae1bf5d9d3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -27,7 +27,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.{SparkException, SparkRuntimeException, SparkUnsupportedOperationException} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{FullQualifiedTableName, InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.{InternalRow, QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchNamespaceException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, CatalogUtils} @@ -840,7 +840,7 @@ class DataSourceV2SQLSuiteV1Filter val exception = intercept[SparkRuntimeException] { insertNullValueAndCheck() } - assert(exception.getErrorClass == "NOT_NULL_ASSERT_VIOLATION") + assert(exception.getCondition == "NOT_NULL_ASSERT_VIOLATION") } } } @@ -2887,6 +2887,48 @@ class DataSourceV2SQLSuiteV1Filter "config" -> "\"spark.sql.catalog.not_exist_catalog\"")) } + test("SPARK-49757: SET CATALOG statement with IDENTIFIER should work") { + val catalogManager = spark.sessionState.catalogManager + assert(catalogManager.currentCatalog.name() == SESSION_CATALOG_NAME) + + sql("SET CATALOG IDENTIFIER('testcat')") + assert(catalogManager.currentCatalog.name() == "testcat") + + spark.sql("SET CATALOG IDENTIFIER(:param)", Map("param" -> "testcat2")) + assert(catalogManager.currentCatalog.name() == "testcat2") + + checkError( + exception = intercept[CatalogNotFoundException] { + sql("SET CATALOG IDENTIFIER('not_exist_catalog')") + }, + condition = "CATALOG_NOT_FOUND", + parameters = Map( + "catalogName" -> "`not_exist_catalog`", + "config" -> "\"spark.sql.catalog.not_exist_catalog\"") + ) + } + + test("SPARK-49757: SET CATALOG statement with IDENTIFIER with multipart name should fail") { + val catalogManager = spark.sessionState.catalogManager + assert(catalogManager.currentCatalog.name() == SESSION_CATALOG_NAME) + + val sqlText = "SET CATALOG IDENTIFIER(:param)" + checkError( + exception = intercept[ParseException] { + spark.sql(sqlText, Map("param" -> "testcat.ns1")) + }, + condition = "INVALID_SQL_SYNTAX.MULTI_PART_NAME", + parameters = Map( + "name" -> "`testcat`.`ns1`", + "statement" -> "SET CATALOG" + ), + context = ExpectedContext( + fragment = sqlText, + start = 0, + stop = 29) + ) + } + test("SPARK-35973: ShowCatalogs") { val schema = new StructType() .add("catalog", StringType, nullable = false) @@ -3713,7 +3755,7 @@ class DataSourceV2SQLSuiteV1Filter // Reset CatalogManager to clear the materialized `spark_catalog` instance, so that we can // configure a new implementation. - val table1 = FullQualifiedTableName(SESSION_CATALOG_NAME, "default", "t") + val table1 = QualifiedTableName(SESSION_CATALOG_NAME, "default", "t") spark.sessionState.catalogManager.reset() withSQLConf( V2_SESSION_CATALOG_IMPLEMENTATION.key -> @@ -3722,7 +3764,7 @@ class DataSourceV2SQLSuiteV1Filter checkParquet(table1.toString, path.getAbsolutePath) } } - val table2 = FullQualifiedTableName("testcat3", "default", "t") + val table2 = QualifiedTableName("testcat3", "default", "t") withSQLConf( "spark.sql.catalog.testcat3" -> classOf[V2CatalogSupportBuiltinDataSource].getName) { withTempPath { path => @@ -3741,7 +3783,7 @@ class DataSourceV2SQLSuiteV1Filter // Reset CatalogManager to clear the materialized `spark_catalog` instance, so that we can // configure a new implementation. spark.sessionState.catalogManager.reset() - val table1 = FullQualifiedTableName(SESSION_CATALOG_NAME, "default", "t") + val table1 = QualifiedTableName(SESSION_CATALOG_NAME, "default", "t") withSQLConf( V2_SESSION_CATALOG_IMPLEMENTATION.key -> classOf[V2CatalogSupportBuiltinDataSource].getName) { @@ -3750,7 +3792,7 @@ class DataSourceV2SQLSuiteV1Filter } } - val table2 = FullQualifiedTableName("testcat3", "default", "t") + val table2 = QualifiedTableName("testcat3", "default", "t") withSQLConf( "spark.sql.catalog.testcat3" -> classOf[V2CatalogSupportBuiltinDataSource].getName) { withTempPath { path => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index 9d4e4fc016722..053616c88d638 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -1326,7 +1326,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase { | UPDATE SET s = named_struct('n_i', null, 'n_l', -1L) |""".stripMargin) } - assert(e1.getErrorClass == "NOT_NULL_ASSERT_VIOLATION") + assert(e1.getCondition == "NOT_NULL_ASSERT_VIOLATION") val e2 = intercept[SparkRuntimeException] { sql( @@ -1337,7 +1337,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase { | UPDATE SET s = named_struct('n_i', null, 'n_l', -1L) |""".stripMargin) } - assert(e2.getErrorClass == "NOT_NULL_ASSERT_VIOLATION") + assert(e2.getCondition == "NOT_NULL_ASSERT_VIOLATION") val e3 = intercept[SparkRuntimeException] { sql( @@ -1348,7 +1348,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase { | INSERT (pk, s, dep) VALUES (s.pk, named_struct('n_i', null, 'n_l', -1L), 'invalid') |""".stripMargin) } - assert(e3.getErrorClass == "NOT_NULL_ASSERT_VIOLATION") + assert(e3.getCondition == "NOT_NULL_ASSERT_VIOLATION") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/ProcedureSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/ProcedureSuite.scala new file mode 100644 index 0000000000000..c8faf5a874f5f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/ProcedureSuite.scala @@ -0,0 +1,656 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import java.util.Collections + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SPARK_DOC_ROOT, SparkException, SparkNumberFormatException} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId +import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, InMemoryCatalog} +import org.apache.spark.sql.connector.catalog.procedures.{BoundProcedure, ProcedureParameter, UnboundProcedure} +import org.apache.spark.sql.connector.catalog.procedures.ProcedureParameter.Mode +import org.apache.spark.sql.connector.catalog.procedures.ProcedureParameter.Mode.{IN, INOUT, OUT} +import org.apache.spark.sql.connector.read.{LocalScan, Scan} +import org.apache.spark.sql.errors.DataTypeErrors.{toSQLType, toSQLValue} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType} +import org.apache.spark.unsafe.types.UTF8String + +class ProcedureSuite extends QueryTest with SharedSparkSession with BeforeAndAfter { + + before { + spark.conf.set(s"spark.sql.catalog.cat", classOf[InMemoryCatalog].getName) + } + + after { + spark.sessionState.catalogManager.reset() + spark.sessionState.conf.unsetConf(s"spark.sql.catalog.cat") + } + + private def catalog: InMemoryCatalog = { + val catalog = spark.sessionState.catalogManager.catalog("cat") + catalog.asInstanceOf[InMemoryCatalog] + } + + test("position arguments") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkAnswer(sql("CALL cat.ns.sum(5, 5)"), Row(10) :: Nil) + } + + test("named arguments") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkAnswer(sql("CALL cat.ns.sum(in2 => 3, in1 => 5)"), Row(8) :: Nil) + } + + test("position and named arguments") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkAnswer(sql("CALL cat.ns.sum(3, in2 => 1)"), Row(4) :: Nil) + } + + test("foldable expressions") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkAnswer(sql("CALL cat.ns.sum(1 + 1, in2 => 2)"), Row(4) :: Nil) + checkAnswer(sql("CALL cat.ns.sum(in2 => 1, in1 => 2 + 1)"), Row(4) :: Nil) + checkAnswer(sql("CALL cat.ns.sum((1 + 1) * 2, in2 => (2 + 1) / 3)"), Row(5) :: Nil) + } + + test("type coercion") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundLongSum) + checkAnswer(sql("CALL cat.ns.sum(1, 2)"), Row(3) :: Nil) + checkAnswer(sql("CALL cat.ns.sum(1L, 2)"), Row(3) :: Nil) + checkAnswer(sql("CALL cat.ns.sum(1, 2L)"), Row(3) :: Nil) + } + + test("multiple output rows") { + catalog.createProcedure(Identifier.of(Array("ns"), "complex"), UnboundComplexProcedure) + checkAnswer( + sql("CALL cat.ns.complex('X', 'Y', 3)"), + Row(1, "X1", "Y1") :: Row(2, "X2", "Y2") :: Row(3, "X3", "Y3") :: Nil) + } + + test("parameters with default values") { + catalog.createProcedure(Identifier.of(Array("ns"), "complex"), UnboundComplexProcedure) + checkAnswer(sql("CALL cat.ns.complex()"), Row(1, "A1", "B1") :: Nil) + checkAnswer(sql("CALL cat.ns.complex('X', 'Y')"), Row(1, "X1", "Y1") :: Nil) + } + + test("parameters with invalid default values") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundInvalidDefaultProcedure) + checkError( + exception = intercept[AnalysisException]( + sql("CALL cat.ns.sum()") + ), + condition = "INVALID_DEFAULT_VALUE.DATA_TYPE", + parameters = Map( + "statement" -> "CALL", + "colName" -> toSQLId("in2"), + "defaultValue" -> toSQLValue("B"), + "expectedType" -> toSQLType("INT"), + "actualType" -> toSQLType("STRING"))) + } + + test("IDENTIFIER") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkAnswer( + spark.sql("CALL IDENTIFIER(:p1)(1, 2)", Map("p1" -> "cat.ns.sum")), + Row(3) :: Nil) + } + + test("parameterized statements") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkAnswer( + spark.sql("CALL cat.ns.sum(?, ?)", Array(2, 3)), + Row(5) :: Nil) + } + + test("undefined procedure") { + checkError( + exception = intercept[AnalysisException]( + sql("CALL cat.non_exist(1, 2)") + ), + sqlState = Some("38000"), + condition = "FAILED_TO_LOAD_ROUTINE", + parameters = Map("routineName" -> "`cat`.`non_exist`") + ) + } + + test("non-procedure catalog") { + withSQLConf("spark.sql.catalog.testcat" -> classOf[BasicInMemoryTableCatalog].getName) { + checkError( + exception = intercept[AnalysisException]( + sql("CALL testcat.procedure(1, 2)") + ), + condition = "_LEGACY_ERROR_TEMP_1184", + parameters = Map("plugin" -> "testcat", "ability" -> "procedures") + ) + } + } + + test("too many arguments") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkError( + exception = intercept[AnalysisException]( + sql("CALL cat.ns.sum(1, 2, 3)") + ), + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + parameters = Map( + "functionName" -> toSQLId("sum"), + "expectedNum" -> "2", + "actualNum" -> "3", + "docroot" -> SPARK_DOC_ROOT)) + } + + test("custom default catalog") { + withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "cat") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + val df = sql("CALL ns.sum(1, 2)") + checkAnswer(df, Row(3) :: Nil) + } + } + + test("custom default catalog and namespace") { + withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "cat") { + catalog.createNamespace(Array("ns"), Collections.emptyMap) + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + sql("USE ns") + val df = sql("CALL sum(1, 2)") + checkAnswer(df, Row(3) :: Nil) + } + } + + test("required parameter not found") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkError( + exception = intercept[AnalysisException] { + sql("CALL cat.ns.sum()") + }, + condition = "REQUIRED_PARAMETER_NOT_FOUND", + parameters = Map( + "routineName" -> toSQLId("sum"), + "parameterName" -> toSQLId("in1"), + "index" -> "0")) + } + + test("conflicting position and named parameter assignments") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkError( + exception = intercept[AnalysisException] { + sql("CALL cat.ns.sum(1, in1 => 2)") + }, + condition = "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.BOTH_POSITIONAL_AND_NAMED", + parameters = Map( + "routineName" -> toSQLId("sum"), + "parameterName" -> toSQLId("in1"))) + } + + test("duplicate named parameter assignments") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkError( + exception = intercept[AnalysisException] { + sql("CALL cat.ns.sum(in1 => 1, in1 => 2)") + }, + condition = "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + parameters = Map( + "routineName" -> toSQLId("sum"), + "parameterName" -> toSQLId("in1"))) + } + + test("unknown parameter name") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkError( + exception = intercept[AnalysisException] { + sql("CALL cat.ns.sum(in1 => 1, in5 => 2)") + }, + condition = "UNRECOGNIZED_PARAMETER_NAME", + parameters = Map( + "routineName" -> toSQLId("sum"), + "argumentName" -> toSQLId("in5"), + "proposal" -> (toSQLId("in1") + " " + toSQLId("in2")))) + } + + test("position parameter after named parameter") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkError( + exception = intercept[AnalysisException] { + sql("CALL cat.ns.sum(in1 => 1, 2)") + }, + condition = "UNEXPECTED_POSITIONAL_ARGUMENT", + parameters = Map( + "routineName" -> toSQLId("sum"), + "parameterName" -> toSQLId("in1"))) + } + + test("invalid argument type") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + val call = "CALL cat.ns.sum(1, TIMESTAMP '2016-11-15 20:54:00.000')" + checkError( + exception = intercept[AnalysisException] { + sql(call) + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "CALL", + "paramIndex" -> "second", + "inputSql" -> "\"TIMESTAMP '2016-11-15 20:54:00'\"", + "inputType" -> toSQLType("TIMESTAMP"), + "requiredType" -> toSQLType("INT")), + context = ExpectedContext(fragment = call, start = 0, stop = call.length - 1)) + } + + test("malformed input to implicit cast") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> true.toString) { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + val call = "CALL cat.ns.sum('A', 2)" + checkError( + exception = intercept[SparkNumberFormatException]( + sql(call) + ), + condition = "CAST_INVALID_INPUT", + parameters = Map( + "expression" -> toSQLValue("A"), + "sourceType" -> toSQLType("STRING"), + "targetType" -> toSQLType("INT")), + context = ExpectedContext(fragment = call, start = 0, stop = call.length - 1)) + } + } + + test("required parameters after optional") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundInvalidSum) + val e = intercept[SparkException] { + sql("CALL cat.ns.sum(in2 => 1)") + } + assert(e.getMessage.contains("required arguments should come before optional arguments")) + } + + test("INOUT parameters are not supported") { + catalog.createProcedure(Identifier.of(Array("ns"), "procedure"), UnboundInoutProcedure) + val e = intercept[SparkException] { + sql("CALL cat.ns.procedure(1)") + } + assert(e.getMessage.contains(" Unsupported parameter mode: INOUT")) + } + + test("OUT parameters are not supported") { + catalog.createProcedure(Identifier.of(Array("ns"), "procedure"), UnboundOutProcedure) + val e = intercept[SparkException] { + sql("CALL cat.ns.procedure(1)") + } + assert(e.getMessage.contains("Unsupported parameter mode: OUT")) + } + + test("EXPLAIN") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundNonExecutableSum) + val explain1 = sql("EXPLAIN CALL cat.ns.sum(5, 5)").head().get(0) + assert(explain1.toString.contains("cat.ns.sum(5, 5)")) + val explain2 = sql("EXPLAIN EXTENDED CALL cat.ns.sum(10, 10)").head().get(0) + assert(explain2.toString.contains("cat.ns.sum(10, 10)")) + } + + test("void procedure") { + catalog.createProcedure(Identifier.of(Array("ns"), "proc"), UnboundVoidProcedure) + checkAnswer(sql("CALL cat.ns.proc('A', 'B')"), Nil) + } + + test("multi-result procedure") { + catalog.createProcedure(Identifier.of(Array("ns"), "proc"), UnboundMultiResultProcedure) + checkAnswer(sql("CALL cat.ns.proc()"), Row("last") :: Nil) + } + + test("invalid input to struct procedure") { + catalog.createProcedure(Identifier.of(Array("ns"), "proc"), UnboundStructProcedure) + val actualType = + StructType(Seq( + StructField("X", DataTypes.DateType, nullable = false), + StructField("Y", DataTypes.IntegerType, nullable = false))) + val expectedType = StructProcedure.parameters.head.dataType + val call = "CALL cat.ns.proc(named_struct('X', DATE '2011-11-11', 'Y', 2), 'VALUE')" + checkError( + exception = intercept[AnalysisException](sql(call)), + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "CALL", + "paramIndex" -> "first", + "inputSql" -> "\"named_struct(X, DATE '2011-11-11', Y, 2)\"", + "inputType" -> toSQLType(actualType), + "requiredType" -> toSQLType(expectedType)), + context = ExpectedContext(fragment = call, start = 0, stop = call.length - 1)) + } + + test("save execution summary") { + withTable("summary") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + val result = sql("CALL cat.ns.sum(1, 2)") + result.write.saveAsTable("summary") + checkAnswer(spark.table("summary"), Row(3) :: Nil) + } + } + + object UnboundVoidProcedure extends UnboundProcedure { + override def name: String = "void" + override def description: String = "void procedure" + override def bind(inputType: StructType): BoundProcedure = VoidProcedure + } + + object VoidProcedure extends BoundProcedure { + override def name: String = "void" + + override def description: String = "void procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.StringType).build(), + ProcedureParameter.in("in2", DataTypes.StringType).build() + ) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + Collections.emptyIterator + } + } + + object UnboundMultiResultProcedure extends UnboundProcedure { + override def name: String = "multi" + override def description: String = "multi-result procedure" + override def bind(inputType: StructType): BoundProcedure = MultiResultProcedure + } + + object MultiResultProcedure extends BoundProcedure { + override def name: String = "multi" + + override def description: String = "multi-result procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array() + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + val scans = java.util.Arrays.asList[Scan]( + Result( + new StructType().add("out", DataTypes.IntegerType), + Array(InternalRow(1))), + Result( + new StructType().add("out", DataTypes.StringType), + Array(InternalRow(UTF8String.fromString("last")))) + ) + scans.iterator() + } + } + + object UnboundNonExecutableSum extends UnboundProcedure { + override def name: String = "sum" + override def description: String = "sum integers" + override def bind(inputType: StructType): BoundProcedure = Sum + } + + object NonExecutableSum extends BoundProcedure { + override def name: String = "sum" + + override def description: String = "sum integers" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.IntegerType).build(), + ProcedureParameter.in("in2", DataTypes.IntegerType).build() + ) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + throw new UnsupportedOperationException() + } + } + + object UnboundSum extends UnboundProcedure { + override def name: String = "sum" + override def description: String = "sum integers" + override def bind(inputType: StructType): BoundProcedure = Sum + } + + object Sum extends BoundProcedure { + override def name: String = "sum" + + override def description: String = "sum integers" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.IntegerType).build(), + ProcedureParameter.in("in2", DataTypes.IntegerType).build() + ) + + def outputType: StructType = new StructType().add("out", DataTypes.IntegerType) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + val in1 = input.getInt(0) + val in2 = input.getInt(1) + val result = Result(outputType, Array(InternalRow(in1 + in2))) + Collections.singleton[Scan](result).iterator() + } + } + + object UnboundLongSum extends UnboundProcedure { + override def name: String = "long_sum" + override def description: String = "sum longs" + override def bind(inputType: StructType): BoundProcedure = LongSum + } + + object LongSum extends BoundProcedure { + override def name: String = "long_sum" + + override def description: String = "sum longs" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.LongType).build(), + ProcedureParameter.in("in2", DataTypes.LongType).build() + ) + + def outputType: StructType = new StructType().add("out", DataTypes.LongType) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + val in1 = input.getLong(0) + val in2 = input.getLong(1) + val result = Result(outputType, Array(InternalRow(in1 + in2))) + Collections.singleton[Scan](result).iterator() + } + } + + object UnboundInvalidSum extends UnboundProcedure { + override def name: String = "invalid" + override def description: String = "sum integers" + override def bind(inputType: StructType): BoundProcedure = InvalidSum + } + + object InvalidSum extends BoundProcedure { + override def name: String = "invalid" + + override def description: String = "sum integers" + + override def isDeterministic: Boolean = false + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.IntegerType).defaultValue("1").build(), + ProcedureParameter.in("in2", DataTypes.IntegerType).build() + ) + + def outputType: StructType = new StructType().add("out", DataTypes.IntegerType) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + throw new UnsupportedOperationException() + } + } + + object UnboundInvalidDefaultProcedure extends UnboundProcedure { + override def name: String = "sum" + override def description: String = "invalid default value procedure" + override def bind(inputType: StructType): BoundProcedure = InvalidDefaultProcedure + } + + object InvalidDefaultProcedure extends BoundProcedure { + override def name: String = "sum" + + override def description: String = "invalid default value procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.IntegerType).defaultValue("10").build(), + ProcedureParameter.in("in2", DataTypes.IntegerType).defaultValue("'B'").build() + ) + + def outputType: StructType = new StructType().add("out", DataTypes.IntegerType) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + throw new UnsupportedOperationException() + } + } + + object UnboundComplexProcedure extends UnboundProcedure { + override def name: String = "complex" + override def description: String = "complex procedure" + override def bind(inputType: StructType): BoundProcedure = ComplexProcedure + } + + object ComplexProcedure extends BoundProcedure { + override def name: String = "complex" + + override def description: String = "complex procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.StringType).defaultValue("'A'").build(), + ProcedureParameter.in("in2", DataTypes.StringType).defaultValue("'B'").build(), + ProcedureParameter.in("in3", DataTypes.IntegerType).defaultValue("1 + 1 - 1").build() + ) + + def outputType: StructType = new StructType() + .add("out1", DataTypes.IntegerType) + .add("out2", DataTypes.StringType) + .add("out3", DataTypes.StringType) + + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + val in1 = input.getString(0) + val in2 = input.getString(1) + val in3 = input.getInt(2) + + val rows = (1 to in3).map { index => + val v1 = UTF8String.fromString(s"$in1$index") + val v2 = UTF8String.fromString(s"$in2$index") + InternalRow(index, v1, v2) + }.toArray + + val result = Result(outputType, rows) + Collections.singleton[Scan](result).iterator() + } + } + + object UnboundStructProcedure extends UnboundProcedure { + override def name: String = "struct_input" + override def description: String = "struct procedure" + override def bind(inputType: StructType): BoundProcedure = StructProcedure + } + + object StructProcedure extends BoundProcedure { + override def name: String = "struct_input" + + override def description: String = "struct procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter + .in( + "in1", + StructType(Seq( + StructField("nested1", DataTypes.IntegerType), + StructField("nested2", DataTypes.StringType)))) + .build(), + ProcedureParameter.in("in2", DataTypes.StringType).build() + ) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + Collections.emptyIterator + } + } + + object UnboundInoutProcedure extends UnboundProcedure { + override def name: String = "procedure" + override def description: String = "inout procedure" + override def bind(inputType: StructType): BoundProcedure = InoutProcedure + } + + object InoutProcedure extends BoundProcedure { + override def name: String = "procedure" + + override def description: String = "inout procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + CustomParameterImpl(INOUT, "in1", DataTypes.IntegerType) + ) + + def outputType: StructType = new StructType().add("out", DataTypes.IntegerType) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + throw new UnsupportedOperationException() + } + } + + object UnboundOutProcedure extends UnboundProcedure { + override def name: String = "procedure" + override def description: String = "out procedure" + override def bind(inputType: StructType): BoundProcedure = OutProcedure + } + + object OutProcedure extends BoundProcedure { + override def name: String = "procedure" + + override def description: String = "out procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + CustomParameterImpl(IN, "in1", DataTypes.IntegerType), + CustomParameterImpl(OUT, "out1", DataTypes.IntegerType) + ) + + def outputType: StructType = new StructType().add("out", DataTypes.IntegerType) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + throw new UnsupportedOperationException() + } + } + + case class Result(readSchema: StructType, rows: Array[InternalRow]) extends LocalScan + + case class CustomParameterImpl( + mode: Mode, + name: String, + dataType: DataType) extends ProcedureParameter { + override def defaultValueExpression: String = null + override def comment: String = null + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala index 370c118de9a93..92c175fe2f94a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala @@ -868,6 +868,39 @@ class QueryCompilationErrorsSuite "inputTypes" -> "[\"INT\", \"STRING\", \"STRING\"]")) } + test("SPARK-49666: the trim collation feature is off without collate builder call") { + withSQLConf(SQLConf.TRIM_COLLATION_ENABLED.key -> "false") { + Seq( + "CREATE TABLE t(col STRING COLLATE EN_RTRIM_CI) USING parquet", + "CREATE TABLE t(col STRING COLLATE UTF8_LCASE_RTRIM) USING parquet", + "SELECT 'aaa' COLLATE UNICODE_LTRIM_CI" + ).foreach { sqlText => + checkError( + exception = intercept[AnalysisException](sql(sqlText)), + condition = "UNSUPPORTED_FEATURE.TRIM_COLLATION" + ) + } + } + } + + test("SPARK-49666: the trim collation feature is off with collate builder call") { + withSQLConf(SQLConf.TRIM_COLLATION_ENABLED.key -> "false") { + Seq( + "SELECT collate('aaa', 'UNICODE_RTRIM')", + "SELECT collate('aaa', 'UTF8_BINARY_RTRIM')", + "SELECT collate('aaa', 'EN_AI_RTRIM')" + ).foreach { sqlText => + checkError( + exception = intercept[AnalysisException](sql(sqlText)), + condition = "UNSUPPORTED_FEATURE.TRIM_COLLATION", + parameters = Map.empty, + context = + ExpectedContext(fragment = sqlText.substring(7), start = 7, stop = sqlText.length - 1) + ) + } + } + } + test("UNSUPPORTED_CALL: call the unsupported method update()") { checkError( exception = intercept[SparkUnsupportedOperationException] { @@ -941,11 +974,67 @@ class QueryCompilationErrorsSuite cmd.run(spark) }, condition = "DATA_SOURCE_EXTERNAL_ERROR", - sqlState = "KD00F", + sqlState = "KD010", parameters = Map.empty ) } + test("SPARK-49895: trailing comma in select statement") { + withTable("t1") { + sql(s"CREATE TABLE t1 (c1 INT, c2 INT) USING PARQUET") + + val queries = Seq( + "SELECT *? FROM t1", + "SELECT c1? FROM t1", + "SELECT c1? FROM t1 WHERE c1 = 1", + "SELECT c1? FROM t1 GROUP BY c1", + "SELECT *, RANK() OVER (ORDER BY c1)? FROM t1", + "SELECT c1? FROM t1 ORDER BY c1", + "WITH cte AS (SELECT c1? FROM t1) SELECT * FROM cte", + "WITH cte AS (SELECT c1 FROM t1) SELECT *? FROM cte", + "SELECT * FROM (SELECT c1? FROM t1)") + + queries.foreach { query => + val queryWithoutTrailingComma = query.replaceAll("\\?", "") + val queryWithTrailingComma = query.replaceAll("\\?", ",") + + sql(queryWithoutTrailingComma) + print(queryWithTrailingComma) + val exception = intercept[AnalysisException] { + sql(queryWithTrailingComma) + } + assert(exception.getCondition === "TRAILING_COMMA_IN_SELECT") + } + + val unresolvedColumnErrors = Seq( + "SELECT c3 FROM t1", + "SELECT from FROM t1", + "SELECT from FROM (SELECT 'a' as c1)", + "SELECT from AS col FROM t1", + "SELECT from AS from FROM t1", + "SELECT from from FROM t1") + unresolvedColumnErrors.foreach { query => + val exception = intercept[AnalysisException] { + sql(query) + } + assert(exception.getCondition === "UNRESOLVED_COLUMN.WITH_SUGGESTION") + } + + // sanity checks + withTable("from") { + sql(s"CREATE TABLE from (from INT) USING PARQUET") + + sql(s"SELECT from FROM from") + sql(s"SELECT from as from FROM from") + sql(s"SELECT from from FROM from from") + sql(s"SELECT c1, from FROM VALUES(1, 2) AS T(c1, from)") + + intercept[ParseException] { + sql("SELECT 1,") + } + } + } + } } class MyCastToString extends SparkUserDefinedFunction( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala index ec92e0b700e31..2e0983fe0319c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala @@ -391,4 +391,14 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest } } } + + test("SPARK-49773: INVALID_TIMEZONE for bad timezone") { + checkError( + exception = intercept[SparkDateTimeException] { + sql("select make_timestamp(1, 2, 28, 23, 1, 1, -100)").collect() + }, + condition = "INVALID_TIMEZONE", + parameters = Map("timeZone" -> "-100") + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala index 00dfd3451d577..1adb1fdf05032 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala @@ -35,11 +35,12 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, Encoder, Kry import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{NamedParameter, UnresolvedGenerator} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Concat, CreateArray, EmptyRow, Expression, Flatten, Grouping, Literal, RowNumber, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Concat, CreateArray, EmptyRow, Expression, Flatten, Grouping, Literal, RowNumber, UnaryExpression, Years} import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.expressions.objects.InitializeJavaBean import org.apache.spark.sql.catalyst.rules.RuleIdCollection +import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLExpr import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions} import org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProvider import org.apache.spark.sql.execution.datasources.orc.OrcTest @@ -292,7 +293,7 @@ class QueryExecutionErrorsSuite val e = intercept[SparkException] { df.write.parquet(dir.getCanonicalPath) } - assert(e.getErrorClass == "TASK_WRITE_FAILED") + assert(e.getCondition == "TASK_WRITE_FAILED") val format = "Parquet" val config = "\"" + SQLConf.PARQUET_REBASE_MODE_IN_WRITE.key + "\"" @@ -311,7 +312,7 @@ class QueryExecutionErrorsSuite val ex = intercept[SparkException] { spark.read.schema("time timestamp_ntz").orc(file.getCanonicalPath).collect() } - assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(ex.getCondition.startsWith("FAILED_READ_FILE")) checkError( exception = ex.getCause.asInstanceOf[SparkUnsupportedOperationException], condition = "UNSUPPORTED_FEATURE.ORC_TYPE_CAST", @@ -333,7 +334,7 @@ class QueryExecutionErrorsSuite val ex = intercept[SparkException] { spark.read.schema("time timestamp_ltz").orc(file.getCanonicalPath).collect() } - assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(ex.getCondition.startsWith("FAILED_READ_FILE")) checkError( exception = ex.getCause.asInstanceOf[SparkUnsupportedOperationException], condition = "UNSUPPORTED_FEATURE.ORC_TYPE_CAST", @@ -381,7 +382,7 @@ class QueryExecutionErrorsSuite } val e2 = e1.getCause.asInstanceOf[SparkException] - assert(e2.getErrorClass == "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION") + assert(e2.getCondition == "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION") checkError( exception = e2.getCause.asInstanceOf[SparkRuntimeException], @@ -767,7 +768,40 @@ class QueryExecutionErrorsSuite parameters = Map( "value1" -> "127S", "symbol" -> "+", - "value2" -> "5S"), + "value2" -> "5S", + "functionName" -> "`try_add`"), + sqlState = "22003") + } + } + + test("BINARY_ARITHMETIC_OVERFLOW: byte minus byte result overflow") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + checkError( + exception = intercept[SparkArithmeticException] { + sql(s"select -2Y - 127Y").collect() + }, + condition = "BINARY_ARITHMETIC_OVERFLOW", + parameters = Map( + "value1" -> "-2S", + "symbol" -> "-", + "value2" -> "127S", + "functionName" -> "`try_subtract`"), + sqlState = "22003") + } + } + + test("BINARY_ARITHMETIC_OVERFLOW: byte multiply byte result overflow") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + checkError( + exception = intercept[SparkArithmeticException] { + sql(s"select 127Y * 5Y").collect() + }, + condition = "BINARY_ARITHMETIC_OVERFLOW", + parameters = Map( + "value1" -> "127S", + "symbol" -> "*", + "value2" -> "5S", + "functionName" -> "`try_multiply`"), sqlState = "22003") } } @@ -887,7 +921,7 @@ class QueryExecutionErrorsSuite val e = intercept[StreamingQueryException] { query.awaitTermination() } - assert(e.getErrorClass === "STREAM_FAILED") + assert(e.getCondition === "STREAM_FAILED") assert(e.getCause.isInstanceOf[NullPointerException]) } @@ -973,6 +1007,17 @@ class QueryExecutionErrorsSuite sqlState = "XX000") } + test("PartitionTransformExpression error on eval") { + val expr = Years(Literal("foo")) + val e = intercept[SparkException] { + expr.eval() + } + checkError( + exception = e, + condition = "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY", + parameters = Map("expression" -> toSQLExpr(expr))) + } + test("INTERNAL_ERROR: Calling doGenCode on unresolved") { val e = intercept[SparkException] { val ctx = new CodegenContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala index da7b6e7f63c85..666f85e19c1c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala @@ -334,7 +334,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL sqlState = "42000", parameters = Map( "statement" -> "CREATE TEMPORARY FUNCTION", - "funcName" -> "`ns`.`db`.`func`"), + "name" -> "`ns`.`db`.`func`"), context = ExpectedContext( fragment = sqlText, start = 0, @@ -367,7 +367,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL sqlState = "42000", parameters = Map( "statement" -> "DROP TEMPORARY FUNCTION", - "funcName" -> "`db`.`func`"), + "name" -> "`db`.`func`"), context = ExpectedContext( fragment = sqlText, start = 0, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index a80444feb68ae..fc1c9c6755572 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, Un import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Concat, GreaterThan, Literal, NullsFirst, SortOrder, UnresolvedWindowExpression, UnspecifiedFrame, WindowSpecDefinition, WindowSpecReference} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, PROJECT, UNRESOLVED_RELATION} +import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.connector.catalog.TableCatalog import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.{CreateTempViewUsing, RefreshResource} @@ -887,14 +887,74 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { // Basic selection. // Here we check that every parsed plan contains a projection and a source relation or // inline table. - def checkPipeSelect(query: String): Unit = { + def check(query: String, patterns: Seq[TreePattern]): Unit = { val plan: LogicalPlan = parser.parsePlan(query) - assert(plan.containsPattern(PROJECT)) + assert(patterns.exists(plan.containsPattern)) assert(plan.containsAnyPattern(UNRESOLVED_RELATION, LOCAL_RELATION)) } + def checkPipeSelect(query: String): Unit = check(query, Seq(PROJECT)) checkPipeSelect("TABLE t |> SELECT 1 AS X") checkPipeSelect("TABLE t |> SELECT 1 AS X, 2 AS Y |> SELECT X + Y AS Z") checkPipeSelect("VALUES (0), (1) tab(col) |> SELECT col * 2 AS result") + // Basic WHERE operators. + def checkPipeWhere(query: String): Unit = check(query, Seq(FILTER)) + checkPipeWhere("TABLE t |> WHERE X = 1") + checkPipeWhere("TABLE t |> SELECT X, LENGTH(Y) AS Z |> WHERE X + LENGTH(Y) < 4") + checkPipeWhere("TABLE t |> WHERE X = 1 AND Y = 2 |> WHERE X + Y = 3") + checkPipeWhere("VALUES (0), (1) tab(col) |> WHERE col < 1") + // PIVOT and UNPIVOT operations + def checkPivotUnpivot(query: String): Unit = check(query, Seq(PIVOT, UNPIVOT)) + checkPivotUnpivot( + """ + |SELECT * FROM VALUES + | ("dotNET", 2012, 10000), + | ("Java", 2012, 20000), + | ("dotNET", 2012, 5000), + | ("dotNET", 2013, 48000), + | ("Java", 2013, 30000) + | AS courseSales(course, year, earnings) + ||> PIVOT ( + | SUM(earnings) + | FOR course IN ('dotNET', 'Java') + |) + |""".stripMargin) + checkPivotUnpivot( + """ + |SELECT * FROM VALUES + | ("dotNET", 15000, 48000, 22500), + | ("Java", 20000, 30000, NULL) + | AS courseEarnings(course, `2012`, `2013`, `2014`) + ||> UNPIVOT ( + | earningsYear FOR year IN (`2012`, `2013`, `2014`) + |) + |""".stripMargin) + // Sampling operations + def checkSample(query: String): Unit = { + val plan: LogicalPlan = parser.parsePlan(query) + assert(plan.collectFirst(_.isInstanceOf[Sample]).nonEmpty) + assert(plan.containsAnyPattern(UNRESOLVED_RELATION, LOCAL_RELATION)) + } + checkSample("TABLE t |> TABLESAMPLE (50 PERCENT)") + checkSample("TABLE t |> TABLESAMPLE (5 ROWS)") + checkSample("TABLE t |> TABLESAMPLE (BUCKET 4 OUT OF 10)") + // Joins. + def checkPipeJoin(query: String): Unit = check(query, Seq(JOIN)) + Seq("", "INNER", "LEFT", "LEFT OUTER", "SEMI", "LEFT SEMI", "RIGHT", "RIGHT OUTER", "FULL", + "FULL OUTER", "ANTI", "LEFT ANTI", "CROSS").foreach { joinType => + checkPipeJoin(s"TABLE t |> $joinType JOIN other ON (t.x = other.x)") + } + // Set operations + def checkDistinct(query: String): Unit = check(query, Seq(DISTINCT_LIKE)) + def checkExcept(query: String): Unit = check(query, Seq(EXCEPT)) + def checkIntersect(query: String): Unit = check(query, Seq(INTERSECT)) + def checkUnion(query: String): Unit = check(query, Seq(UNION)) + checkDistinct("TABLE t |> UNION DISTINCT TABLE t") + checkExcept("TABLE t |> EXCEPT ALL TABLE t") + checkExcept("TABLE t |> EXCEPT DISTINCT TABLE t") + checkExcept("TABLE t |> MINUS ALL TABLE t") + checkExcept("TABLE t |> MINUS DISTINCT TABLE t") + checkIntersect("TABLE t |> INTERSECT ALL TABLE t") + checkUnion("TABLE t |> UNION ALL TABLE t") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 75f016d050de9..c5e64c96b2c8a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -904,7 +904,7 @@ class AdaptiveQueryExecSuite val error = intercept[SparkException] { aggregated.count() } - assert(error.getErrorClass === "INVALID_BUCKET_FILE") + assert(error.getCondition === "INVALID_BUCKET_FILE") assert(error.getMessage contains "Invalid bucket file") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 275b35947182c..c90b1d3ca5978 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -1217,8 +1217,8 @@ class ArrowConvertersSuite extends SharedSparkSession { val tempFile1 = new File(tempDataPath, "testData2-ints-part1.json") val tempFile2 = new File(tempDataPath, "testData2-ints-part2.json") - Files.write(json1, tempFile1, StandardCharsets.UTF_8) - Files.write(json2, tempFile2, StandardCharsets.UTF_8) + Files.asCharSink(tempFile1, StandardCharsets.UTF_8).write(json1) + Files.asCharSink(tempFile2, StandardCharsets.UTF_8).write(json2) validateConversion(schema, arrowBatches(0), tempFile1) validateConversion(schema, arrowBatches(1), tempFile2) @@ -1501,7 +1501,7 @@ class ArrowConvertersSuite extends SharedSparkSession { // NOTE: coalesce to single partition because can only load 1 batch in validator val batchBytes = df.coalesce(1).toArrowBatchRdd.collect().head val tempFile = new File(tempDataPath, file) - Files.write(json, tempFile, StandardCharsets.UTF_8) + Files.asCharSink(tempFile, StandardCharsets.UTF_8).write(json) validateConversion(df.schema, batchBytes, tempFile, timeZoneId, errorOnDuplicatedFieldNames) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala index 05ae575305299..290cfd56b8bce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala @@ -91,7 +91,7 @@ object CompressionSchemeBenchmark extends BenchmarkBase with AllCompressionSchem schemes.filter(_.supports(tpe)).foreach { scheme => val (compressFunc, compressionRatio, buf) = prepareEncodeInternal(count, tpe, scheme, input) - val label = s"${getFormattedClassName(scheme)}(${compressionRatio.formatted("%.3f")})" + val label = s"${getFormattedClassName(scheme)}(${"%.3f".format(compressionRatio)})" benchmark.addCase(label)({ i: Int => for (n <- 0L until iters) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index 176eb7c290764..8b868c0e17230 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -688,7 +688,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { checkError( exception = parseException(sql1), condition = "INVALID_SQL_SYNTAX.MULTI_PART_NAME", - parameters = Map("statement" -> "DROP TEMPORARY FUNCTION", "funcName" -> "`a`.`b`"), + parameters = Map("statement" -> "DROP TEMPORARY FUNCTION", "name" -> "`a`.`b`"), context = ExpectedContext( fragment = sql1, start = 0, @@ -698,7 +698,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { checkError( exception = parseException(sql2), condition = "INVALID_SQL_SYNTAX.MULTI_PART_NAME", - parameters = Map("statement" -> "DROP TEMPORARY FUNCTION", "funcName" -> "`a`.`b`"), + parameters = Map("statement" -> "DROP TEMPORARY FUNCTION", "name" -> "`a`.`b`"), context = ExpectedContext( fragment = sql2, start = 0, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 8307326f17fcf..e07f6406901e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.fs.permission.{AclEntry, AclStatus} import org.apache.spark.{SparkClassNotFoundException, SparkException, SparkFiles, SparkRuntimeException} import org.apache.spark.internal.config import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} -import org.apache.spark.sql.catalyst.{FullQualifiedTableName, FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.TempTableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec @@ -219,7 +219,7 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSparkSession { test("SPARK-25403 refresh the table after inserting data") { withTable("t") { val catalog = spark.sessionState.catalog - val table = FullQualifiedTableName( + val table = QualifiedTableName( CatalogManager.SESSION_CATALOG_NAME, catalog.getCurrentDatabase, "t") sql("CREATE TABLE t (a INT) USING parquet") sql("INSERT INTO TABLE t VALUES (1)") @@ -233,7 +233,7 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSparkSession { withTable("t") { withTempDir { dir => val catalog = spark.sessionState.catalog - val table = FullQualifiedTableName( + val table = QualifiedTableName( CatalogManager.SESSION_CATALOG_NAME, catalog.getCurrentDatabase, "t") val p1 = s"${dir.getCanonicalPath}/p1" val p2 = s"${dir.getCanonicalPath}/p2" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/TruncateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/TruncateTableSuite.scala index 348b216aeb044..40ae35bbe8aa3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/TruncateTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/TruncateTableSuite.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.permission.{AclEntry, AclEntryScope, AclEntryType, FsAction, FsPermission} import org.apache.spark.sql.{AnalysisException, Row} -import org.apache.spark.sql.catalyst.{FullQualifiedTableName, TableIdentifier} +import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.execution.command import org.apache.spark.sql.execution.command.FakeLocalFsFileSystem @@ -148,7 +148,7 @@ trait TruncateTableSuiteBase extends command.TruncateTableSuiteBase { val catalog = spark.sessionState.catalog val qualifiedTableName = - FullQualifiedTableName(CatalogManager.SESSION_CATALOG_NAME, "ns", "tbl") + QualifiedTableName(CatalogManager.SESSION_CATALOG_NAME, "ns", "tbl") val cachedPlan = catalog.getCachedTable(qualifiedTableName) assert(cachedPlan.stats.sizeInBytes == 0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index 31b7380889158..e9f78f9f598e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -566,7 +566,7 @@ class FileIndexSuite extends SharedSparkSession { new File(directoryPath, "part_col=1").renameTo(new File(directoryPath, "undefined")) // By default, we expect the invalid path assertion to trigger. - val ex = intercept[AssertionError] { + val ex = intercept[SparkRuntimeException] { spark.read .format("parquet") .load(directoryPath.getCanonicalPath) @@ -585,7 +585,7 @@ class FileIndexSuite extends SharedSparkSession { // Data source option override takes precedence. withSQLConf(SQLConf.IGNORE_INVALID_PARTITION_PATHS.key -> "true") { - val ex = intercept[AssertionError] { + val ex = intercept[SparkRuntimeException] { spark.read .format("parquet") .option(FileIndexOptions.IGNORE_INVALID_PARTITION_PATHS, "false") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala index deb62eb3ac234..387a2baa256bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala @@ -368,7 +368,7 @@ class BinaryFileFormatSuite extends QueryTest with SharedSparkSession { checkAnswer(readContent(), expected) } } - assert(caught.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(caught.getCondition.startsWith("FAILED_READ_FILE")) assert(caught.getCause.getMessage.contains("exceeds the max length allowed")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParsingOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParsingOptionsSuite.scala new file mode 100644 index 0000000000000..8c8304503cef8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParsingOptionsSuite.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.test.SharedSparkSession + +class CSVParsingOptionsSuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + test("SPARK-49955: null string value does not mean corrupted file") { + val str = "abc" + val stringDataset = Seq(str, null).toDS() + val df = spark.read.csv(stringDataset) + // `spark.read.csv(rdd)` removes all null values at the beginning. + checkAnswer(df, Seq(Row("abc"))) + val df2 = spark.read.option("mode", "failfast").csv(stringDataset) + checkAnswer(df2, Seq(Row("abc"))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index e2d1d9b05c3c2..422ae02a18322 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -85,6 +85,7 @@ abstract class CSVSuite private val badAfterGoodFile = "test-data/bad_after_good.csv" private val malformedRowFile = "test-data/malformedRow.csv" private val charFile = "test-data/char.csv" + private val moreColumnsFile = "test-data/more-columns.csv" /** Verifies data and schema. */ private def verifyCars( @@ -391,7 +392,7 @@ abstract class CSVSuite condition = "FAILED_READ_FILE.NO_HINT", parameters = Map("path" -> s".*$carsFile.*")) val e2 = e1.getCause.asInstanceOf[SparkException] - assert(e2.getErrorClass == "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION") + assert(e2.getCondition == "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION") checkError( exception = e2.getCause.asInstanceOf[SparkRuntimeException], condition = "MALFORMED_CSV_RECORD", @@ -3439,6 +3440,39 @@ abstract class CSVSuite expected) } } + + test("SPARK-49444: CSV parsing failure with more than max columns") { + val schema = new StructType() + .add("intColumn", IntegerType, nullable = true) + .add("decimalColumn", DecimalType(10, 2), nullable = true) + + val fileReadException = intercept[SparkException] { + spark + .read + .schema(schema) + .option("header", "false") + .option("maxColumns", "2") + .csv(testFile(moreColumnsFile)) + .collect() + } + + checkErrorMatchPVals( + exception = fileReadException, + condition = "FAILED_READ_FILE.NO_HINT", + parameters = Map("path" -> s".*$moreColumnsFile")) + + val malformedCSVException = fileReadException.getCause.asInstanceOf[SparkRuntimeException] + + checkError( + exception = malformedCSVException, + condition = "MALFORMED_CSV_RECORD", + parameters = Map("badRecord" -> "1,3.14,string,5,7"), + sqlState = "KD000") + + assert(malformedCSVException.getCause.isInstanceOf[TextParsingException]) + val textParsingException = malformedCSVException.getCause.asInstanceOf[TextParsingException] + assert(textParsingException.getCause.isInstanceOf[ArrayIndexOutOfBoundsException]) + } } class CSVv1Suite extends CSVSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala index 703085dca66f1..11cc0b99bbde7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.json +import org.apache.spark.SparkException import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{StringType, StructType} @@ -185,4 +186,14 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSparkSession { assert(df.first().getString(0) == "Cazen Lee") assert(df.first().getString(1) == "$10") } + + test("SPARK-49955: null string value does not mean corrupted file") { + val str = "{\"name\": \"someone\"}" + val stringDataset = Seq(str, null).toDS() + val df = spark.read.json(stringDataset) + checkAnswer(df, Seq(Row(null, "someone"), Row(null, null))) + + val e = intercept[SparkException](spark.read.option("mode", "failfast").json(stringDataset)) + assert(e.getCause.isInstanceOf[NullPointerException]) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index f13d66b76838f..500c0647bcb2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -708,7 +708,7 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { val ex = intercept[SparkException] { sql(s"select A from $tableName where A < 0").collect() } - assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(ex.getCondition.startsWith("FAILED_READ_FILE")) assert(ex.getCause.isInstanceOf[SparkRuntimeException]) assert(ex.getCause.getMessage.contains( """Found duplicate field(s) "A": [A, a] in case-insensitive mode""")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index 2e6413d998d12..ab0d4d9bc53b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -604,14 +604,14 @@ abstract class OrcQueryTest extends OrcTest { val e1 = intercept[SparkException] { testIgnoreCorruptFiles() } - assert(e1.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(e1.getCondition.startsWith("FAILED_READ_FILE")) assert(e1.getCause.getMessage.contains("Malformed ORC file") || // Hive ORC table scan uses a different code path and has one more error stack e1.getCause.getCause.getMessage.contains("Malformed ORC file")) val e2 = intercept[SparkException] { testIgnoreCorruptFilesWithoutSchemaInfer() } - assert(e2.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(e2.getCondition.startsWith("FAILED_READ_FILE")) assert(e2.getCause.getMessage.contains("Malformed ORC file") || // Hive ORC table scan uses a different code path and has one more error stack e2.getCause.getCause.getMessage.contains("Malformed ORC file")) @@ -625,7 +625,7 @@ abstract class OrcQueryTest extends OrcTest { val e4 = intercept[SparkException] { testAllCorruptFilesWithoutSchemaInfer() } - assert(e4.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(e4.getCondition.startsWith("FAILED_READ_FILE")) assert(e4.getCause.getMessage.contains("Malformed ORC file") || // Hive ORC table scan uses a different code path and has one more error stack e4.getCause.getCause.getMessage.contains("Malformed ORC file")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 9348d10711b35..040999476ece1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -450,8 +450,8 @@ abstract class OrcSuite val ex = intercept[SparkException] { spark.read.orc(basePath).columns.length } - assert(ex.getErrorClass == "CANNOT_MERGE_SCHEMAS") - assert(ex.getCause.asInstanceOf[SparkException].getErrorClass === + assert(ex.getCondition == "CANNOT_MERGE_SCHEMAS") + assert(ex.getCause.asInstanceOf[SparkException].getCondition === "CANNOT_MERGE_INCOMPATIBLE_DATA_TYPE") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 5c382b1858716..903dda7f41c0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -1958,7 +1958,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared val ex = intercept[SparkException] { sql(s"select a from $tableName where b > 0").collect() } - assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(ex.getCondition.startsWith("FAILED_READ_FILE")) assert(ex.getCause.isInstanceOf[SparkRuntimeException]) assert(ex.getCause.getMessage.contains( """Found duplicate field(s) "B": [B, b] in case-insensitive mode""")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 0afa545595c77..95fb178154929 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -1223,7 +1223,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession val m1 = intercept[SparkException] { spark.range(1).coalesce(1).write.options(extraOptions).parquet(dir.getCanonicalPath) } - assert(m1.getErrorClass == "TASK_WRITE_FAILED") + assert(m1.getCondition == "TASK_WRITE_FAILED") assert(m1.getCause.getMessage.contains("Intentional exception for testing purposes")) } @@ -1233,8 +1233,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession .coalesce(1) df.write.partitionBy("a").options(extraOptions).parquet(dir.getCanonicalPath) } - if (m2.getErrorClass != null) { - assert(m2.getErrorClass == "TASK_WRITE_FAILED") + if (m2.getCondition != null) { + assert(m2.getCondition == "TASK_WRITE_FAILED") assert(m2.getCause.getMessage.contains("Intentional exception for testing purposes")) } else { assert(m2.getMessage.contains("TASK_WRITE_FAILED")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 52d67a0954325..87a2843f34de1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -111,7 +111,7 @@ abstract class ParquetPartitionDiscoverySuite "hdfs://host:9000/path/a=10/b=20", "hdfs://host:9000/path/a=10.5/b=hello") - var exception = intercept[AssertionError] { + var exception = intercept[SparkRuntimeException] { parsePartitions( paths.map(new Path(_)), true, Set.empty[Path], None, true, true, timeZoneId, false) } @@ -173,7 +173,7 @@ abstract class ParquetPartitionDiscoverySuite "hdfs://host:9000/path/a=10/b=20", "hdfs://host:9000/path/path1") - exception = intercept[AssertionError] { + exception = intercept[SparkRuntimeException] { parsePartitions( paths.map(new Path(_)), true, @@ -197,7 +197,7 @@ abstract class ParquetPartitionDiscoverySuite "hdfs://host:9000/tmp/tables/nonPartitionedTable1", "hdfs://host:9000/tmp/tables/nonPartitionedTable2") - exception = intercept[AssertionError] { + exception = intercept[SparkRuntimeException] { parsePartitions( paths.map(new Path(_)), true, @@ -878,7 +878,7 @@ abstract class ParquetPartitionDiscoverySuite checkAnswer(twoPartitionsDF, df.filter("b != 3")) - intercept[AssertionError] { + intercept[SparkRuntimeException] { spark .read .parquet( @@ -1181,7 +1181,7 @@ abstract class ParquetPartitionDiscoverySuite spark.read.parquet(dir.toString) } val msg = exception.getMessage - assert(exception.getErrorClass === "CONFLICTING_PARTITION_COLUMN_NAMES") + assert(exception.getCondition === "CONFLICTING_PARTITION_COLUMN_NAMES") // Partitions inside the error message can be presented in any order assert("Partition column name list #[0-1]: col1".r.findFirstIn(msg).isDefined) assert("Partition column name list #[0-1]: col1, col2".r.findFirstIn(msg).isDefined) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 4d413efe50430..22a02447e720f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -1075,7 +1075,7 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS val e = intercept[SparkException] { readParquet("d DECIMAL(3, 2)", path).collect() } - assert(e.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(e.getCondition.startsWith("FAILED_READ_FILE")) assert(e.getCause.getMessage.contains("Please read this column/field as Spark BINARY type")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala index 6d9092391a98e..30503af0fab6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala @@ -414,7 +414,7 @@ abstract class ParquetRebaseDatetimeSuite val e = intercept[SparkException] { df.write.parquet(dir.getCanonicalPath) } - assert(e.getErrorClass == "TASK_WRITE_FAILED") + assert(e.getCondition == "TASK_WRITE_FAILED") val errMsg = e.getCause.asInstanceOf[SparkUpgradeException].getMessage assert(errMsg.contains("You may get a different result due to the upgrading")) } @@ -431,7 +431,7 @@ abstract class ParquetRebaseDatetimeSuite val e = intercept[SparkException] { df.write.parquet(dir.getCanonicalPath) } - assert(e.getErrorClass == "TASK_WRITE_FAILED") + assert(e.getCondition == "TASK_WRITE_FAILED") val errMsg = e.getCause.asInstanceOf[SparkUpgradeException].getMessage assert(errMsg.contains("You may get a different result due to the upgrading")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala index 95378d9467478..08fd8a9ecb53e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala @@ -319,7 +319,7 @@ class ParquetRowIndexSuite extends QueryTest with SharedSparkSession { .load(path.getAbsolutePath) val exception = intercept[SparkException](dfRead.collect()) - assert(exception.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(exception.getCondition.startsWith("FAILED_READ_FILE")) assert(exception.getCause.getMessage.contains( ParquetFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala index 4833b8630134c..59c0af8afd198 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala @@ -90,7 +90,7 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB .option(StateSourceOptions.CHANGE_END_BATCH_ID, 2) .load(tempDir.getAbsolutePath) } - assert(exc.getErrorClass === "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE") + assert(exc.getCondition === "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE") } } @@ -103,7 +103,7 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB .option(StateSourceOptions.CHANGE_END_BATCH_ID, 0) .load(tempDir.getAbsolutePath) } - assert(exc.getErrorClass === "STDS_INVALID_OPTION_VALUE.IS_NEGATIVE") + assert(exc.getCondition === "STDS_INVALID_OPTION_VALUE.IS_NEGATIVE") } } @@ -116,7 +116,7 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB .option(StateSourceOptions.CHANGE_END_BATCH_ID, 0) .load(tempDir.getAbsolutePath) } - assert(exc.getErrorClass === "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE") + assert(exc.getCondition === "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE") } } @@ -130,7 +130,7 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB .option(StateSourceOptions.CHANGE_END_BATCH_ID, 0) .load(tempDir.getAbsolutePath) } - assert(exc.getErrorClass === "STDS_CONFLICT_OPTIONS") + assert(exc.getCondition === "STDS_CONFLICT_OPTIONS") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala index af07707569500..300da03f73e1f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala @@ -287,6 +287,44 @@ class StateDataSourceNegativeTestSuite extends StateDataSourceTestBase { matchPVals = true) } } + + test("ERROR: trying to specify state variable name along with " + + "readRegisteredTimers should fail") { + withTempDir { tempDir => + val exc = intercept[StateDataSourceConflictOptions] { + spark.read.format("statestore") + // trick to bypass getting the last committed batch before validating operator ID + .option(StateSourceOptions.BATCH_ID, 0) + .option(StateSourceOptions.STATE_VAR_NAME, "test") + .option(StateSourceOptions.READ_REGISTERED_TIMERS, true) + .load(tempDir.getAbsolutePath) + } + checkError(exc, "STDS_CONFLICT_OPTIONS", "42613", + Map("options" -> + s"['${ + StateSourceOptions.READ_REGISTERED_TIMERS + }', '${StateSourceOptions.STATE_VAR_NAME}']")) + } + } + + test("ERROR: trying to specify non boolean value for " + + "flattenCollectionTypes") { + withTempDir { tempDir => + runDropDuplicatesQuery(tempDir.getAbsolutePath) + + val exc = intercept[StateDataSourceInvalidOptionValue] { + spark.read.format("statestore") + // trick to bypass getting the last committed batch before validating operator ID + .option(StateSourceOptions.BATCH_ID, 0) + .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, "test") + .load(tempDir.getAbsolutePath) + } + checkError(exc, "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE", Some("42616"), + Map("optionName" -> StateSourceOptions.FLATTEN_COLLECTION_TYPES, + "message" -> ".*"), + matchPVals = true) + } + } } /** @@ -1099,7 +1137,7 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass val exc = intercept[StateStoreSnapshotPartitionNotFound] { stateDfError.show() } - assert(exc.getErrorClass === "CANNOT_LOAD_STATE_STORE.SNAPSHOT_PARTITION_ID_NOT_FOUND") + assert(exc.getCondition === "CANNOT_LOAD_STATE_STORE.SNAPSHOT_PARTITION_ID_NOT_FOUND") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala index 61091fde35e79..84c6eb54681a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala @@ -21,9 +21,9 @@ import java.time.Duration import org.apache.spark.sql.{Encoders, Row} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider, TestClass} -import org.apache.spark.sql.functions.explode +import org.apache.spark.sql.functions.{explode, timestamp_seconds} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.{ExpiredTimerInfo, InputMapRow, ListState, MapInputEvent, MapOutputEvent, MapStateTTLProcessor, OutputMode, RunningCountStatefulProcessor, StatefulProcessor, StateStoreMetricsTest, TestMapStateProcessor, TimeMode, TimerValues, TransformWithStateSuiteUtils, Trigger, TTLConfig, ValueState} +import org.apache.spark.sql.streaming.{ExpiredTimerInfo, InputMapRow, ListState, MapInputEvent, MapOutputEvent, MapStateTTLProcessor, MaxEventTimeStatefulProcessor, OutputMode, RunningCountStatefulProcessor, RunningCountStatefulProcessorWithProcTimeTimerUpdates, StatefulProcessor, StateStoreMetricsTest, TestMapStateProcessor, TimeMode, TimerValues, TransformWithStateSuiteUtils, Trigger, TTLConfig, ValueState} import org.apache.spark.sql.streaming.util.StreamManualClock /** Stateful processor of single value state var with non-primitive type */ @@ -159,7 +159,7 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest val resultDf = stateReaderDf.selectExpr( "key.value AS groupingKey", - "single_value.id AS valueId", "single_value.name AS valueName", + "value.id AS valueId", "value.name AS valueName", "partition_id") checkAnswer(resultDf, @@ -176,17 +176,59 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest assert(ex.isInstanceOf[StateDataSourceInvalidOptionValue]) assert(ex.getMessage.contains("State variable non-exist is not defined")) - // TODO: this should be removed when readChangeFeed is supported for value state + // Verify that trying to read timers in TimeMode as None fails val ex1 = intercept[Exception] { spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.READ_REGISTERED_TIMERS, true) + .load() + } + assert(ex1.isInstanceOf[StateDataSourceInvalidOptionValue]) + assert(ex1.getMessage.contains("Registered timers are not available")) + } + } + } + + testWithChangelogCheckpointingEnabled("state data source cdf integration - " + + "value state with single variable") { + withTempDir { tempDir => + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new StatefulProcessorWithSingleValueVar(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + AddData(inputData, "b"), + CheckNewAnswer(("b", "1")), + StopStream + ) + + val changeFeedDf = spark.read .format("statestore") .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) .option(StateSourceOptions.STATE_VAR_NAME, "valueState") - .option(StateSourceOptions.READ_CHANGE_FEED, "true") + .option(StateSourceOptions.READ_CHANGE_FEED, true) .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) .load() - } - assert(ex1.isInstanceOf[StateDataSourceConflictOptions]) + + val opDf = changeFeedDf.selectExpr( + "change_type", + "key.value AS groupingKey", + "value.id AS valueId", "value.name AS valueName", + "partition_id") + + checkAnswer(opDf, + Seq(Row("update", "a", 1L, "dummyKey", 0), Row("update", "b", 1L, "dummyKey", 1))) } } } @@ -222,7 +264,7 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest .load() val resultDf = stateReaderDf.selectExpr( - "key.value", "single_value.value", "single_value.ttlExpirationMs", "partition_id") + "key.value", "value.value", "value.ttlExpirationMs", "partition_id") var count = 0L resultDf.collect().foreach { row => @@ -235,7 +277,7 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest val answerDf = stateReaderDf.selectExpr( "key.value AS groupingKey", - "single_value.value.value AS valueId", "partition_id") + "value.value.value AS valueId", "partition_id") checkAnswer(answerDf, Seq(Row("a", 1L, 0), Row("b", 1L, 1))) @@ -249,19 +291,61 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest } assert(ex.isInstanceOf[StateDataSourceInvalidOptionValue]) assert(ex.getMessage.contains("State variable non-exist is not defined")) + } + } + } - // TODO: this should be removed when readChangeFeed is supported for TTL based state - // variables - val ex1 = intercept[Exception] { - spark.read - .format("statestore") - .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) - .option(StateSourceOptions.STATE_VAR_NAME, "countState") - .option(StateSourceOptions.READ_CHANGE_FEED, "true") - .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) - .load() + testWithChangelogCheckpointingEnabled("state data source cdf integration - " + + "value state with single variable and TTL") { + withTempDir { tempDir => + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new StatefulProcessorWithTTL(), + TimeMode.ProcessingTime(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, "a"), + AddData(inputData, "b"), + Execute { _ => + // wait for the batch to run since we are using processing time + Thread.sleep(5000) + }, + StopStream + ) + + val stateReaderDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "countState") + .option(StateSourceOptions.READ_CHANGE_FEED, true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .load() + + val resultDf = stateReaderDf.selectExpr( + "key.value", "value.value", "value.ttlExpirationMs", "partition_id") + + var count = 0L + resultDf.collect().foreach { row => + count = count + 1 + assert(row.getLong(2) > 0) } - assert(ex1.isInstanceOf[StateDataSourceConflictOptions]) + + // verify that 2 state rows are present + assert(count === 2) + + val answerDf = stateReaderDf.selectExpr( + "change_type", + "key.value AS groupingKey", + "value.value.value AS valueId", "partition_id") + checkAnswer(answerDf, + Seq(Row("update", "a", 1L, 0), Row("update", "b", 1L, 1))) } } } @@ -290,10 +374,12 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest StopStream ) + // Verify that the state can be read in flattened/non-flattened modes val stateReaderDf = spark.read .format("statestore") .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) .option(StateSourceOptions.STATE_VAR_NAME, "groupsList") + .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, false) .load() val listStateDf = stateReaderDf @@ -307,6 +393,19 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest checkAnswer(listStateDf, Seq(Row("session1", "group1"), Row("session1", "group2"), Row("session1", "group4"), Row("session2", "group1"), Row("session3", "group7"))) + + val flattenedReaderDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "groupsList") + .load() + + val resultDf = flattenedReaderDf.selectExpr( + "key.value AS groupingKey", + "list_element.value AS valueList") + checkAnswer(resultDf, + Seq(Row("session1", "group1"), Row("session1", "group2"), Row("session1", "group4"), + Row("session2", "group1"), Row("session3", "group7"))) } } } @@ -338,10 +437,12 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest StopStream ) + // Verify that the state can be read in flattened/non-flattened modes val stateReaderDf = spark.read .format("statestore") .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) .option(StateSourceOptions.STATE_VAR_NAME, "groupsListWithTTL") + .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, false) .load() val listStateDf = stateReaderDf @@ -368,6 +469,31 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest checkAnswer(valuesDf, Seq(Row("session1", "group1"), Row("session1", "group2"), Row("session1", "group4"), Row("session2", "group1"), Row("session3", "group7"))) + + val flattenedStateReaderDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "groupsListWithTTL") + .load() + + val flattenedResultDf = flattenedStateReaderDf + .selectExpr("list_element.ttlExpirationMs AS ttlExpirationMs") + var flattenedCount = 0L + flattenedResultDf.collect().foreach { row => + flattenedCount = flattenedCount + 1 + assert(row.getLong(0) > 0) + } + + // verify that 5 state rows are present + assert(flattenedCount === 5) + + val outputDf = flattenedStateReaderDf + .selectExpr("key.value AS groupingKey", + "list_element.value.value AS groupId") + + checkAnswer(outputDf, + Seq(Row("session1", "group1"), Row("session1", "group2"), Row("session1", "group4"), + Row("session2", "group1"), Row("session3", "group7"))) } } } @@ -397,10 +523,12 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest StopStream ) + // Verify that the state can be read in flattened/non-flattened modes val stateReaderDf = spark.read .format("statestore") .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) .option(StateSourceOptions.STATE_VAR_NAME, "sessionState") + .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, false) .load() val resultDf = stateReaderDf.selectExpr( @@ -413,6 +541,24 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest Row("k2", Map(Row("v2") -> Row("3")))) ) + + val flattenedStateReaderDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "sessionState") + .load() + + val outputDf = flattenedStateReaderDf + .selectExpr("key.value AS groupingKey", + "user_map_key.value AS mapKey", + "user_map_value.value AS mapValue") + + checkAnswer(outputDf, + Seq( + Row("k1", "v1", "10"), + Row("k1", "v2", "5"), + Row("k2", "v2", "3")) + ) } } } @@ -463,10 +609,12 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest StopStream ) + // Verify that the state can be read in flattened/non-flattened modes val stateReaderDf = spark.read .format("statestore") .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) .option(StateSourceOptions.STATE_VAR_NAME, "mapState") + .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, false) .load() val resultDf = stateReaderDf.selectExpr( @@ -478,6 +626,114 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest Map(Row("key2") -> Row(Row(2), 61000L), Row("key1") -> Row(Row(1), 61000L)))) ) + + val flattenedStateReaderDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "mapState") + .load() + + val outputDf = flattenedStateReaderDf + .selectExpr("key.value AS groupingKey", + "user_map_key.value AS mapKey", + "user_map_value.value.value AS mapValue", + "user_map_value.ttlExpirationMs AS ttlTimestamp") + + checkAnswer(outputDf, + Seq( + Row("k1", "key1", 1, 61000L), + Row("k1", "key2", 2, 61000L)) + ) + } + } + } + + test("state data source - processing-time timers integration") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { tempDir => + val clock = new StreamManualClock + + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState( + new RunningCountStatefulProcessorWithProcTimeTimerUpdates(), + TimeMode.ProcessingTime(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock, + checkpointLocation = tempDir.getCanonicalPath), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "1")), // at batch 0, ts = 1, timer = "a" -> [6] (= 1 + 5) + AddData(inputData, "a"), + AdvanceManualClock(2 * 1000), + CheckNewAnswer(("a", "2")), // at batch 1, ts = 3, timer = "a" -> [10.5] (3 + 7.5) + StopStream) + + val stateReaderDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.READ_REGISTERED_TIMERS, true) + .load() + + val resultDf = stateReaderDf.selectExpr( + "key.value AS groupingKey", + "expiration_timestamp_ms AS expiryTimestamp", + "partition_id") + + checkAnswer(resultDf, + Seq(Row("a", 10500L, 0))) + } + } + } + + test("state data source - event-time timers integration") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { tempDir => + val inputData = MemoryStream[(String, Int)] + val result = + inputData.toDS() + .select($"_1".as("key"), timestamp_seconds($"_2").as("eventTime")) + .withWatermark("eventTime", "10 seconds") + .as[(String, Long)] + .groupByKey(_._1) + .transformWithState( + new MaxEventTimeStatefulProcessor(), + TimeMode.EventTime(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = tempDir.getCanonicalPath), + + AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), + // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5. + CheckNewAnswer(("a", 15)), // Output = max event time of a + + AddData(inputData, ("a", 4)), // Add data older than watermark for "a" + CheckNewAnswer(), // No output as data should get filtered by watermark + StopStream) + + val stateReaderDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.READ_REGISTERED_TIMERS, true) + .load() + + val resultDf = stateReaderDf.selectExpr( + "key.value AS groupingKey", + "expiration_timestamp_ms AS expiryTimestamp", + "partition_id") + + checkAnswer(resultDf, + Seq(Row("a", 20000L, 0))) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SingleJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SingleJoinSuite.scala new file mode 100644 index 0000000000000..a318769af6871 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SingleJoinSuite.scala @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.SparkRuntimeException +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer.BuildRight +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint, Project} +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.exchange.EnsureRequirements +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} + +class SingleJoinSuite extends SparkPlanTest with SharedSparkSession { + import testImplicits.toRichColumn + + private val EnsureRequirements = new EnsureRequirements() + + private lazy val left = spark.createDataFrame( + sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(1, 2.0), + Row(2, 1.0), + Row(2, 1.0), + Row(3, 3.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) + + // (a > c && a != 6) + + private lazy val right = spark.createDataFrame( + sparkContext.parallelize(Seq( + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(4, 2.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + private lazy val singleConditionEQ = EqualTo(left.col("a").expr, right.col("c").expr) + + private lazy val nonEqualityCond = And(GreaterThan(left.col("a").expr, right.col("c").expr), + Not(EqualTo(left.col("a").expr, Literal(6)))) + + + + private def testSingleJoin( + testName: String, + leftRows: => DataFrame, + rightRows: => DataFrame, + condition: => Option[Expression], + expectedAnswer: Seq[Row], + expectError: Boolean = false): Unit = { + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, + Inner, condition, JoinHint.NONE) + ExtractEquiJoinKeys.unapply(join) + } + + def checkSingleJoinError(planFunction: (SparkPlan, SparkPlan) => SparkPlan): Unit = { + val outputPlan = planFunction(leftRows.queryExecution.sparkPlan, + rightRows.queryExecution.sparkPlan) + checkError( + exception = intercept[SparkRuntimeException] { + SparkPlanTest.executePlan(outputPlan, spark.sqlContext) + }, + condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS", + parameters = Map.empty + ) + } + + testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin") { _ => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _, _) => + val planFunction = (left: SparkPlan, right: SparkPlan) => + EnsureRequirements.apply(BroadcastHashJoinExec( + leftKeys, rightKeys, LeftSingle, BuildRight, boundCondition, left, right)) + if (expectError) { + checkSingleJoinError(planFunction) + } else { + checkAnswer2(leftRows, rightRows, planFunction, + expectedAnswer, + sortAnswers = true) + } + } + } + testWithWholeStageCodegenOnAndOff(s"$testName using ShuffledHashJoin") { _ => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _, _) => + val planFunction = (left: SparkPlan, right: SparkPlan) => + EnsureRequirements.apply( + ShuffledHashJoinExec( + leftKeys, rightKeys, LeftSingle, BuildRight, boundCondition, left, right)) + if (expectError) { + checkSingleJoinError(planFunction) + } else { + checkAnswer2(leftRows, rightRows, planFunction, + expectedAnswer, + sortAnswers = true) + } + } + } + + testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastNestedLoopJoin") { _ => + val planFunction = (left: SparkPlan, right: SparkPlan) => + EnsureRequirements.apply( + BroadcastNestedLoopJoinExec(left, right, BuildRight, LeftSingle, condition)) + if (expectError) { + checkSingleJoinError(planFunction) + } else { + checkAnswer2(leftRows, rightRows, planFunction, + expectedAnswer, + sortAnswers = true) + } + } + } + + testSingleJoin( + "test single condition (equal) for a left single join", + left, + Project(Seq(right.col("c").expr.asInstanceOf[NamedExpression]), right.logicalPlan), + Some(singleConditionEQ), + Seq(Row(1, 2.0, null), + Row(1, 2.0, null), + Row(2, 1.0, 2), + Row(2, 1.0, 2), + Row(3, 3.0, 3), + Row(6, null, 6), + Row(null, 5.0, null), + Row(null, null, null))) + + testSingleJoin( + "test single condition (equal) for a left single join -- multiple matches", + left, + Project(Seq(right.col("d").expr.asInstanceOf[NamedExpression]), right.logicalPlan), + Some(EqualTo(left.col("b").expr, right.col("d").expr)), + Seq.empty, true) + + testSingleJoin( + "test non-equality for a left single join", + left, + Project(Seq(right.col("c").expr.asInstanceOf[NamedExpression]), right.logicalPlan), + Some(nonEqualityCond), + Seq(Row(1, 2.0, null), + Row(1, 2.0, null), + Row(2, 1.0, null), + Row(2, 1.0, null), + Row(3, 3.0, 2), + Row(6, null, null), + Row(null, 5.0, null), + Row(null, null, null))) + + testSingleJoin( + "test non-equality for a left single join -- multiple matches", + left, + Project(Seq(right.col("c").expr.asInstanceOf[NamedExpression]), right.logicalPlan), + Some(GreaterThan(left.col("a").expr, right.col("c").expr)), + Seq.empty, expectError = true) + + private lazy val emptyFrame = spark.createDataFrame( + spark.sparkContext.emptyRDD[Row], new StructType().add("c", IntegerType).add("d", DoubleType)) + + testSingleJoin( + "empty inner (right) side", + left, + Project(Seq(emptyFrame.col("c").expr.asInstanceOf[NamedExpression]), emptyFrame.logicalPlan), + Some(GreaterThan(left.col("a").expr, emptyFrame.col("c").expr)), + Seq(Row(1, 2.0, null), + Row(1, 2.0, null), + Row(2, 1.0, null), + Row(2, 1.0, null), + Row(3, 3.0, null), + Row(6, null, null), + Row(null, 5.0, null), + Row(null, null, null))) + + testSingleJoin( + "empty outer (left) side", + Project(Seq(emptyFrame.col("c").expr.asInstanceOf[NamedExpression]), emptyFrame.logicalPlan), + right, + Some(EqualTo(emptyFrame.col("c").expr, right.col("c").expr)), + Seq.empty) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index dcebece29037f..1f2be12058eb7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -330,7 +330,7 @@ class PythonDataSourceSuite extends PythonDataSourceSuiteBase { val err = intercept[AnalysisException] { spark.read.format(dataSourceName).schema(schema).load().collect() } - assert(err.getErrorClass == "PYTHON_DATA_SOURCE_ERROR") + assert(err.getCondition == "PYTHON_DATA_SOURCE_ERROR") assert(err.getMessage.contains("PySparkNotImplementedError")) } @@ -350,7 +350,7 @@ class PythonDataSourceSuite extends PythonDataSourceSuiteBase { val err = intercept[AnalysisException] { spark.read.format(dataSourceName).schema(schema).load().collect() } - assert(err.getErrorClass == "PYTHON_DATA_SOURCE_ERROR") + assert(err.getCondition == "PYTHON_DATA_SOURCE_ERROR") assert(err.getMessage.contains("error creating reader")) } @@ -369,7 +369,7 @@ class PythonDataSourceSuite extends PythonDataSourceSuiteBase { val err = intercept[AnalysisException] { spark.read.format(dataSourceName).schema(schema).load().collect() } - assert(err.getErrorClass == "PYTHON_DATA_SOURCE_ERROR") + assert(err.getCondition == "PYTHON_DATA_SOURCE_ERROR") assert(err.getMessage.contains("DATA_SOURCE_TYPE_MISMATCH")) assert(err.getMessage.contains("PySparkAssertionError")) } @@ -480,7 +480,7 @@ class PythonDataSourceSuite extends PythonDataSourceSuiteBase { spark.dataSource.registerPython(dataSourceName, dataSource) val err = intercept[AnalysisException]( spark.read.format(dataSourceName).load().collect()) - assert(err.getErrorClass == "PYTHON_DATA_SOURCE_ERROR") + assert(err.getCondition == "PYTHON_DATA_SOURCE_ERROR") assert(err.getMessage.contains("partitions")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala index 3a8ce569d1ba9..a2d3318361837 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala @@ -99,7 +99,7 @@ class PythonForeachWriterSuite extends SparkFunSuite with Eventually with Mockit } private val iterator = buffer.iterator private val outputBuffer = new ArrayBuffer[Int] - private val testTimeout = timeout(20.seconds) + private val testTimeout = timeout(60.seconds) private val intProj = UnsafeProjection.create(Array[DataType](IntegerType)) private val thread = new Thread() { override def run(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala index 8d0e1c5f578fa..3d91a045907fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala @@ -574,7 +574,7 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { val q = spark.readStream.format(dataSourceName).load().writeStream.format("console").start() q.awaitTermination() } - assert(err.getErrorClass == "STREAM_FAILED") + assert(err.getCondition == "STREAM_FAILED") assert(err.getMessage.contains("error creating stream reader")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala index 615e1e89f30b8..776772bb51ca8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala @@ -32,32 +32,59 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.execution.streaming.{StatefulProcessorHandleImpl, StatefulProcessorHandleState} import org.apache.spark.sql.execution.streaming.state.StateMessage -import org.apache.spark.sql.execution.streaming.state.StateMessage.{Clear, Exists, Get, HandleState, SetHandleState, StateCallCommand, StatefulProcessorCall, ValueStateCall, ValueStateUpdate} -import org.apache.spark.sql.streaming.{TTLConfig, ValueState} +import org.apache.spark.sql.execution.streaming.state.StateMessage.{AppendList, AppendValue, Clear, Exists, Get, HandleState, ListStateCall, ListStateGet, ListStatePut, SetHandleState, StateCallCommand, StatefulProcessorCall, ValueStateCall, ValueStateUpdate} +import org.apache.spark.sql.streaming.{ListState, TTLConfig, ValueState} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with BeforeAndAfterEach { - val valueStateName = "test" - var statefulProcessorHandle: StatefulProcessorHandleImpl = _ + val stateName = "test" + val iteratorId = "testId" + val serverSocket: ServerSocket = mock(classOf[ServerSocket]) + val groupingKeySchema: StructType = StructType(Seq()) + val stateSchema: StructType = StructType(Array(StructField("value", IntegerType))) + // Below byte array is a serialized row with a single integer value 1. + val byteArray: Array[Byte] = Array(0x80.toByte, 0x05.toByte, 0x95.toByte, 0x05.toByte, + 0x00.toByte, 0x00.toByte, 0x00.toByte, 0x00.toByte, 0x00.toByte, 0x00.toByte, 0x00.toByte, + 'K'.toByte, 0x01.toByte, 0x85.toByte, 0x94.toByte, '.'.toByte + ) + + var statefulProcessorHandle: StatefulProcessorHandleImpl = + mock(classOf[StatefulProcessorHandleImpl]) var outputStream: DataOutputStream = _ var valueState: ValueState[Row] = _ + var listState: ListState[Row] = _ var stateServer: TransformWithStateInPandasStateServer = _ - var valueSchema: StructType = _ - var valueDeserializer: ExpressionEncoder.Deserializer[Row] = _ + var stateDeserializer: ExpressionEncoder.Deserializer[Row] = _ + var stateSerializer: ExpressionEncoder.Serializer[Row] = _ + var transformWithStateInPandasDeserializer: TransformWithStateInPandasDeserializer = _ + var arrowStreamWriter: BaseStreamingArrowWriter = _ + var valueStateMap: mutable.HashMap[String, ValueStateInfo] = mutable.HashMap() + var listStateMap: mutable.HashMap[String, ListStateInfo] = mutable.HashMap() override def beforeEach(): Unit = { - val serverSocket = mock(classOf[ServerSocket]) statefulProcessorHandle = mock(classOf[StatefulProcessorHandleImpl]) - val groupingKeySchema = StructType(Seq()) outputStream = mock(classOf[DataOutputStream]) valueState = mock(classOf[ValueState[Row]]) - valueSchema = StructType(Array(StructField("value", IntegerType))) - valueDeserializer = ExpressionEncoder(valueSchema).resolveAndBind().createDeserializer() - val valueStateMap = mutable.HashMap[String, - (ValueState[Row], StructType, ExpressionEncoder.Deserializer[Row])](valueStateName -> - (valueState, valueSchema, valueDeserializer)) + listState = mock(classOf[ListState[Row]]) + stateDeserializer = ExpressionEncoder(stateSchema).resolveAndBind().createDeserializer() + stateSerializer = ExpressionEncoder(stateSchema).resolveAndBind().createSerializer() + valueStateMap = mutable.HashMap[String, ValueStateInfo](stateName -> + ValueStateInfo(valueState, stateSchema, stateDeserializer)) + listStateMap = mutable.HashMap[String, ListStateInfo](stateName -> + ListStateInfo(listState, stateSchema, stateDeserializer, stateSerializer)) + // Iterator map for list state. Please note that `handleImplicitGroupingKeyRequest` would + // reset the iterator map to empty so be careful to call it if you want to access the iterator + // map later. + val listStateIteratorMap = mutable.HashMap[String, Iterator[Row]](iteratorId -> + Iterator(new GenericRowWithSchema(Array(1), stateSchema))) + transformWithStateInPandasDeserializer = mock(classOf[TransformWithStateInPandasDeserializer]) + arrowStreamWriter = mock(classOf[BaseStreamingArrowWriter]) stateServer = new TransformWithStateInPandasStateServer(serverSocket, - statefulProcessorHandle, groupingKeySchema, outputStream, valueStateMap) + statefulProcessorHandle, groupingKeySchema, "", false, false, 2, + outputStream, valueStateMap, transformWithStateInPandasDeserializer, arrowStreamWriter, + listStateMap, listStateIteratorMap) + when(transformWithStateInPandasDeserializer.readArrowBatches(any)) + .thenReturn(Seq(new GenericRowWithSchema(Array(1), stateSchema))) } test("set handle state") { @@ -91,15 +118,38 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo } } + Seq(true, false).foreach { useTTL => + test(s"get list state, useTTL=$useTTL") { + val stateCallCommandBuilder = StateCallCommand.newBuilder() + .setStateName("newName") + .setSchema("StructType(List(StructField(value,IntegerType,true)))") + if (useTTL) { + stateCallCommandBuilder.setTtl(StateMessage.TTLConfig.newBuilder().setDurationMs(1000)) + } + val message = StatefulProcessorCall + .newBuilder() + .setGetListState(stateCallCommandBuilder.build()) + .build() + stateServer.handleStatefulProcessorCall(message) + if (useTTL) { + verify(statefulProcessorHandle) + .getListState[Row](any[String], any[Encoder[Row]], any[TTLConfig]) + } else { + verify(statefulProcessorHandle).getListState[Row](any[String], any[Encoder[Row]]) + } + verify(outputStream).writeInt(0) + } + } + test("value state exists") { - val message = ValueStateCall.newBuilder().setStateName(valueStateName) + val message = ValueStateCall.newBuilder().setStateName(stateName) .setExists(Exists.newBuilder().build()).build() stateServer.handleValueStateRequest(message) verify(valueState).exists() } test("value state get") { - val message = ValueStateCall.newBuilder().setStateName(valueStateName) + val message = ValueStateCall.newBuilder().setStateName(stateName) .setGet(Get.newBuilder().build()).build() val schema = new StructType().add("value", "int") when(valueState.getOption()).thenReturn(Some(new GenericRowWithSchema(Array(1), schema))) @@ -109,7 +159,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo } test("value state get - not exist") { - val message = ValueStateCall.newBuilder().setStateName(valueStateName) + val message = ValueStateCall.newBuilder().setStateName(stateName) .setGet(Get.newBuilder().build()).build() when(valueState.getOption()).thenReturn(None) stateServer.handleValueStateRequest(message) @@ -127,7 +177,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo } test("value state clear") { - val message = ValueStateCall.newBuilder().setStateName(valueStateName) + val message = ValueStateCall.newBuilder().setStateName(stateName) .setClear(Clear.newBuilder().build()).build() stateServer.handleValueStateRequest(message) verify(valueState).clear() @@ -135,16 +185,98 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo } test("value state update") { - // Below byte array is a serialized row with a single integer value 1. - val byteArray: Array[Byte] = Array(0x80.toByte, 0x05.toByte, 0x95.toByte, 0x05.toByte, - 0x00.toByte, 0x00.toByte, 0x00.toByte, 0x00.toByte, 0x00.toByte, 0x00.toByte, 0x00.toByte, - 'K'.toByte, 0x01.toByte, 0x85.toByte, 0x94.toByte, '.'.toByte - ) val byteString: ByteString = ByteString.copyFrom(byteArray) - val message = ValueStateCall.newBuilder().setStateName(valueStateName) + val message = ValueStateCall.newBuilder().setStateName(stateName) .setValueStateUpdate(ValueStateUpdate.newBuilder().setValue(byteString).build()).build() stateServer.handleValueStateRequest(message) verify(valueState).update(any[Row]) verify(outputStream).writeInt(0) } + + test("list state exists") { + val message = ListStateCall.newBuilder().setStateName(stateName) + .setExists(Exists.newBuilder().build()).build() + stateServer.handleListStateRequest(message) + verify(listState).exists() + } + + test("list state get - iterator in map") { + val message = ListStateCall.newBuilder().setStateName(stateName) + .setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build() + stateServer.handleListStateRequest(message) + verify(listState, times(0)).get() + verify(arrowStreamWriter).writeRow(any) + verify(arrowStreamWriter).finalizeCurrentArrowBatch() + } + + test("list state get - iterator in map with multiple batches") { + val maxRecordsPerBatch = 2 + val message = ListStateCall.newBuilder().setStateName(stateName) + .setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build() + val iteratorMap = mutable.HashMap[String, Iterator[Row]](iteratorId -> + Iterator(new GenericRowWithSchema(Array(1), stateSchema), + new GenericRowWithSchema(Array(2), stateSchema), + new GenericRowWithSchema(Array(3), stateSchema), + new GenericRowWithSchema(Array(4), stateSchema))) + stateServer = new TransformWithStateInPandasStateServer(serverSocket, + statefulProcessorHandle, groupingKeySchema, "", false, false, + maxRecordsPerBatch, outputStream, valueStateMap, + transformWithStateInPandasDeserializer, arrowStreamWriter, listStateMap, iteratorMap) + // First call should send 2 records. + stateServer.handleListStateRequest(message) + verify(listState, times(0)).get() + verify(arrowStreamWriter, times(maxRecordsPerBatch)).writeRow(any) + verify(arrowStreamWriter).finalizeCurrentArrowBatch() + // Second call should send the remaining 2 records. + stateServer.handleListStateRequest(message) + verify(listState, times(0)).get() + // Since Mockito's verify counts the total number of calls, the expected number of writeRow call + // should be 2 * maxRecordsPerBatch. + verify(arrowStreamWriter, times(2 * maxRecordsPerBatch)).writeRow(any) + verify(arrowStreamWriter, times(2)).finalizeCurrentArrowBatch() + } + + test("list state get - iterator not in map") { + val maxRecordsPerBatch = 2 + val message = ListStateCall.newBuilder().setStateName(stateName) + .setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build() + val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap() + stateServer = new TransformWithStateInPandasStateServer(serverSocket, + statefulProcessorHandle, groupingKeySchema, "", false, false, + maxRecordsPerBatch, outputStream, valueStateMap, + transformWithStateInPandasDeserializer, arrowStreamWriter, listStateMap, iteratorMap) + when(listState.get()).thenReturn(Iterator(new GenericRowWithSchema(Array(1), stateSchema), + new GenericRowWithSchema(Array(2), stateSchema), + new GenericRowWithSchema(Array(3), stateSchema))) + stateServer.handleListStateRequest(message) + verify(listState).get() + // Verify that only maxRecordsPerBatch (2) rows are written to the output stream while still + // having 1 row left in the iterator. + verify(arrowStreamWriter, times(maxRecordsPerBatch)).writeRow(any) + verify(arrowStreamWriter).finalizeCurrentArrowBatch() + } + + test("list state put") { + val message = ListStateCall.newBuilder().setStateName(stateName) + .setListStatePut(ListStatePut.newBuilder().build()).build() + stateServer.handleListStateRequest(message) + verify(transformWithStateInPandasDeserializer).readArrowBatches(any) + verify(listState).put(any) + } + + test("list state append value") { + val byteString: ByteString = ByteString.copyFrom(byteArray) + val message = ListStateCall.newBuilder().setStateName(stateName) + .setAppendValue(AppendValue.newBuilder().setValue(byteString).build()).build() + stateServer.handleListStateRequest(message) + verify(listState).appendValue(any[Row]) + } + + test("list state append list") { + val message = ListStateCall.newBuilder().setStateName(stateName) + .setAppendList(AppendList.newBuilder().build()).build() + stateServer.handleListStateRequest(message) + verify(transformWithStateInPandasDeserializer).readArrowBatches(any) + verify(listState).appendList(any) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala index 97b95eb402b7e..b5f23853fd5b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala @@ -102,7 +102,6 @@ class ConsoleWriteSupportSuite extends StreamTest { || 2| |+-----+ |only showing top 2 rows - | |""".stripMargin) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index 691f18451af22..9fcd2001cce50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -526,12 +526,12 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared val conf = RocksDBConf().copy(compression = "zstd") withDB(remoteDir, conf = conf, useColumnFamilies = colFamiliesEnabled) { db => - assert(db.columnFamilyOptions.compressionType() == CompressionType.ZSTD_COMPRESSION) + assert(db.rocksDbOptions.compressionType() == CompressionType.ZSTD_COMPRESSION) } // Test the default is LZ4 withDB(remoteDir, conf = RocksDBConf().copy(), useColumnFamilies = colFamiliesEnabled) { db => - assert(db.columnFamilyOptions.compressionType() == CompressionType.LZ4_COMPRESSION) + assert(db.rocksDbOptions.compressionType() == CompressionType.LZ4_COMPRESSION) } } @@ -811,6 +811,47 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared } } + testWithChangelogCheckpointingEnabled("RocksDB: ensure that changelog files are written " + + "and snapshots uploaded optionally with changelog format v2") { + withTempDir { dir => + val remoteDir = Utils.createTempDir().toString + val conf = dbConf.copy(minDeltasForSnapshot = 5, compactOnCommit = false) + new File(remoteDir).delete() // to make sure that the directory gets created + withDB(remoteDir, conf = conf, useColumnFamilies = true) { db => + db.createColFamilyIfAbsent("test") + db.load(0) + db.put("a", "1") + db.put("b", "2") + db.commit() + assert(changelogVersionsPresent(remoteDir) == Seq(1)) + assert(snapshotVersionsPresent(remoteDir) == Seq(1)) + + db.load(1) + db.put("a", "3") + db.put("c", "4") + db.commit() + + assert(changelogVersionsPresent(remoteDir) == Seq(1, 2)) + assert(snapshotVersionsPresent(remoteDir) == Seq(1)) + + db.removeColFamilyIfExists("test") + db.load(2) + db.remove("a") + db.put("d", "5") + db.commit() + assert(changelogVersionsPresent(remoteDir) == Seq(1, 2, 3)) + assert(snapshotVersionsPresent(remoteDir) == Seq(1, 3)) + + db.load(3) + db.put("e", "6") + db.remove("b") + db.commit() + assert(changelogVersionsPresent(remoteDir) == Seq(1, 2, 3, 4)) + assert(snapshotVersionsPresent(remoteDir) == Seq(1, 3)) + } + } + } + test("RocksDB: ensure merge operation correctness") { withTempDir { dir => val remoteDir = Utils.createTempDir().toString diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala index 38533825ece90..99483bc0ee8dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala @@ -423,14 +423,14 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { // collation checks are also performed in this path. so we need to check for them explicitly. if (keyCollationChecks) { assert(ex.getMessage.contains("Binary inequality column is not supported")) - assert(ex.getErrorClass === "STATE_STORE_UNSUPPORTED_OPERATION_BINARY_INEQUALITY") + assert(ex.getCondition === "STATE_STORE_UNSUPPORTED_OPERATION_BINARY_INEQUALITY") } else { if (ignoreValueSchema) { // if value schema is ignored, the mismatch has to be on the key schema - assert(ex.getErrorClass === "STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE") + assert(ex.getCondition === "STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE") } else { - assert(ex.getErrorClass === "STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE" || - ex.getErrorClass === "STATE_STORE_VALUE_SCHEMA_NOT_COMPATIBLE") + assert(ex.getCondition === "STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE" || + ex.getCondition === "STATE_STORE_VALUE_SCHEMA_NOT_COMPATIBLE") } assert(ex.getMessage.contains("does not match existing")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 8bbc7a31760d9..2a9944a81cb2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -1373,7 +1373,7 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] put(store, "a", 0, 0) val e = intercept[SparkException](quietly { store.commit() } ) - assert(e.getErrorClass == "CANNOT_WRITE_STATE_STORE.CANNOT_COMMIT") + assert(e.getCondition == "CANNOT_WRITE_STATE_STORE.CANNOT_COMMIT") if (store.getClass.getName contains ROCKSDB_STATE_STORE) { assert(e.getMessage contains "RocksDBStateStore[id=(op=0,part=0)") } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala index c993aa8e52031..76fcdfc380950 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala @@ -324,7 +324,7 @@ class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { a.asInstanceOf[AgnosticEncoder[Any]] test("udf") { - val int2LongSum = new aggregate.TypedSumLong[Int]((i: Int) => i.toLong) + val int2LongSum = new TypedSumLong[Int]((i: Int) => i.toLong) val bufferEncoder = encoderFor(int2LongSum.bufferEncoder) val outputEncoder = encoderFor(int2LongSum.outputEncoder) val bufferAttrs = bufferEncoder.namedExpressions.map { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index 82795e551b6bf..2b58440baf852 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -233,8 +233,8 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { // static sql configs checkError( exception = intercept[AnalysisException](sql(s"RESET ${StaticSQLConf.WAREHOUSE_PATH.key}")), - condition = "_LEGACY_ERROR_TEMP_1325", - parameters = Map("key" -> "spark.sql.warehouse.dir")) + condition = "CANNOT_MODIFY_CONFIG", + parameters = Map("key" -> "\"spark.sql.warehouse.dir\"", "docroot" -> SPARK_DOC_ROOT)) } @@ -315,10 +315,16 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { } test("cannot set/unset static SQL conf") { - val e1 = intercept[AnalysisException](sql(s"SET ${GLOBAL_TEMP_DATABASE.key}=10")) - assert(e1.message.contains("Cannot modify the value of a static config")) - val e2 = intercept[AnalysisException](spark.conf.unset(GLOBAL_TEMP_DATABASE.key)) - assert(e2.message.contains("Cannot modify the value of a static config")) + checkError( + exception = intercept[AnalysisException](sql(s"SET ${GLOBAL_TEMP_DATABASE.key}=10")), + condition = "CANNOT_MODIFY_CONFIG", + parameters = Map("key" -> "\"spark.sql.globalTempDatabase\"", "docroot" -> SPARK_DOC_ROOT) + ) + checkError( + exception = intercept[AnalysisException](spark.conf.unset(GLOBAL_TEMP_DATABASE.key)), + condition = "CANNOT_MODIFY_CONFIG", + parameters = Map("key" -> "\"spark.sql.globalTempDatabase\"", "docroot" -> SPARK_DOC_ROOT) + ) } test("SPARK-36643: Show migration guide when attempting SparkConf") { @@ -486,8 +492,8 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { val sqlText = "set time zone interval 19 hours" checkError( exception = intercept[ParseException](sql(sqlText)), - condition = "_LEGACY_ERROR_TEMP_0044", - parameters = Map.empty, + condition = "INVALID_INTERVAL_FORMAT.TIMEZONE_INTERVAL_OUT_OF_RANGE", + parameters = Map("input" -> "19"), context = ExpectedContext(sqlText, 0, 30)) } @@ -517,6 +523,13 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { "confName" -> "spark.sql.session.collation.default", "proposals" -> "UNICODE" )) + + withSQLConf(SQLConf.TRIM_COLLATION_ENABLED.key -> "false") { + checkError( + exception = intercept[AnalysisException](sql(s"SET COLLATION UNICODE_CI_RTRIM")), + condition = "UNSUPPORTED_FEATURE.TRIM_COLLATION" + ) + } } test("SPARK-43028: config not found error") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 054c7e644ff55..0550fae3805d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -2688,7 +2688,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df = sql("SELECT SUM(2147483647 + DEPT) FROM h2.test.employee") checkAggregateRemoved(df, ansiMode) val expectedPlanFragment = if (ansiMode) { - "PushedAggregates: [SUM(2147483647 + DEPT)], " + + "PushedAggregates: [SUM(DEPT + 2147483647)], " + "PushedFilters: [], " + "PushedGroupByExpressions: []" } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 83d8191d01ec1..baad5702f4f22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -97,6 +97,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi case TestLeafStatement(testVal) => testVal case TestIfElseCondition(_, description) => description case TestLoopCondition(_, _, description) => description + case loopStmt: LoopStatementExec => loopStmt.label.get case leaveStmt: LeaveStatementExec => leaveStmt.label case iterateStmt: IterateStatementExec => iterateStmt.label case _ => fail("Unexpected statement type") @@ -669,4 +670,20 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq("con1", "con2")) } + + test("loop statement with leave") { + val iter = new CompoundBodyExec( + statements = Seq( + new LoopStatementExec( + body = new CompoundBodyExec(Seq( + TestLeafStatement("body1"), + new LeaveStatementExec("lbl")) + ), + label = Some("lbl") + ) + ) + ).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("body1", "lbl")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index bc2adec5be3d5..3551608a1ee84 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.{SparkException, SparkNumberFormatException} import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, Row} import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.exceptions.SqlScriptingException +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession /** @@ -701,8 +702,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { verifySqlScriptResult(commands, expected) } - // This is disabled because it fails in non-ANSI mode - ignore("simple case mismatched types") { + test("simple case mismatched types") { val commands = """ |BEGIN @@ -712,18 +712,26 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { | END CASE; |END |""".stripMargin - - checkError( - exception = intercept[SparkNumberFormatException] ( - runSqlScript(commands) - ), - condition = "CAST_INVALID_INPUT", - parameters = Map( - "expression" -> "'one'", - "sourceType" -> "\"STRING\"", - "targetType" -> "\"BIGINT\""), - context = ExpectedContext(fragment = "\"one\"", start = 23, stop = 27) - ) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + checkError( + exception = intercept[SparkNumberFormatException]( + runSqlScript(commands) + ), + condition = "CAST_INVALID_INPUT", + parameters = Map( + "expression" -> "'one'", + "sourceType" -> "\"STRING\"", + "targetType" -> "\"BIGINT\""), + context = ExpectedContext(fragment = "\"one\"", start = 23, stop = 27)) + } + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + checkError( + exception = intercept[SqlScriptingException]( + runSqlScript(commands) + ), + condition = "BOOLEAN_STATEMENT_WITH_EMPTY_ROW", + parameters = Map("invalidStatement" -> "\"ONE\"")) + } } test("simple case compare with null") { @@ -1375,4 +1383,156 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { ) verifySqlScriptResult(sqlScriptText, expected) } + + test("loop statement with leave") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: LOOP + | SET x = x + 1; + | SELECT x; + | IF x > 2 + | THEN + | LEAVE lbl; + | END IF; + | END LOOP; + | SELECT x; + |END""".stripMargin + val expected = Seq( + Seq.empty[Row], // declare + Seq.empty[Row], // set x = 0 + Seq.empty[Row], // set x = 1 + Seq(Row(1)), // select x + Seq.empty[Row], // set x = 2 + Seq(Row(2)), // select x + Seq.empty[Row], // set x = 3 + Seq(Row(3)), // select x + Seq(Row(3)), // select x + Seq.empty[Row] // drop + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("nested loop statement with leave") { + val commands = + """ + |BEGIN + | DECLARE x = 0; + | DECLARE y = 0; + | lbl1: LOOP + | SET VAR y = 0; + | lbl2: LOOP + | SELECT x, y; + | SET VAR y = y + 1; + | IF y >= 2 THEN + | LEAVE lbl2; + | END IF; + | END LOOP; + | SET VAR x = x + 1; + | IF x >= 2 THEN + | LEAVE lbl1; + | END IF; + | END LOOP; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // declare x + Seq.empty[Row], // declare y + Seq.empty[Row], // set y to 0 + Seq(Row(0, 0)), // select x, y + Seq.empty[Row], // increase y + Seq(Row(0, 1)), // select x, y + Seq.empty[Row], // increase y + Seq.empty[Row], // increase x + Seq.empty[Row], // set y to 0 + Seq(Row(1, 0)), // select x, y + Seq.empty[Row], // increase y + Seq(Row(1, 1)), // select x, y + Seq.empty[Row], // increase y + Seq.empty[Row], // increase x + Seq.empty[Row], // drop y + Seq.empty[Row] // drop x + ) + verifySqlScriptResult(commands, expected) + } + + test("iterate loop statement") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: LOOP + | SET x = x + 1; + | IF x > 1 THEN + | LEAVE lbl; + | END IF; + | ITERATE lbl; + | SET x = x + 2; + | END LOOP; + | SELECT x; + |END""".stripMargin + val expected = Seq( + Seq.empty[Row], // declare + Seq.empty[Row], // set x = 0 + Seq.empty[Row], // set x = 1 + Seq.empty[Row], // set x = 2 + Seq(Row(2)), // select x + Seq.empty[Row] // drop + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("leave outer loop from nested loop statement") { + val sqlScriptText = + """ + |BEGIN + | lbl: LOOP + | lbl2: LOOP + | SELECT 1; + | LEAVE lbl; + | END LOOP; + | END LOOP; + |END""".stripMargin + val expected = Seq( + Seq(Row(1)) // select 1 + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("iterate outer loop from nested loop statement") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: LOOP + | SET x = x + 1; + | IF x > 2 THEN + | LEAVE lbl; + | END IF; + | lbl2: LOOP + | SELECT 1; + | ITERATE lbl; + | SET x = 10; + | END LOOP; + | END LOOP; + | SELECT x; + |END""".stripMargin + val expected = Seq( + Seq.empty[Row], // declare + Seq.empty[Row], // set x = 0 + Seq.empty[Row], // set x = 1 + Seq(Row(1)), // select 1 + Seq.empty[Row], // set x = 2 + Seq(Row(1)), // select 1 + Seq.empty[Row], // set x = 3 + Seq(Row(3)), // select x + Seq.empty[Row] // drop + ) + verifySqlScriptResult(sqlScriptText, expected) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 57655a58a694d..baf99798965da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -956,7 +956,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { val msg = intercept[SparkRuntimeException] { sql("INSERT INTO TABLE test_table SELECT 2, null") } - assert(msg.getErrorClass == "NOT_NULL_ASSERT_VIOLATION") + assert(msg.getCondition == "NOT_NULL_ASSERT_VIOLATION") } } @@ -1998,7 +1998,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql(s"create table t(a string default 'abc') using parquet") }, - condition = "_LEGACY_ERROR_TEMP_1345", + condition = "DEFAULT_UNSUPPORTED", parameters = Map("statementType" -> "CREATE TABLE", "dataSource" -> "parquet")) withTable("t") { sql(s"create table t(a string, b int) using parquet") @@ -2006,7 +2006,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("alter table t add column s bigint default 42") }, - condition = "_LEGACY_ERROR_TEMP_1345", + condition = "DEFAULT_UNSUPPORTED", parameters = Map( "statementType" -> "ALTER TABLE ADD COLUMNS", "dataSource" -> "parquet")) @@ -2314,7 +2314,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { // provider is now in the denylist. sql(s"alter table t1 add column (b string default 'abc')") }, - condition = "_LEGACY_ERROR_TEMP_1346", + condition = "ADD_DEFAULT_UNSUPPORTED", parameters = Map( "statementType" -> "ALTER TABLE ADD COLUMNS", "dataSource" -> provider)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index d9ce8002d285b..a0eea14e54eed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -296,7 +296,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { val exception = SparkException.internalError("testpurpose") testSerialization( new QueryTerminatedEvent(UUID.randomUUID, UUID.randomUUID, - Some(exception.getMessage), Some(exception.getErrorClass))) + Some(exception.getMessage), Some(exception.getCondition))) } test("only one progress event per interval when no data") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala index 782badaef924f..f651bfb7f3c72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala @@ -21,7 +21,7 @@ import java.sql.Timestamp import org.apache.spark.sql.Row import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.functions.{expr, lit, window} +import org.apache.spark.sql.functions.{count, expr, lit, timestamp_seconds, window} import org.apache.spark.sql.internal.SQLConf /** @@ -524,4 +524,66 @@ class StreamingQueryOptimizationCorrectnessSuite extends StreamTest { doTest(numExpectedStatefulOperatorsForOneEmptySource = 1) } } + + test("SPARK-49699: observe node is not pruned out from PruneFilters") { + val input1 = MemoryStream[Int] + val df = input1.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .observe("observation", count(lit(1)).as("rows")) + // Enforce PruneFilters to come into play and prune subtree. We could do the same + // with the reproducer of SPARK-48267, but let's just be simpler. + .filter(expr("false")) + + testStream(df)( + AddData(input1, 1, 2, 3), + CheckNewAnswer(), + Execute { qe => + val observeRow = qe.lastExecution.observedMetrics.get("observation") + assert(observeRow.get.getAs[Long]("rows") == 3L) + } + ) + } + + test("SPARK-49699: watermark node is not pruned out from PruneFilters") { + // NOTE: The test actually passes without SPARK-49699, because of the trickiness of + // filter pushdown and PruneFilters. Unlike observe node, the `false` filter is pushed down + // below to watermark node, hence PruneFilters rule does not prune out watermark node even + // before SPARK-49699. Propagate empty relation does not also propagate emptiness into + // watermark node, so the node is retained. The test is added for preventing regression. + + val input1 = MemoryStream[Int] + val df = input1.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .withWatermark("eventTime", "0 second") + // Enforce PruneFilter to come into play and prune subtree. We could do the same + // with the reproducer of SPARK-48267, but let's just be simpler. + .filter(expr("false")) + + testStream(df)( + AddData(input1, 1, 2, 3), + CheckNewAnswer(), + Execute { qe => + // If the watermark node is pruned out, this would be null. + assert(qe.lastProgress.eventTime.get("watermark") != null) + } + ) + } + + test("SPARK-49699: stateful operator node is not pruned out from PruneFilters") { + val input1 = MemoryStream[Int] + val df = input1.toDF() + .groupBy("value") + .count() + // Enforce PruneFilter to come into play and prune subtree. We could do the same + // with the reproducer of SPARK-48267, but let's just be simpler. + .filter(expr("false")) + + testStream(df, OutputMode.Complete())( + AddData(input1, 1, 2, 3), + CheckNewAnswer(), + Execute { qe => + assert(qe.lastProgress.stateOperators.length == 1) + } + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 8471995cb1e50..c12846d7512d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Complete import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, ReadLimit} -import org.apache.spark.sql.execution.exchange.ReusedExchangeExec +import org.apache.spark.sql.execution.exchange.{REQUIRED_BY_STATEFUL_OPERATOR, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.{MemorySink, TestForeachWriter} import org.apache.spark.sql.functions._ @@ -1448,6 +1448,28 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } + test("SPARK-49905 shuffle added by stateful operator should use the shuffle origin " + + "`REQUIRED_BY_STATEFUL_OPERATOR`") { + val inputData = MemoryStream[Int] + + // Use the streaming aggregation as an example - all stateful operators are using the same + // distribution, named `StatefulOpClusteredDistribution`. + val df = inputData.toDF().groupBy("value").count() + + testStream(df, OutputMode.Update())( + AddData(inputData, 1, 2, 3, 1, 2, 3), + CheckAnswer((1, 2), (2, 2), (3, 2)), + Execute { qe => + val shuffleOpt = qe.lastExecution.executedPlan.collect { + case s: ShuffleExchangeExec => s + } + + assert(shuffleOpt.nonEmpty, "No shuffle exchange found in the query plan") + assert(shuffleOpt.head.shuffleOrigin === REQUIRED_BY_STATEFUL_OPERATOR) + } + ) + } + private def checkAppendOutputModeException(df: DataFrame): Unit = { withTempDir { outputDir => withTempDir { checkpointDir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index d0e255bb30499..0c02fbf97820b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -1623,7 +1623,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest val batch1AnsDf = batch1Df.selectExpr( "key.value AS groupingKey", - "single_value.value AS valueId") + "value.value AS valueId") checkAnswer(batch1AnsDf, Seq(Row("a", 2L))) @@ -1636,7 +1636,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest val batch3AnsDf = batch3Df.selectExpr( "key.value AS groupingKey", - "single_value.value AS valueId") + "value.value AS valueId") checkAnswer(batch3AnsDf, Seq(Row("a", 1L))) } } @@ -1731,7 +1731,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest val countStateAnsDf = countStateDf.selectExpr( "key.value AS groupingKey", - "single_value.value AS valueId") + "value.value AS valueId") checkAnswer(countStateAnsDf, Seq(Row("a", 5L))) val mostRecentDf = spark.read @@ -1743,7 +1743,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest val mostRecentAnsDf = mostRecentDf.selectExpr( "key.value AS groupingKey", - "single_value.value") + "value.value") checkAnswer(mostRecentAnsDf, Seq(Row("a", "str1"))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala index 45056d104e84e..1fbeaeb817bd9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoders -import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, ListStateImplWithTTL, MapStateImplWithTTL, MemoryStream, ValueStateImpl, ValueStateImplWithTTL} +import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, ListStateImplWithTTL, MapStateImplWithTTL, MemoryStream, TimerStateUtils, ValueStateImpl, ValueStateImplWithTTL} import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock @@ -265,7 +265,16 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { val fm = CheckpointFileManager.create(stateSchemaPath, hadoopConf) val keySchema = new StructType().add("value", StringType) + val schemaForKeyRow: StructType = new StructType() + .add("key", new StructType(keySchema.fields)) + .add("expiryTimestampMs", LongType, nullable = false) + val schemaForValueRow: StructType = StructType(Array(StructField("__dummy__", NullType))) val schema0 = StateStoreColFamilySchema( + TimerStateUtils.getTimerStateVarName(TimeMode.ProcessingTime().toString), + schemaForKeyRow, + schemaForValueRow, + Some(PrefixKeyScanStateEncoderSpec(schemaForKeyRow, 1))) + val schema1 = StateStoreColFamilySchema( "valueStateTTL", keySchema, new StructType().add("value", @@ -275,14 +284,14 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { Some(NoPrefixKeyStateEncoderSpec(keySchema)), None ) - val schema1 = StateStoreColFamilySchema( + val schema2 = StateStoreColFamilySchema( "valueState", keySchema, new StructType().add("value", IntegerType, false), Some(NoPrefixKeyStateEncoderSpec(keySchema)), None ) - val schema2 = StateStoreColFamilySchema( + val schema3 = StateStoreColFamilySchema( "listState", keySchema, new StructType().add("value", @@ -300,7 +309,7 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { val compositeKeySchema = new StructType() .add("key", new StructType().add("value", StringType)) .add("userKey", userKeySchema) - val schema3 = StateStoreColFamilySchema( + val schema4 = StateStoreColFamilySchema( "mapState", compositeKeySchema, new StructType().add("value", @@ -351,9 +360,9 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { q.lastProgress.stateOperators.head.customMetrics .get("numMapStateWithTTLVars").toInt) - assert(colFamilySeq.length == 4) + assert(colFamilySeq.length == 5) assert(colFamilySeq.map(_.toString).toSet == Set( - schema0, schema1, schema2, schema3 + schema0, schema1, schema2, schema3, schema4 ).map(_.toString)) }, StopStream diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index d7c00b68828c4..90432dea3a017 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -35,7 +35,7 @@ private[sql] trait SQLTestData { self => // Helper object to import SQL implicits without a concrete SparkSession private object internalImplicits extends SQLImplicits { - protected override def session: SparkSession = self.spark + override protected def session: SparkSession = self.spark } import internalImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 54d6840eb5775..fe5a0f8ee257a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -240,7 +240,7 @@ private[sql] trait SQLTestUtilsBase * but the implicits import is needed in the constructor. */ protected object testImplicits extends SQLImplicits { - protected override def session: SparkSession = self.spark + override protected def session: SparkSession = self.spark implicit def toRichColumn(c: Column): SparkSession#RichColumn = session.RichColumn(c) } diff --git a/sql/gen-sql-api-docs.py b/sql/gen-sql-api-docs.py index 17631a7352a02..3d19da01b3938 100644 --- a/sql/gen-sql-api-docs.py +++ b/sql/gen-sql-api-docs.py @@ -69,19 +69,6 @@ note="", since="1.0.0", deprecated=""), - ExpressionInfo( - className="", - name="between", - usage="expr1 [NOT] BETWEEN expr2 AND expr3 - " + - "evaluate if `expr1` is [not] in between `expr2` and `expr3`.", - arguments="", - examples="\n Examples:\n " + - "> SELECT col1 FROM VALUES 1, 3, 5, 7 WHERE col1 BETWEEN 2 AND 5;\n " + - " 3\n " + - " 5", - note="", - since="1.0.0", - deprecated=""), ExpressionInfo( className="", name="case", diff --git a/sql/gen-sql-functions-docs.py b/sql/gen-sql-functions-docs.py index bb813cffb0128..a1facbaaf7e3b 100644 --- a/sql/gen-sql-functions-docs.py +++ b/sql/gen-sql-functions-docs.py @@ -36,9 +36,14 @@ "bitwise_funcs", "conversion_funcs", "csv_funcs", "xml_funcs", "lambda_funcs", "collection_funcs", "url_funcs", "hash_funcs", "struct_funcs", + "table_funcs", "variant_funcs" } +def _print_red(text): + print('\033[31m' + text + '\033[0m') + + def _list_grouped_function_infos(jvm): """ Returns a list of function information grouped by each group value via JVM. @@ -126,7 +131,13 @@ def _make_pretty_usage(infos): func_name = "\\" + func_name elif (info.name == "when"): func_name = "CASE WHEN" - usages = iter(re.split(r"(.*%s.*) - " % func_name, info.usage.strip())[1:]) + expr_usages = re.split(r"(.*%s.*) - " % func_name, info.usage.strip()) + if len(expr_usages) <= 1: + _print_red("\nThe `usage` of %s is not standardized, please correct it. " + "Refer to: `AesDecrypt`" % (func_name)) + os._exit(-1) + usages = iter(expr_usages[1:]) + for (sig, description) in zip(usages, usages): result.append(" ") result.append(" %s" % sig) diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 6a352f8a530d7..d066e235ebeab 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -138,6 +138,16 @@ mockito-core test
+ + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + net.sf.jpam jpam diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServerErrors.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServerErrors.scala index 8a8bdd4d38ee3..59d1b61f2f8e7 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServerErrors.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServerErrors.scala @@ -38,7 +38,7 @@ object HiveThriftServerErrors { def runningQueryError(e: Throwable, format: ErrorMessageFormat.Value): Throwable = e match { case st: SparkThrowable if format == ErrorMessageFormat.PRETTY => - val errorClassPrefix = Option(st.getErrorClass).map(e => s"[$e] ").getOrElse("") + val errorClassPrefix = Option(st.getCondition).map(e => s"[$e] ").getOrElse("") new HiveSQLException( s"Error running query: $errorClassPrefix${st.toString}", st.getSqlState, st) case st: SparkThrowable with Throwable => diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 4575549005f33..43030f68e5dac 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -1062,7 +1062,7 @@ class SingleSessionSuite extends HiveThriftServer2TestBase { statement.executeQuery("SET spark.sql.hive.thriftServer.singleSession=false") }.getMessage assert(e.contains( - "Cannot modify the value of a static config: spark.sql.hive.thriftServer.singleSession")) + "CANNOT_MODIFY_CONFIG")) } } @@ -1222,7 +1222,7 @@ abstract class HiveThriftServer2TestBase extends SparkFunSuite with BeforeAndAft // overrides all other potential log4j configurations contained in other dependency jar files. val tempLog4jConf = Utils.createTempDir().getCanonicalPath - Files.write( + Files.asCharSink(new File(s"$tempLog4jConf/log4j2.properties"), StandardCharsets.UTF_8).write( """rootLogger.level = info |rootLogger.appenderRef.stdout.ref = console |appender.console.type = Console @@ -1230,9 +1230,7 @@ abstract class HiveThriftServer2TestBase extends SparkFunSuite with BeforeAndAft |appender.console.target = SYSTEM_ERR |appender.console.layout.type = PatternLayout |appender.console.layout.pattern = %d{HH:mm:ss.SSS} %p %c: %maxLen{%m}{512}%n%ex{8}%n - """.stripMargin, - new File(s"$tempLog4jConf/log4j2.properties"), - StandardCharsets.UTF_8) + """.stripMargin) tempLog4jConf } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala index 4bc4116a23da7..60c49619552e7 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala @@ -214,7 +214,7 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer { val sessionHandle = client.openSession(user, "") val infoValue = client.getInfo(sessionHandle, GetInfoType.CLI_ODBC_KEYWORDS) // scalastyle:off line.size.limit - assert(infoValue.getStringValue == "ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DEFINER,DELETE,DELIMITED,DESC,DESCRIBE,DETERMINISTIC,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DO,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EVOLUTION,EXCEPT,EXCHANGE,EXCLUDE,EXECUTE,EXISTS,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IDENTITY,IF,IGNORE,ILIKE,IMMEDIATE,IMPORT,IN,INCLUDE,INCREMENT,INDEX,INDEXES,INNER,INPATH,INPUT,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,INVOKER,IS,ITEMS,ITERATE,JOIN,KEYS,LANGUAGE,LAST,LATERAL,LAZY,LEADING,LEAVE,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MODIFIES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NONE,NOT,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,READS,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEAT,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,RETURN,RETURNS,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SECURITY,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,SPECIFIC,SQL,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UNTIL,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WHILE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") + assert(infoValue.getStringValue == "ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALL,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DEFINER,DELETE,DELIMITED,DESC,DESCRIBE,DETERMINISTIC,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DO,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EVOLUTION,EXCEPT,EXCHANGE,EXCLUDE,EXECUTE,EXISTS,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IDENTITY,IF,IGNORE,ILIKE,IMMEDIATE,IMPORT,IN,INCLUDE,INCREMENT,INDEX,INDEXES,INNER,INPATH,INPUT,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,INVOKER,IS,ITEMS,ITERATE,JOIN,KEYS,LANGUAGE,LAST,LATERAL,LAZY,LEADING,LEAVE,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,LOOP,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MODIFIES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NONE,NOT,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,READS,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEAT,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,RETURN,RETURNS,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SECURITY,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,SPECIFIC,SQL,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UNTIL,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WHILE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") // scalastyle:on line.size.limit } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala index 2b2cbec41d643..8d4a9886a2b25 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala @@ -75,7 +75,7 @@ class UISeleniumSuite // overrides all other potential log4j configurations contained in other dependency jar files. val tempLog4jConf = org.apache.spark.util.Utils.createTempDir().getCanonicalPath - Files.write( + Files.asCharSink(new File(s"$tempLog4jConf/log4j2.properties"), StandardCharsets.UTF_8).write( """rootLogger.level = info |rootLogger.appenderRef.file.ref = console |appender.console.type = Console @@ -83,9 +83,7 @@ class UISeleniumSuite |appender.console.target = SYSTEM_ERR |appender.console.layout.type = PatternLayout |appender.console.layout.pattern = %d{HH:mm:ss.SSS} %p %c: %maxLen{%m}{512}%n%ex{8}%n - """.stripMargin, - new File(s"$tempLog4jConf/log4j2.properties"), - StandardCharsets.UTF_8) + """.stripMargin) tempLog4jConf } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 7873c36222da0..1f87db31ffa52 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -28,7 +28,7 @@ import org.apache.spark.SparkException import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.sql.{AnalysisException, SparkSession} -import org.apache.spark.sql.catalyst.{FullQualifiedTableName, TableIdentifier} +import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types.DataTypeUtils @@ -56,7 +56,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log private val tableCreationLocks = Striped.lazyWeakLock(100) /** Acquires a lock on the table cache for the duration of `f`. */ - private def withTableCreationLock[A](tableName: FullQualifiedTableName, f: => A): A = { + private def withTableCreationLock[A](tableName: QualifiedTableName, f: => A): A = { val lock = tableCreationLocks.get(tableName) lock.lock() try f finally { @@ -66,7 +66,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log // For testing only private[hive] def getCachedDataSourceTable(table: TableIdentifier): LogicalPlan = { - val key = FullQualifiedTableName( + val key = QualifiedTableName( // scalastyle:off caselocale table.catalog.getOrElse(CatalogManager.SESSION_CATALOG_NAME).toLowerCase, table.database.getOrElse(sessionState.catalog.getCurrentDatabase).toLowerCase, @@ -76,7 +76,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } private def getCached( - tableIdentifier: FullQualifiedTableName, + tableIdentifier: QualifiedTableName, pathsInMetastore: Seq[Path], schemaInMetastore: StructType, expectedFileFormat: Class[_ <: FileFormat], @@ -120,7 +120,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } private def logWarningUnexpectedFileFormat( - tableIdentifier: FullQualifiedTableName, + tableIdentifier: QualifiedTableName, expectedFileFormat: Class[_ <: FileFormat], actualFileFormat: String): Unit = { logWarning(log"Table ${MDC(TABLE_NAME, tableIdentifier)} should be stored as " + @@ -201,7 +201,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log fileType: String, isWrite: Boolean): LogicalRelation = { val metastoreSchema = relation.tableMeta.schema - val tableIdentifier = FullQualifiedTableName(relation.tableMeta.identifier.catalog.get, + val tableIdentifier = QualifiedTableName(relation.tableMeta.identifier.catalog.get, relation.tableMeta.database, relation.tableMeta.identifier.table) val lazyPruningEnabled = sparkSession.sessionState.conf.manageFilesourcePartitions diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 44c1ecd6902ce..dbeb8607facc2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, ReplaceCharWithVarchar, ResolveSessionCatalog, ResolveTranspose} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, InvokeProcedures, ReplaceCharWithVarchar, ResolveSessionCatalog, ResolveTranspose} import org.apache.spark.sql.catalyst.catalog.{ExternalCatalogWithListener, InvalidUDFClassException} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -95,6 +95,7 @@ class HiveSessionStateBuilder( new EvalSubqueriesForTimeTravel +: new DetermineTableStats(session) +: new ResolveTranspose(session) +: + new InvokeProcedures(session) +: customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 594c097de2c7d..83d70b2e19109 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -23,10 +23,10 @@ import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import java.util.{Locale, Set} -import com.google.common.io.Files +import com.google.common.io.{Files, FileWriteMode} import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.{SparkException, TestUtils} +import org.apache.spark.{SPARK_DOC_ROOT, SparkException, TestUtils} import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.catalog.{CatalogTableType, CatalogUtils, HiveTableRelation} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} +import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLConf import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.execution.{SparkPlanInfo, TestUncaughtExceptionHandler} import org.apache.spark.sql.execution.adaptive.{DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} @@ -246,7 +247,7 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi checkKeywordsExist(sql("describe function `between`"), "Function: between", - "Usage: input [NOT] BETWEEN lower AND upper - " + + "input [NOT] between lower AND upper - " + "evaluate if `input` is [not] in between `lower` and `upper`") checkKeywordsExist(sql("describe function `case`"), @@ -1947,10 +1948,10 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi val path = dir.toURI.toString.stripSuffix("/") val dirPath = dir.getAbsoluteFile for (i <- 1 to 3) { - Files.write(s"$i", new File(dirPath, s"part-r-0000$i"), StandardCharsets.UTF_8) + Files.asCharSink(new File(dirPath, s"part-r-0000$i"), StandardCharsets.UTF_8).write(s"$i") } for (i <- 5 to 7) { - Files.write(s"$i", new File(dirPath, s"part-s-0000$i"), StandardCharsets.UTF_8) + Files.asCharSink(new File(dirPath, s"part-s-0000$i"), StandardCharsets.UTF_8).write(s"$i") } withTable("load_t") { @@ -1971,7 +1972,7 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi val path = dir.toURI.toString.stripSuffix("/") val dirPath = dir.getAbsoluteFile for (i <- 1 to 3) { - Files.write(s"$i", new File(dirPath, s"part-r-0000 $i"), StandardCharsets.UTF_8) + Files.asCharSink(new File(dirPath, s"part-r-0000 $i"), StandardCharsets.UTF_8).write(s"$i") } withTable("load_t") { sql("CREATE TABLE load_t (a STRING) USING hive") @@ -1986,7 +1987,7 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi val path = dir.toURI.toString.stripSuffix("/") val dirPath = dir.getAbsoluteFile for (i <- 1 to 3) { - Files.write(s"$i", new File(dirPath, s"part-r-0000$i"), StandardCharsets.UTF_8) + Files.asCharSink(new File(dirPath, s"part-r-0000$i"), StandardCharsets.UTF_8).write(s"$i") } withTable("load_t") { sql("CREATE TABLE load_t (a STRING) USING hive") @@ -2010,7 +2011,7 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi val path = dir.toURI.toString.stripSuffix("/") val dirPath = dir.getAbsoluteFile for (i <- 1 to 3) { - Files.write(s"$i", new File(dirPath, s"part-r-0000$i"), StandardCharsets.UTF_8) + Files.asCharSink(new File(dirPath, s"part-r-0000$i"), StandardCharsets.UTF_8).write(s"$i") } withTable("load_t1") { sql("CREATE TABLE load_t1 (a STRING) USING hive") @@ -2025,7 +2026,7 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi val path = dir.toURI.toString.stripSuffix("/") val dirPath = dir.getAbsoluteFile for (i <- 1 to 3) { - Files.write(s"$i", new File(dirPath, s"part-r-0000$i"), StandardCharsets.UTF_8) + Files.asCharSink(new File(dirPath, s"part-r-0000$i"), StandardCharsets.UTF_8).write(s"$i") } withTable("load_t2") { sql("CREATE TABLE load_t2 (a STRING) USING hive") @@ -2039,7 +2040,8 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi withTempDir { dir => val path = dir.toURI.toString.stripSuffix("/") val dirPath = dir.getAbsoluteFile - Files.append("1", new File(dirPath, "part-r-000011"), StandardCharsets.UTF_8) + Files.asCharSink( + new File(dirPath, "part-r-000011"), StandardCharsets.UTF_8, FileWriteMode.APPEND).write("1") withTable("part_table") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { sql( @@ -2460,8 +2462,12 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi "spark.sql.hive.metastore.jars", "spark.sql.hive.metastore.sharedPrefixes", "spark.sql.hive.metastore.barrierPrefixes").foreach { key => - val e = intercept[AnalysisException](sql(s"set $key=abc")) - assert(e.getMessage.contains("Cannot modify the value of a static config")) + checkError( + exception = intercept[AnalysisException](sql(s"set $key=abc")), + condition = "CANNOT_MODIFY_CONFIG", + parameters = Map( + "key" -> toSQLConf(key), "docroot" -> SPARK_DOC_ROOT) + ) } } diff --git a/streaming/pom.xml b/streaming/pom.xml index 85a4d268d2a25..704886bfdd1f6 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -121,6 +121,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + target/scala-${scala.binary.version}/classes diff --git a/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java index f8d961fa8dd8e..73c2e89f3729a 100644 --- a/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java @@ -1641,7 +1641,7 @@ public void testRawSocketStream() { private static List> fileTestPrepare(File testDir) throws IOException { File existingFile = new File(testDir, "0"); - Files.write("0\n", existingFile, StandardCharsets.UTF_8); + Files.asCharSink(existingFile, StandardCharsets.UTF_8).write("0\n"); Assertions.assertTrue(existingFile.setLastModified(1000)); Assertions.assertEquals(1000, existingFile.lastModified()); return Arrays.asList(Arrays.asList("0")); diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 43b0835df7cbf..4aeb0e043a973 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -649,7 +649,7 @@ class CheckpointSuite extends TestSuiteBase with LocalStreamingContext with DStr */ def writeFile(i: Int, clock: Clock): Unit = { val file = new File(testDir, i.toString) - Files.write(s"$i\n", file, StandardCharsets.UTF_8) + Files.asCharSink(file, StandardCharsets.UTF_8).write(s"$i\n") assert(file.setLastModified(clock.getTimeMillis())) // Check that the file's modification date is actually the value we wrote, since rounding or // truncation will break the test: diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 66fd1ac7bb22e..64335a96045bf 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -132,7 +132,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val batchDuration = Seconds(2) // Create a file that exists before the StreamingContext is created: val existingFile = new File(testDir, "0") - Files.write("0\n", existingFile, StandardCharsets.UTF_8) + Files.asCharSink(existingFile, StandardCharsets.UTF_8).write("0\n") assert(existingFile.setLastModified(10000) && existingFile.lastModified === 10000) // Set up the streaming context and input streams @@ -191,7 +191,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Create a file that exists before the StreamingContext is created: val existingFile = new File(testDir, "0") - Files.write("0\n", existingFile, StandardCharsets.UTF_8) + Files.asCharSink(existingFile, StandardCharsets.UTF_8).write("0\n") assert(existingFile.setLastModified(10000) && existingFile.lastModified === 10000) val pathWithWildCard = testDir.toString + "/*/" @@ -215,7 +215,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { def createFileAndAdvanceTime(data: Int, dir: File): Unit = { val file = new File(testSubDir1, data.toString) - Files.write(s"$data\n", file, StandardCharsets.UTF_8) + Files.asCharSink(file, StandardCharsets.UTF_8).write(s"$data\n") assert(file.setLastModified(clock.getTimeMillis())) assert(file.lastModified === clock.getTimeMillis()) logInfo(s"Created file $file") @@ -478,7 +478,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val batchDuration = Seconds(2) // Create a file that exists before the StreamingContext is created: val existingFile = new File(testDir, "0") - Files.write("0\n", existingFile, StandardCharsets.UTF_8) + Files.asCharSink(existingFile, StandardCharsets.UTF_8).write("0\n") assert(existingFile.setLastModified(10000) && existingFile.lastModified === 10000) // Set up the streaming context and input streams @@ -502,7 +502,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val input = Seq(1, 2, 3, 4, 5) input.foreach { i => val file = new File(testDir, i.toString) - Files.write(s"$i\n", file, StandardCharsets.UTF_8) + Files.asCharSink(file, StandardCharsets.UTF_8).write(s"$i\n") assert(file.setLastModified(clock.getTimeMillis())) assert(file.lastModified === clock.getTimeMillis()) logInfo("Created file " + file) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index 771e65ed40b51..2dc43a231d9b8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -375,7 +375,7 @@ class FileGeneratingThread(input: Seq[String], testDir: Path, interval: Long) val localFile = new File(localTestDir, (i + 1).toString) val hadoopFile = new Path(testDir, (i + 1).toString) val tempHadoopFile = new Path(testDir, ".tmp_" + (i + 1).toString) - Files.write(input(i) + "\n", localFile, StandardCharsets.UTF_8) + Files.asCharSink(localFile, StandardCharsets.UTF_8).write(input(i) + "\n") var tries = 0 var done = false while (!done && tries < maxTries) {