Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.iceberg.common.DynConstructors
import org.apache.iceberg.spark.ExtendedParser
import org.apache.iceberg.spark.ExtendedParser.RawOrderField
import org.apache.iceberg.spark.Spark3Util
import org.apache.iceberg.spark.procedures.SparkProcedures
import org.apache.iceberg.spark.source.SparkTable
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.SparkSession
Expand Down Expand Up @@ -194,8 +195,10 @@ class IcebergSparkSqlExtensionsParser(delegate: ParserInterface) extends ParserI
// Strip comments of the form /* ... */. This must come after stripping newlines so that
// comments that span multiple lines are caught.
.replaceAll("/\\*.*?\\*/", " ")
// Strip backtick then `system`.`ancestors_of` changes to system.ancestors_of
.replaceAll("`", "")
.trim()
normalized.startsWith("call") || (
isIcebergProcedure(normalized) || (
normalized.startsWith("alter table") && (
normalized.contains("add partition field") ||
normalized.contains("drop partition field") ||
Expand All @@ -209,6 +212,12 @@ class IcebergSparkSqlExtensionsParser(delegate: ParserInterface) extends ParserI
isSnapshotRefDdl(normalized)))
}

// All builtin Iceberg procedures are under the 'system' namespace
private def isIcebergProcedure(normalized: String): Boolean = {
normalized.startsWith("call") &&
SparkProcedures.names().asScala.map("system." + _).exists(normalized.contains)
}

private def isSnapshotRefDdl(normalized: String): Boolean = {
normalized.contains("create branch") ||
normalized.contains("replace branch") ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,37 @@ public static void stopSpark() {
currentSpark.stop();
}

@Test
public void testDelegateUnsupportedProcedure() {
assertThatThrownBy(() -> parser.parsePlan("CALL cat.d.t()"))
.isInstanceOf(ParseException.class)
.satisfies(
exception -> {
ParseException parseException = (ParseException) exception;
Assert.assertEquals("PARSE_SYNTAX_ERROR", parseException.getErrorClass());
Assert.assertEquals("Syntax error at or near 'CALL'", parseException.message());
});
}

@Test
public void testCallWithBackticks() throws ParseException {
CallStatement call =
(CallStatement) parser.parsePlan("CALL cat.`system`.`rollback_to_snapshot`()");
Assert.assertEquals(
ImmutableList.of("cat", "system", "rollback_to_snapshot"),
JavaConverters.seqAsJavaList(call.name()));
Assert.assertEquals(0, call.args().size());
}

@Test
public void testCallWithPositionalArgs() throws ParseException {
CallStatement call =
(CallStatement) parser.parsePlan("CALL c.n.func(1, '2', 3L, true, 1.0D, 9.0e1, 900e-1BD)");
(CallStatement)
parser.parsePlan(
"CALL c.system.rollback_to_snapshot(1, '2', 3L, true, 1.0D, 9.0e1, 900e-1BD)");
Assert.assertEquals(
ImmutableList.of("c", "n", "func"), JavaConverters.seqAsJavaList(call.name()));
ImmutableList.of("c", "system", "rollback_to_snapshot"),
JavaConverters.seqAsJavaList(call.name()));

Assert.assertEquals(7, call.args().size());

Expand All @@ -94,9 +119,12 @@ public void testCallWithPositionalArgs() throws ParseException {
@Test
public void testCallWithNamedArgs() throws ParseException {
CallStatement call =
(CallStatement) parser.parsePlan("CALL cat.system.func(c1 => 1, c2 => '2', c3 => true)");
(CallStatement)
parser.parsePlan(
"CALL cat.system.rollback_to_snapshot(c1 => 1, c2 => '2', c3 => true)");
Assert.assertEquals(
ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name()));
ImmutableList.of("cat", "system", "rollback_to_snapshot"),
JavaConverters.seqAsJavaList(call.name()));

Assert.assertEquals(3, call.args().size());

Expand All @@ -107,9 +135,11 @@ public void testCallWithNamedArgs() throws ParseException {

@Test
public void testCallWithMixedArgs() throws ParseException {
CallStatement call = (CallStatement) parser.parsePlan("CALL cat.system.func(c1 => 1, '2')");
CallStatement call =
(CallStatement) parser.parsePlan("CALL cat.system.rollback_to_snapshot(c1 => 1, '2')");
Assert.assertEquals(
ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name()));
ImmutableList.of("cat", "system", "rollback_to_snapshot"),
JavaConverters.seqAsJavaList(call.name()));

Assert.assertEquals(2, call.args().size());

Expand All @@ -121,9 +151,11 @@ public void testCallWithMixedArgs() throws ParseException {
public void testCallWithTimestampArg() throws ParseException {
CallStatement call =
(CallStatement)
parser.parsePlan("CALL cat.system.func(TIMESTAMP '2017-02-03T10:37:30.00Z')");
parser.parsePlan(
"CALL cat.system.rollback_to_snapshot(TIMESTAMP '2017-02-03T10:37:30.00Z')");
Assert.assertEquals(
ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name()));
ImmutableList.of("cat", "system", "rollback_to_snapshot"),
JavaConverters.seqAsJavaList(call.name()));

Assert.assertEquals(1, call.args().size());

Expand All @@ -134,9 +166,11 @@ public void testCallWithTimestampArg() throws ParseException {
@Test
public void testCallWithVarSubstitution() throws ParseException {
CallStatement call =
(CallStatement) parser.parsePlan("CALL cat.system.func('${spark.extra.prop}')");
(CallStatement)
parser.parsePlan("CALL cat.system.rollback_to_snapshot('${spark.extra.prop}')");
Assert.assertEquals(
ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name()));
ImmutableList.of("cat", "system", "rollback_to_snapshot"),
JavaConverters.seqAsJavaList(call.name()));

Assert.assertEquals(1, call.args().size());

Expand All @@ -145,30 +179,32 @@ public void testCallWithVarSubstitution() throws ParseException {

@Test
public void testCallParseError() {
assertThatThrownBy(() -> parser.parsePlan("CALL cat.system radish kebab"))
assertThatThrownBy(() -> parser.parsePlan("CALL cat.system.rollback_to_snapshot kebab"))
.as("Should fail with a sensible parse error")
.isInstanceOf(IcebergParseException.class)
.hasMessageContaining("missing '(' at 'radish'");
.hasMessageContaining("missing '(' at 'kebab'");
}

@Test
public void testCallStripsComments() throws ParseException {
List<String> callStatementsWithComments =
Lists.newArrayList(
"/* bracketed comment */ CALL cat.system.func('${spark.extra.prop}')",
"/**/ CALL cat.system.func('${spark.extra.prop}')",
"-- single line comment \n CALL cat.system.func('${spark.extra.prop}')",
"-- multiple \n-- single line \n-- comments \n CALL cat.system.func('${spark.extra.prop}')",
"/* select * from multiline_comment \n where x like '%sql%'; */ CALL cat.system.func('${spark.extra.prop}')",
"/* bracketed comment */ CALL cat.system.rollback_to_snapshot('${spark.extra.prop}')",
"/**/ CALL cat.system.rollback_to_snapshot('${spark.extra.prop}')",
"-- single line comment \n CALL cat.system.rollback_to_snapshot('${spark.extra.prop}')",
"-- multiple \n-- single line \n-- comments \n CALL cat.system.rollback_to_snapshot('${spark.extra.prop}')",
"/* select * from multiline_comment \n where x like '%sql%'; */ CALL cat.system.rollback_to_snapshot('${spark.extra.prop}')",
"/* {\"app\": \"dbt\", \"dbt_version\": \"1.0.1\", \"profile_name\": \"profile1\", \"target_name\": \"dev\", "
+ "\"node_id\": \"model.profile1.stg_users\"} \n*/ CALL cat.system.func('${spark.extra.prop}')",
+ "\"node_id\": \"model.profile1.stg_users\"} \n*/ CALL cat.system.rollback_to_snapshot('${spark.extra.prop}')",
"/* Some multi-line comment \n"
+ "*/ CALL /* inline comment */ cat.system.func('${spark.extra.prop}') -- ending comment",
"CALL -- a line ending comment\n" + "cat.system.func('${spark.extra.prop}')");
+ "*/ CALL /* inline comment */ cat.system.rollback_to_snapshot('${spark.extra.prop}') -- ending comment",
"CALL -- a line ending comment\n"
+ "cat.system.rollback_to_snapshot('${spark.extra.prop}')");
for (String sqlText : callStatementsWithComments) {
CallStatement call = (CallStatement) parser.parsePlan(sqlText);
Assert.assertEquals(
ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name()));
ImmutableList.of("cat", "system", "rollback_to_snapshot"),
JavaConverters.seqAsJavaList(call.name()));

Assert.assertEquals(1, call.args().size());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@
import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
import org.apache.spark.sql.catalyst.parser.ParseException;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;

public class TestCherrypickSnapshotProcedure extends SparkExtensionsTestBase {
Expand Down Expand Up @@ -178,8 +179,13 @@ public void testInvalidCherrypickSnapshotCases() {

assertThatThrownBy(() -> sql("CALL %s.custom.cherrypick_snapshot('n', 't', 1L)", catalogName))
.as("Should not resolve procedures in arbitrary namespaces")
.isInstanceOf(NoSuchProcedureException.class)
.hasMessageContaining("not found");
.isInstanceOf(ParseException.class)
.satisfies(
exception -> {
ParseException parseException = (ParseException) exception;
Assert.assertEquals("PARSE_SYNTAX_ERROR", parseException.getErrorClass());
Assert.assertEquals("Syntax error at or near 'CALL'", parseException.message());
});

assertThatThrownBy(() -> sql("CALL %s.system.cherrypick_snapshot('t')", catalogName))
.as("Should reject calls without all required args")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
import org.apache.iceberg.spark.source.SimpleRecord;
import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
import org.apache.spark.sql.catalyst.parser.ParseException;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;
Expand Down Expand Up @@ -178,8 +178,12 @@ public void testInvalidExpireSnapshotsCases() {

assertThatThrownBy(() -> sql("CALL %s.custom.expire_snapshots('n', 't')", catalogName))
.as("Should not resolve procedures in arbitrary namespaces")
.isInstanceOf(NoSuchProcedureException.class)
.hasMessageContaining("not found");
.satisfies(
exception -> {
ParseException parseException = (ParseException) exception;
Assert.assertEquals("PARSE_SYNTAX_ERROR", parseException.getErrorClass());
Assert.assertEquals("Syntax error at or near 'CALL'", parseException.message());
});

assertThatThrownBy(() -> sql("CALL %s.system.expire_snapshots()", catalogName))
.as("Should reject calls without all required args")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@
import org.apache.iceberg.Table;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
import org.apache.spark.sql.catalyst.parser.ParseException;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;

public class TestFastForwardBranchProcedure extends SparkExtensionsTestBase {
Expand Down Expand Up @@ -176,8 +177,13 @@ public void testInvalidFastForwardBranchCases() {
assertThatThrownBy(
() ->
sql("CALL %s.custom.fast_forward('test_table', 'main', 'newBranch')", catalogName))
.isInstanceOf(NoSuchProcedureException.class)
.hasMessage("Procedure custom.fast_forward not found");
.isInstanceOf(ParseException.class)
.satisfies(
exception -> {
ParseException parseException = (ParseException) exception;
Assert.assertEquals("PARSE_SYNTAX_ERROR", parseException.getErrorClass());
Assert.assertEquals("Syntax error at or near 'CALL'", parseException.message());
});

assertThatThrownBy(() -> sql("CALL %s.system.fast_forward('test_table', 'main')", catalogName))
.isInstanceOf(AnalysisException.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@
import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
import org.apache.spark.sql.catalyst.parser.ParseException;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;

public class TestPublishChangesProcedure extends SparkExtensionsTestBase {
Expand Down Expand Up @@ -176,8 +177,12 @@ public void testInvalidApplyWapChangesCases() {
assertThatThrownBy(
() -> sql("CALL %s.custom.publish_changes('n', 't', 'not_valid')", catalogName))
.as("Should not resolve procedures in arbitrary namespaces")
.isInstanceOf(NoSuchProcedureException.class)
.hasMessageContaining("not found");
.satisfies(
exception -> {
ParseException parseException = (ParseException) exception;
Assert.assertEquals("PARSE_SYNTAX_ERROR", parseException.getErrorClass());
Assert.assertEquals("Syntax error at or near 'CALL'", parseException.message());
});

assertThatThrownBy(() -> sql("CALL %s.system.publish_changes('t')", catalogName))
.as("Should reject calls without all required args")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
import org.apache.spark.sql.catalyst.parser.ParseException;
import org.junit.After;
Expand Down Expand Up @@ -266,8 +265,12 @@ public void testInvalidRemoveOrphanFilesCases() {

assertThatThrownBy(() -> sql("CALL %s.custom.remove_orphan_files('n', 't')", catalogName))
.as("Should not resolve procedures in arbitrary namespaces")
.isInstanceOf(NoSuchProcedureException.class)
.hasMessageContaining("not found");
.satisfies(
exception -> {
ParseException parseException = (ParseException) exception;
Assert.assertEquals("PARSE_SYNTAX_ERROR", parseException.getErrorClass());
Assert.assertEquals("Syntax error at or near 'CALL'", parseException.message());
});

assertThatThrownBy(() -> sql("CALL %s.system.remove_orphan_files()", catalogName))
.as("Should reject calls without all required args")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
import org.apache.spark.sql.catalyst.parser.ParseException;
import org.junit.After;
import org.junit.Assert;
import org.junit.Assume;
Expand Down Expand Up @@ -566,8 +566,12 @@ public void testInvalidCasesForRewriteDataFiles() {

assertThatThrownBy(() -> sql("CALL %s.custom.rewrite_data_files('n', 't')", catalogName))
.as("Should not resolve procedures in arbitrary namespaces")
.isInstanceOf(NoSuchProcedureException.class)
.hasMessageContaining("not found");
.satisfies(
exception -> {
ParseException parseException = (ParseException) exception;
Assert.assertEquals("PARSE_SYNTAX_ERROR", parseException.getErrorClass());
Assert.assertEquals("Syntax error at or near 'CALL'", parseException.message());
});

assertThatThrownBy(() -> sql("CALL %s.system.rewrite_data_files()", catalogName))
.as("Should reject calls without all required args")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
import org.apache.spark.sql.catalyst.parser.ParseException;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;
Expand Down Expand Up @@ -284,8 +284,12 @@ public void testInvalidRewriteManifestsCases() {

assertThatThrownBy(() -> sql("CALL %s.custom.rewrite_manifests('n', 't')", catalogName))
.as("Should not resolve procedures in arbitrary namespaces")
.isInstanceOf(NoSuchProcedureException.class)
.hasMessageContaining("not found");
.satisfies(
exception -> {
ParseException parseException = (ParseException) exception;
Assert.assertEquals("PARSE_SYNTAX_ERROR", parseException.getErrorClass());
Assert.assertEquals("Syntax error at or near 'CALL'", parseException.message());
});

assertThatThrownBy(() -> sql("CALL %s.system.rewrite_manifests()", catalogName))
.as("Should reject calls without all required args")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
import org.apache.spark.sql.catalyst.parser.ParseException;
import org.junit.After;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.Test;

Expand Down Expand Up @@ -261,8 +262,12 @@ public void testInvalidRollbackToSnapshotCases() {

assertThatThrownBy(() -> sql("CALL %s.custom.rollback_to_snapshot('n', 't', 1L)", catalogName))
.as("Should not resolve procedures in arbitrary namespaces")
.isInstanceOf(NoSuchProcedureException.class)
.hasMessageContaining("not found");
.satisfies(
exception -> {
ParseException parseException = (ParseException) exception;
Assert.assertEquals("PARSE_SYNTAX_ERROR", parseException.getErrorClass());
Assert.assertEquals("Syntax error at or near 'CALL'", parseException.message());
});

assertThatThrownBy(() -> sql("CALL %s.system.rollback_to_snapshot('t')", catalogName))
.as("Should reject calls without all required args")
Expand Down
Loading