diff --git a/core/trino-main/src/main/java/io/trino/security/AccessControlManager.java b/core/trino-main/src/main/java/io/trino/security/AccessControlManager.java index b1ff99897627..3dd6a5ee2332 100644 --- a/core/trino-main/src/main/java/io/trino/security/AccessControlManager.java +++ b/core/trino-main/src/main/java/io/trino/security/AccessControlManager.java @@ -1267,7 +1267,14 @@ public List getColumnMasks(SecurityContext context, QualifiedObj .forEach(masks::add); } - return masks.build(); + // Currently the use case of multiple masks on a single column is not supported, the reason being there's no guarantee about the order + // in which masks will be applied and whether the functions from different masks are compatible with each other. + List combinedMasks = masks.build(); + if (combinedMasks.size() > 1) { + throw new TrinoException(NOT_SUPPORTED, format("Multiple masks on a single column is not supported: %s", columnName)); + } + + return combinedMasks; } private ConnectorAccessControl getConnectorAccessControl(TransactionId transactionId, String catalogName) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index 2b0e420e27d8..daaf9981f38a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -305,26 +305,26 @@ private RelationPlan addColumnMasks(Table table, RelationPlan plan) PlanBuilder planBuilder = newPlanBuilder(plan, analysis, lambdaDeclarationToSymbolMap, session, plannerContext) .withScope(analysis.getAccessControlScope(table), plan.getFieldMappings()); // The fields in the access control scope has the same layout as those for the table scope + Map assignments = new LinkedHashMap<>(); + for (Symbol symbol : planBuilder.getRoot().getOutputSymbols()) { + assignments.put(symbol, symbol.toSymbolReference()); + } + for (int i = 0; i < plan.getDescriptor().getAllFieldCount(); i++) { Field field = plan.getDescriptor().getFieldByIndex(i); for (Expression mask : columnMasks.getOrDefault(field.getName().orElseThrow(), ImmutableList.of())) { planBuilder = subqueryPlanner.handleSubqueries(planBuilder, mask, analysis.getSubqueries(mask)); - - Map assignments = new LinkedHashMap<>(); - for (Symbol symbol : planBuilder.getRoot().getOutputSymbols()) { - assignments.put(symbol, symbol.toSymbolReference()); - } assignments.put(plan.getFieldMappings().get(i), coerceIfNecessary(analysis, mask, planBuilder.rewrite(mask))); - - planBuilder = planBuilder - .withNewRoot(new ProjectNode( - idAllocator.getNextId(), - planBuilder.getRoot(), - Assignments.copyOf(assignments))); } } + planBuilder = planBuilder + .withNewRoot(new ProjectNode( + idAllocator.getNextId(), + planBuilder.getRoot(), + Assignments.copyOf(assignments))); + return new RelationPlan(planBuilder.getRoot(), plan.getScope(), plan.getFieldMappings(), outerContext); } diff --git a/core/trino-main/src/test/java/io/trino/security/TestAccessControlManager.java b/core/trino-main/src/test/java/io/trino/security/TestAccessControlManager.java index 8b3964568014..ed570c36628e 100644 --- a/core/trino-main/src/test/java/io/trino/security/TestAccessControlManager.java +++ b/core/trino-main/src/test/java/io/trino/security/TestAccessControlManager.java @@ -209,6 +209,7 @@ public void testDenyTableFunctionCatalogAccessControl() } } + // TODO: need to properly handle the ordering of multiple masks as they are not allowed currently @Test public void testColumnMaskOrdering() { @@ -259,16 +260,16 @@ public void checkCanShowCreateTable(ConnectorSecurityContext context, SchemaTabl } }))); - transaction(transactionManager, accessControlManager) + assertThatThrownBy(() -> transaction(transactionManager, accessControlManager) .execute(transactionId -> { - List masks = accessControlManager.getColumnMasks( + accessControlManager.getColumnMasks( context(transactionId), new QualifiedObjectName(TEST_CATALOG_NAME, "schema", "table"), "column", BIGINT); - assertEquals(masks.get(0).getExpression(), "connector mask"); - assertEquals(masks.get(1).getExpression(), "system mask"); - }); + })) + .isInstanceOf(TrinoException.class) + .hasMessageMatching("Multiple masks on a single column is not supported: column"); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java b/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java index 9b7078b05f7f..b83f81cc65f1 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java @@ -212,7 +212,8 @@ public void testMultipleMasksOnSameColumn() USER, new ViewExpression(USER, Optional.empty(), Optional.empty(), "custkey * 2")); - assertThat(assertions.query("SELECT custkey FROM orders WHERE orderkey = 1")).matches("VALUES BIGINT '-740'"); + // When there are multiple masks on the same column, the latter one overrides the previous ones + assertThat(assertions.query("SELECT custkey FROM orders WHERE orderkey = 1")).matches("VALUES BIGINT '740'"); } @Test @@ -842,12 +843,13 @@ public void testMultipleMasksUsingOtherMaskedColumns() // Mask "comment" and "orderstatus" using "clerk" ("clerk" appears between "orderstatus" and "comment" in table definition) // "comment" and "orderstatus" are masked as the condition on "clerk" is satisfied + // This is to showcase that the three maskings are done simultaneously, not in a "sequential" or "chained" manner. accessControl.reset(); accessControl.columnMask( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "cast(regexp_replace(clerk,'(Clerk#)','***#') as varchar(15))")); + new ViewExpression(USER, Optional.empty(), Optional.empty(), "cast('###' as varchar(15))")); accessControl.columnMask( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), @@ -862,6 +864,6 @@ public void testMultipleMasksUsingOtherMaskedColumns() new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'([1-9]+)') IN ('951'), '***', comment)")); assertThat(assertions.query(query)) - .matches("VALUES (CAST('***' as varchar(79)), '*', CAST('***#000000951' as varchar(15)))"); + .matches("VALUES (CAST('***' as varchar(79)), '*', CAST('###' as varchar(15)))"); } }