diff --git a/wren-core-base/manifest-macro/src/lib.rs b/wren-core-base/manifest-macro/src/lib.rs index 6e8e8f06d..0207b5a1a 100644 --- a/wren-core-base/manifest-macro/src/lib.rs +++ b/wren-core-base/manifest-macro/src/lib.rs @@ -169,6 +169,8 @@ pub fn column(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStrea pub expression: Option, #[serde(default, with = "bool_from_int")] pub is_hidden: bool, + pub rls: Option, + pub cls: Option, } }; proc_macro::TokenStream::from(expanded) @@ -342,3 +344,146 @@ pub fn view(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream }; proc_macro::TokenStream::from(expanded) } + +#[proc_macro] +pub fn row_level_security(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(python_binding as LitBool); + let python_binding = if input.value { + quote! { + #[pyclass] + } + } else { + quote! {} + }; + let expanded = quote! { + #python_binding + #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash)] + pub struct RowLevelSecurity { + pub name: String, + pub operator: RowLevelOperator, + } + }; + proc_macro::TokenStream::from(expanded) +} + +#[proc_macro] +pub fn row_level_operator(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(python_binding as LitBool); + let python_binding = if input.value { + quote! { + #[pyclass(eq, eq_int)] + } + } else { + quote! {} + }; + let expanded = quote! { + #python_binding + #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)] + #[serde(rename_all = "SCREAMING_SNAKE_CASE")] + pub enum RowLevelOperator { + Equals, + NotEquals, + GreaterThan, + LessThan, + GreaterThanOrEquals, + LessThanOrEquals, + IN, + NotIn, + LIKE, + NotLike, + } + }; + proc_macro::TokenStream::from(expanded) +} + +#[proc_macro] +pub fn column_level_security(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(python_binding as LitBool); + let python_binding = if input.value { + quote! { + #[pyclass] + } + } else { + quote! {} + }; + let expanded = quote! { + #python_binding + #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash)] + pub struct ColumnLevelSecurity { + pub name: String, + pub operator: ColumnLevelOperator, + pub threshold: NormalizedExpr, + } + }; + proc_macro::TokenStream::from(expanded) +} + +#[proc_macro] +pub fn column_level_operator(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(python_binding as LitBool); + let python_binding = if input.value { + quote! { + #[pyclass(eq, eq_int)] + } + } else { + quote! {} + }; + let expanded = quote! { + #python_binding + #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)] + #[serde(rename_all = "SCREAMING_SNAKE_CASE")] + pub enum ColumnLevelOperator { + Equals, + NotEquals, + GreaterThan, + LessThan, + GreaterThanOrEquals, + LessThanOrEquals, + } + }; + proc_macro::TokenStream::from(expanded) +} + +#[proc_macro] +pub fn normalized_expr(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(python_binding as LitBool); + let python_binding = if input.value { + quote! { + #[pyclass] + } + } else { + quote! {} + }; + let expanded = quote! { + #python_binding + #[derive(SerializeDisplay, DeserializeFromStr, Debug, PartialEq, Eq, Hash)] + pub struct NormalizedExpr { + pub value: String, + #[serde_with(alias = "type")] + pub data_type: NormalizedExprType, + } + }; + proc_macro::TokenStream::from(expanded) +} + +#[proc_macro] +pub fn normalized_expr_type(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(python_binding as LitBool); + let python_binding = if input.value { + quote! { + #[pyclass(eq, eq_int)] + } + } else { + quote! {} + }; + let expanded = quote! { + #python_binding + #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)] + #[serde(rename_all = "SCREAMING_SNAKE_CASE")] + pub enum NormalizedExprType { + Numeric, + String, + } + }; + proc_macro::TokenStream::from(expanded) +} diff --git a/wren-core-base/src/mdl/builder.rs b/wren-core-base/src/mdl/builder.rs index d69adf52b..4be1ff49c 100644 --- a/wren-core-base/src/mdl/builder.rs +++ b/wren-core-base/src/mdl/builder.rs @@ -22,6 +22,9 @@ use crate::mdl::manifest::{ Column, DataSource, JoinType, Manifest, Metric, Model, Relationship, TimeGrain, TimeUnit, View, }; +use crate::mdl::{ + ColumnLevelOperator, ColumnLevelSecurity, NormalizedExpr, RowLevelOperator, RowLevelSecurity, +}; use std::sync::Arc; /// A builder for creating a Manifest @@ -165,6 +168,8 @@ impl ColumnBuilder { is_hidden: false, not_null: false, expression: None, + rls: None, + cls: None, }, } } @@ -202,6 +207,27 @@ impl ColumnBuilder { self } + pub fn row_level_security(mut self, name: &str, operator: RowLevelOperator) -> Self { + self.column.rls = Some(RowLevelSecurity { + name: name.to_string(), + operator, + }); + self + } + + pub fn column_level_security( + mut self, + name: &str, + operator: ColumnLevelOperator, + threshold: &str, + ) -> Self { + self.column.cls = Some(ColumnLevelSecurity { + name: name.to_string(), + operator, + threshold: NormalizedExpr::new(threshold), + }); + self + } pub fn build(self) -> Arc { Arc::new(self.column) } @@ -356,6 +382,7 @@ mod test { use crate::mdl::manifest::{ Column, DataSource, JoinType, Manifest, Metric, Model, Relationship, TimeUnit, View, }; + use crate::mdl::{ColumnLevelOperator, RowLevelOperator}; use std::fs; use std::path::PathBuf; use std::sync::Arc; @@ -368,6 +395,8 @@ mod test { .not_null(true) .hidden(true) .expression("test") + .row_level_security("SESSION_STATUS", RowLevelOperator::Equals) + .column_level_security("SESSION_LEVEL", ColumnLevelOperator::Equals, "'NORMAL'") .build(); let json_str = serde_json::to_string(&expected).unwrap(); @@ -661,6 +690,22 @@ mod test { .calculated(true) .build(), ) + .column( + ColumnBuilder::new("rls_orderkey", "integer") + .row_level_security("SESSION_STATUS", RowLevelOperator::Equals) + .expression("o_orderkey") + .build(), + ) + .column( + ColumnBuilder::new("cls_orderkey", "integer") + .column_level_security( + "SESSION_LEVEL", + ColumnLevelOperator::Equals, + "'NORMAL'", + ) + .expression("o_orderkey") + .build(), + ) .primary_key("o_orderkey") .build(), ) diff --git a/wren-core-base/src/mdl/cls.rs b/wren-core-base/src/mdl/cls.rs new file mode 100644 index 000000000..93c3f7de4 --- /dev/null +++ b/wren-core-base/src/mdl/cls.rs @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +use crate::mdl::manifest::{ColumnLevelSecurity, NormalizedExpr, NormalizedExprType}; +use crate::mdl::ColumnLevelOperator; +use std::fmt::{Display, Formatter}; +use std::str::FromStr; + +impl ColumnLevelSecurity { + /// Evaluate the input against the column level security. + /// If the type of the input is different from the type of the value, the result is always false except for NOT_EQUALS. + pub fn eval(&self, input: &str) -> bool { + let input_expr = NormalizedExpr::new(input); + match self.operator { + ColumnLevelOperator::Equals => input_expr.eq(&self.threshold), + ColumnLevelOperator::NotEquals => input_expr.neq(&self.threshold), + ColumnLevelOperator::GreaterThan => input_expr.gt(&self.threshold), + ColumnLevelOperator::LessThan => input_expr.lt(&self.threshold), + ColumnLevelOperator::GreaterThanOrEquals => input_expr.gte(&self.threshold), + ColumnLevelOperator::LessThanOrEquals => input_expr.lte(&self.threshold), + } + } +} + +impl NormalizedExpr { + pub fn new(expr: &str) -> Self { + assert!(!expr.is_empty(), "expr is null or empty"); + + if Self::is_string(expr) { + NormalizedExpr { + value: expr[1..expr.len() - 1].to_string(), + data_type: NormalizedExprType::String, + } + } else { + NormalizedExpr { + value: expr.to_string(), + data_type: NormalizedExprType::Numeric, + } + } + } + + fn is_string(expr: &str) -> bool { + expr.starts_with("'") && expr.ends_with("'") + } + + fn eq(&self, other: &Self) -> bool { + if self.data_type != other.data_type { + return false; + } + self.value == other.value + } + + fn neq(&self, other: &Self) -> bool { + !self.eq(other) + } + + fn gt(&self, other: &Self) -> bool { + if self.data_type != other.data_type { + return false; + } + match self.data_type { + NormalizedExprType::String => self.value > other.value, + NormalizedExprType::Numeric => { + self.value.parse::().unwrap() > other.value.parse::().unwrap() + } + } + } + + fn lt(&self, other: &Self) -> bool { + if self.data_type != other.data_type { + return false; + } + match self.data_type { + NormalizedExprType::String => self.value < other.value, + NormalizedExprType::Numeric => { + self.value.parse::().unwrap() < other.value.parse::().unwrap() + } + } + } + + fn gte(&self, other: &Self) -> bool { + self.gt(other) || self.eq(other) + } + + fn lte(&self, other: &Self) -> bool { + self.lt(other) || self.eq(other) + } +} + +impl Display for NormalizedExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self.data_type { + NormalizedExprType::String => write!(f, "'{}'", self.value), + NormalizedExprType::Numeric => write!(f, "{}", self.value), + } + } +} + +impl FromStr for NormalizedExpr { + type Err = String; + + fn from_str(s: &str) -> Result { + Ok(NormalizedExpr::new(s)) + } +} + +#[cfg(test)] +mod test { + use crate::mdl::{ColumnLevelOperator, ColumnLevelSecurity, NormalizedExpr}; + + #[test] + #[should_panic(expected = "expr is null or empty")] + fn test_normalized_expr_with_empty_str() { + NormalizedExpr::new(""); + } + + #[test] + #[should_panic(expected = "expr is null or empty")] + fn test_column_level_security_eval_empty_str() { + ColumnLevelSecurity { + name: "numericEquals".to_string(), + operator: ColumnLevelOperator::Equals, + threshold: NormalizedExpr::new("1"), + } + .eval(""); + } + + #[test] + fn test_numeric_column_level_security() { + let cls_name = "cls_name"; + assert!(ColumnLevelSecurity { + name: cls_name.to_string(), + operator: ColumnLevelOperator::Equals, + threshold: NormalizedExpr::new("1"), + } + .eval("1")); + assert!(ColumnLevelSecurity { + name: cls_name.to_string(), + operator: ColumnLevelOperator::NotEquals, + threshold: NormalizedExpr::new("1"), + } + .eval("2")); + assert!(ColumnLevelSecurity { + name: cls_name.to_string(), + operator: ColumnLevelOperator::GreaterThan, + threshold: NormalizedExpr::new("1"), + } + .eval("2")); + assert!(ColumnLevelSecurity { + name: cls_name.to_string(), + operator: ColumnLevelOperator::LessThan, + threshold: NormalizedExpr::new("1"), + } + .eval("-1")); + assert!(ColumnLevelSecurity { + name: cls_name.to_string(), + operator: ColumnLevelOperator::GreaterThanOrEquals, + threshold: NormalizedExpr::new("1"), + } + .eval("1")); + assert!(ColumnLevelSecurity { + name: cls_name.to_string(), + operator: ColumnLevelOperator::LessThanOrEquals, + threshold: NormalizedExpr::new("1"), + } + .eval("1")); + } + + #[test] + fn test_string_column_level_security() { + let cls_name = "cls_name"; + assert!(ColumnLevelSecurity { + name: cls_name.to_string(), + operator: ColumnLevelOperator::Equals, + threshold: NormalizedExpr::new("'b'"), + } + .eval("'b'")); + assert!(ColumnLevelSecurity { + name: cls_name.to_string(), + operator: ColumnLevelOperator::NotEquals, + threshold: NormalizedExpr::new("'b'"), + } + .eval("'B'")); + assert!(ColumnLevelSecurity { + name: cls_name.to_string(), + operator: ColumnLevelOperator::GreaterThan, + threshold: NormalizedExpr::new("'b'"), + } + .eval("'c'")); + assert!(ColumnLevelSecurity { + name: cls_name.to_string(), + operator: ColumnLevelOperator::LessThan, + threshold: NormalizedExpr::new("'b'"), + } + .eval("'a'")); + assert!(ColumnLevelSecurity { + name: cls_name.to_string(), + operator: ColumnLevelOperator::GreaterThanOrEquals, + threshold: NormalizedExpr::new("'b'"), + } + .eval("'b'")); + assert!(ColumnLevelSecurity { + name: cls_name.to_string(), + operator: ColumnLevelOperator::LessThanOrEquals, + threshold: NormalizedExpr::new("'b'"), + } + .eval("'b'")); + } + + #[test] + fn test_diff_type_column_level_security_eval() { + let cls_name = "cls_name"; + assert!(!ColumnLevelSecurity { + name: cls_name.to_string(), + operator: ColumnLevelOperator::Equals, + threshold: NormalizedExpr::new("1"), + } + .eval("'1'")); + assert!(ColumnLevelSecurity { + name: cls_name.to_string(), + operator: ColumnLevelOperator::NotEquals, + threshold: NormalizedExpr::new("1"), + } + .eval("'1'")); + assert!(!ColumnLevelSecurity { + name: cls_name.to_string(), + operator: ColumnLevelOperator::GreaterThan, + threshold: NormalizedExpr::new("1"), + } + .eval("'1'")); + assert!(!ColumnLevelSecurity { + name: cls_name.to_string(), + operator: ColumnLevelOperator::LessThan, + threshold: NormalizedExpr::new("1"), + } + .eval("'1'")); + assert!(!ColumnLevelSecurity { + name: cls_name.to_string(), + operator: ColumnLevelOperator::GreaterThanOrEquals, + threshold: NormalizedExpr::new("1"), + } + .eval("'1'")); + assert!(!ColumnLevelSecurity { + name: cls_name.to_string(), + operator: ColumnLevelOperator::LessThanOrEquals, + threshold: NormalizedExpr::new("1"), + } + .eval("'1'")); + } +} diff --git a/wren-core-base/src/mdl/manifest.rs b/wren-core-base/src/mdl/manifest.rs index c9d6bce5b..5118e7097 100644 --- a/wren-core-base/src/mdl/manifest.rs +++ b/wren-core-base/src/mdl/manifest.rs @@ -24,12 +24,15 @@ mod manifest_impl { use crate::mdl::manifest::bool_from_int; use crate::mdl::manifest::table_reference; use manifest_macro::{ - column, data_source, join_type, manifest, metric, model, relationship, time_grain, - time_unit, view, + column, column_level_operator, column_level_security, data_source, join_type, manifest, + metric, model, normalized_expr, normalized_expr_type, relationship, row_level_operator, + row_level_security, time_grain, time_unit, view, }; use serde::{Deserialize, Serialize}; use serde_with::serde_as; + use serde_with::DeserializeFromStr; use serde_with::NoneAsEmptyString; + use serde_with::SerializeDisplay; use std::sync::Arc; manifest!(false); data_source!(false); @@ -41,6 +44,12 @@ mod manifest_impl { join_type!(false); time_grain!(false); time_unit!(false); + row_level_security!(false); + row_level_operator!(false); + column_level_security!(false); + normalized_expr!(false); + normalized_expr_type!(false); + column_level_operator!(false); } #[cfg(feature = "python-binding")] @@ -48,13 +57,16 @@ mod manifest_impl { use crate::mdl::manifest::bool_from_int; use crate::mdl::manifest::table_reference; use manifest_macro::{ - column, data_source, join_type, manifest, metric, model, relationship, time_grain, - time_unit, view, + column, column_level_operator, column_level_security, data_source, join_type, manifest, + metric, model, normalized_expr, normalized_expr_type, relationship, row_level_operator, + row_level_security, time_grain, time_unit, view, }; use pyo3::pyclass; use serde::{Deserialize, Serialize}; use serde_with::serde_as; + use serde_with::DeserializeFromStr; use serde_with::NoneAsEmptyString; + use serde_with::SerializeDisplay; use std::sync::Arc; data_source!(true); @@ -67,6 +79,12 @@ mod manifest_impl { time_grain!(true); time_unit!(true); manifest!(true); + row_level_security!(true); + row_level_operator!(true); + column_level_security!(true); + normalized_expr!(true); + normalized_expr_type!(true); + column_level_operator!(true); } pub use crate::mdl::manifest::manifest_impl::*; diff --git a/wren-core-base/src/mdl/mod.rs b/wren-core-base/src/mdl/mod.rs index bb6c52514..2003099b7 100644 --- a/wren-core-base/src/mdl/mod.rs +++ b/wren-core-base/src/mdl/mod.rs @@ -18,6 +18,7 @@ */ pub mod builder; +pub mod cls; pub mod manifest; mod py_method; diff --git a/wren-core-base/tests/data/mdl.json b/wren-core-base/tests/data/mdl.json new file mode 100644 index 000000000..e588b4045 --- /dev/null +++ b/wren-core-base/tests/data/mdl.json @@ -0,0 +1,162 @@ +{ + "catalog": "test", + "schema": "test", + "models": [ + { + "name": "customer", + "tableReference": { + "catalog": "", + "schema": "", + "table": "customer" + }, + "columns": [ + { + "name": "c_custkey", + "type": "integer" + }, + { + "name": "c_name", + "type": "varchar" + }, + { + "name": "custkey_plus", + "type": "integer", + "expression": "c_custkey + 1", + "isCalculated": true + }, + { + "name": "orders", + "type": "orders", + "relationship": "CustomerOrders", + "properties": { + "description": "This is a customer orders relationship", + "maintainer": "test" + } + } + ], + "primaryKey": "c_custkey", + "properties": { + "description": "This is a customer table", + "maintainer": "test" + } + }, + { + "name": "profile", + "tableReference": { + "table": "profile" + }, + "columns": [ + { + "name": "p_custkey", + "type": "integer" + }, + { + "name": "p_phone", + "type": "varchar" + }, + { + "name": "p_sex", + "type": "varchar" + }, + { + "name": "customer", + "type": "customer", + "relationship": "CustomerProfile" + }, + { + "name": "totalcost", + "type": "integer", + "isCalculated": true, + "expression": "sum(customer.orders.o_totalprice)" + } + ], + "primaryKey": "p_custkey" + }, + { + "name": "orders", + "tableReference": { + "catalog": "", + "schema": "", + "table": "orders" + }, + "columns": [ + { + "name": "o_orderkey", + "type": "integer" + }, + { + "name": "o_custkey", + "type": "integer" + }, + { + "name": "o_totalprice", + "type": "integer" + }, + { + "name": "customer", + "type": "customer", + "relationship": "CustomerOrders" + }, + { + "name": "customer_name", + "type": "varchar", + "expression": "customer.c_name", + "isCalculated": true + }, + { + "name": "orderkey_plus_custkey", + "type": "integer", + "expression": "o_orderkey + o_custkey", + "isCalculated": true + }, + { + "name": "hash_orderkey", + "type": "varchar", + "expression": "md5(o_orderkey)", + "isCalculated": true + }, + { + "name": "rls_orderkey", + "type": "integer", + "expression": "o_orderkey", + "rls": { + "name": "SESSION_STATUS", + "operator": "EQUALS" + } + }, + { + "name": "cls_orderkey", + "type": "integer", + "expression": "o_orderkey", + "cls": { + "name": "SESSION_LEVEL", + "operator": "EQUALS", + "threshold": "'NORMAL'" + } + } + ], + "primaryKey": "o_orderkey" + } + ], + "relationships": [ + { + "name": "CustomerOrders", + "models": ["customer", "orders"], + "joinType": "ONE_TO_MANY", + "condition": "customer.c_custkey = orders.o_custkey" + }, + { + "name" : "CustomerProfile", + "models": ["customer", "profile"], + "joinType": "ONE_TO_ONE", + "condition": "customer.c_custkey = profile.p_custkey" + } + ], + "views": [ + { + "name": "customer_view", + "statement": "select * from test.test.customer" + } + ], + "dataSource": "mysql" +}