From afe1bb09e8b13d3e6a0d27a34255377b4d2fcf4d Mon Sep 17 00:00:00 2001 From: Artem Zhmaka Date: Wed, 21 Sep 2022 21:14:13 +0200 Subject: [PATCH 1/8] zhars/implement_searchable_tokenization Implement searchable tokenization for PostgreSQL for text/binary protocols --- decryptor/postgresql/proxy.go | 3 + encryptor/searchable_query_filter.go | 197 +++++++++++++++++++ hmac/decryptor/expr.go | 57 ------ hmac/decryptor/hashQuery.go | 155 ++------------- pseudonymization/dataTokenizer.go | 2 +- pseudonymization/tokenizeQuery.go | 191 ++++++++++++++++++ pseudonymization/tokenizeQuery_test.go | 115 +++++++++++ tests/ee_searchable_tokenization_config.yaml | 27 +++ tests/test.py | 74 ++++++- 9 files changed, 622 insertions(+), 199 deletions(-) create mode 100644 encryptor/searchable_query_filter.go delete mode 100644 hmac/decryptor/expr.go create mode 100644 pseudonymization/tokenizeQuery.go create mode 100644 pseudonymization/tokenizeQuery_test.go create mode 100644 tests/ee_searchable_tokenization_config.yaml diff --git a/decryptor/postgresql/proxy.go b/decryptor/postgresql/proxy.go index 492f84800..0589560ef 100644 --- a/decryptor/postgresql/proxy.go +++ b/decryptor/postgresql/proxy.go @@ -120,6 +120,9 @@ func (factory *proxyFactory) New(clientID []byte, clientSession base.ClientSessi return nil, err } chainEncryptors = append(chainEncryptors, tokenEncryptor) + + acraBlockStructTokenEncryptor := pseudonymization.NewPostgresqlTokenizeQuery(schemaStore, tokenEncryptor) + proxy.AddQueryObserver(acraBlockStructTokenEncryptor) } chainEncryptors = append(chainEncryptors, crypto.NewEncryptHandler(registryHandler)) diff --git a/encryptor/searchable_query_filter.go b/encryptor/searchable_query_filter.go new file mode 100644 index 000000000..147f195ff --- /dev/null +++ b/encryptor/searchable_query_filter.go @@ -0,0 +1,197 @@ +/* +Copyright 2018, Cossack Labs Limited + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package encryptor + +import ( + "github.com/cossacklabs/acra/encryptor/config" + "github.com/cossacklabs/acra/sqlparser" + "github.com/sirupsen/logrus" +) + +// SearchableQueryFilterMode represent the mode work of SearchableQueryFilter +type SearchableQueryFilterMode int + +// QueryFilterModeSearchableEncryption list of supported modes for filtering comparisons for searchable and tokenized values +const ( + QueryFilterModeSearchableEncryption = iota + QueryFilterModeConsistentTokenization +) + +// SearchableExprItem represent the filtered value found by SearchableQueryFilter +type SearchableExprItem struct { + Expr *sqlparser.ComparisonExpr + Setting config.ColumnEncryptionSetting +} + +// SearchableQueryFilter filter se +type SearchableQueryFilter struct { + mode SearchableQueryFilterMode + schemaStore config.TableSchemaStore +} + +// NewSearchableQueryFilter create new SearchableQueryFilter from schemaStore and SearchableQueryFilterMode +func NewSearchableQueryFilter(schemaStore config.TableSchemaStore, mode SearchableQueryFilterMode) *SearchableQueryFilter { + return &SearchableQueryFilter{ + schemaStore: schemaStore, + mode: mode, + } +} + +// FilterSearchableComparisons filter search comparisons from statement +func (filter *SearchableQueryFilter) FilterSearchableComparisons(statement sqlparser.Statement) []SearchableExprItem { + // We are interested only in SELECT statements which access at least one encryptable table. + // If that's not the case, we have nothing to do here. + defaultTable, aliasedTables := filter.filterInterestingTables(statement) + if len(aliasedTables) == 0 { + logrus.Debugln("No encryptable tables in search query") + return nil + } + // Now take a closer look at WHERE clauses of the statement. We need only expressions + // which are simple equality comparisons, like "WHERE column = value". + exprs := filter.filterComparisonExprs(statement) + if len(exprs) == 0 { + logrus.Debugln("No eligible comparisons in search query") + return nil + } + // And among those expressions, not all may refer to columns with searchable encryption + // enabled for them. Leave only those expressions which are searchable. + searchableExprs := filter.filterComparisons(exprs, defaultTable, aliasedTables) + if len(exprs) == 0 { + logrus.Debugln("No searchable comparisons in search query") + return nil + } + return searchableExprs +} + +func (filter *SearchableQueryFilter) filterInterestingTables(statement sqlparser.Statement) (*AliasedTableName, AliasToTableMap) { + // We are interested only in SELECT statements. + selectStatement, ok := statement.(*sqlparser.Select) + if !ok { + return nil, nil + } + // Not all SELECT statements refer to tables at all. + tables := GetTablesWithAliases(selectStatement.From) + if len(tables) == 0 { + return nil, nil + } + // And even then, we can work only with tables that we have an encryption schema for. + var encryptableTables []*AliasedTableName + for _, table := range tables { + if v := filter.schemaStore.GetTableSchema(table.TableName.Name.String()); v != nil { + encryptableTables = append(encryptableTables, table) + } + } + if len(encryptableTables) == 0 { + return nil, nil + } + return tables[0], NewAliasToTableMapFromTables(encryptableTables) +} + +func (filter *SearchableQueryFilter) filterComparisonExprs(statement sqlparser.Statement) []*sqlparser.ComparisonExpr { + // Walk through WHERE clauses of a SELECT statements... + whereExprs, err := getWhereStatements(statement) + if err != nil { + logrus.WithError(err).Debugln("Failed to extract WHERE clauses") + return nil + } + // ...and find all eligible comparison expressions in them. + var exprs []*sqlparser.ComparisonExpr + for _, whereExpr := range whereExprs { + comparisonExprs, err := getEqualComparisonExprs(whereExpr) + if err != nil { + logrus.WithError(err).Debugln("Failed to extract comparison expressions") + return nil + } + exprs = append(exprs, comparisonExprs...) + } + return exprs +} + +func (filter *SearchableQueryFilter) filterComparisons(exprs []*sqlparser.ComparisonExpr, defaultTable *AliasedTableName, aliasedTables AliasToTableMap) []SearchableExprItem { + filtered := make([]SearchableExprItem, 0, len(exprs)) + for _, expr := range exprs { + // Leave out comparisons of columns which do not have a schema after alias resolution. + column := expr.Left.(*sqlparser.ColName) + schema := filter.getTableSchemaOfColumn(column, defaultTable, aliasedTables) + if schema == nil { + continue + } + // Also leave out those columns which are not searchable. + columnName := column.Name.String() + encryptionSetting := schema.GetColumnEncryptionSettings(columnName) + + if encryptionSetting == nil { + continue + } + + isComparableSetting := encryptionSetting.IsSearchable() + if filter.mode == QueryFilterModeConsistentTokenization { + isComparableSetting = encryptionSetting.IsConsistentTokenization() + } + + if isComparableSetting { + filtered = append(filtered, SearchableExprItem{Expr: expr, Setting: encryptionSetting}) + } + } + return filtered +} + +func (filter *SearchableQueryFilter) getTableSchemaOfColumn(column *sqlparser.ColName, defaultTable *AliasedTableName, aliasedTables AliasToTableMap) config.TableSchema { + if column.Qualifier.Qualifier.IsEmpty() { + return filter.schemaStore.GetTableSchema(defaultTable.TableName.Name.String()) + } + tableName := aliasedTables[column.Qualifier.Name.String()] + return filter.schemaStore.GetTableSchema(tableName) +} + +func getWhereStatements(stmt sqlparser.Statement) ([]*sqlparser.Where, error) { + var whereStatements []*sqlparser.Where + err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + where, ok := node.(*sqlparser.Where) + if ok { + whereStatements = append(whereStatements, where) + } + return true, nil + }, stmt) + return whereStatements, err +} + +func isSupportedSQLVal(val *sqlparser.SQLVal) bool { + switch val.Type { + case sqlparser.PgEscapeString, sqlparser.HexVal, sqlparser.StrVal, sqlparser.PgPlaceholder, sqlparser.ValArg, sqlparser.IntVal: + return true + } + return false +} + +// getEqualComparisonExprs return only = or != or <=> expressions +func getEqualComparisonExprs(stmt sqlparser.SQLNode) ([]*sqlparser.ComparisonExpr, error) { + var exprs []*sqlparser.ComparisonExpr + err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + if comparisonExpr, ok := node.(*sqlparser.ComparisonExpr); ok { + if sqlVal, ok := comparisonExpr.Right.(*sqlparser.SQLVal); ok && isSupportedSQLVal(sqlVal) { + if comparisonExpr.Operator == sqlparser.EqualStr || comparisonExpr.Operator == sqlparser.NotEqualStr || comparisonExpr.Operator == sqlparser.NullSafeEqualStr { + if _, ok := comparisonExpr.Left.(*sqlparser.ColName); ok { + exprs = append(exprs, comparisonExpr) + } + } + } + } + return true, nil + }, stmt) + return exprs, err +} diff --git a/hmac/decryptor/expr.go b/hmac/decryptor/expr.go deleted file mode 100644 index 0ca0678d6..000000000 --- a/hmac/decryptor/expr.go +++ /dev/null @@ -1,57 +0,0 @@ -/* -Copyright 2018, Cossack Labs Limited - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package decryptor - -import "github.com/cossacklabs/acra/sqlparser" - -func getWhereStatements(stmt sqlparser.Statement) ([]*sqlparser.Where, error) { - var whereStatements []*sqlparser.Where - err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { - where, ok := node.(*sqlparser.Where) - if ok { - whereStatements = append(whereStatements, where) - } - return true, nil - }, stmt) - return whereStatements, err -} - -func isSupportedSQLVal(val *sqlparser.SQLVal) bool { - switch val.Type { - case sqlparser.PgEscapeString, sqlparser.HexVal, sqlparser.StrVal, sqlparser.PgPlaceholder, sqlparser.ValArg: - return true - } - return false -} - -// getEqualComparisonExprs return only = or != or <=> expressions -func getEqualComparisonExprs(stmt sqlparser.SQLNode) ([]*sqlparser.ComparisonExpr, error) { - var exprs []*sqlparser.ComparisonExpr - err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { - if comparisonExpr, ok := node.(*sqlparser.ComparisonExpr); ok { - if sqlVal, ok := comparisonExpr.Right.(*sqlparser.SQLVal); ok && isSupportedSQLVal(sqlVal) { - if comparisonExpr.Operator == sqlparser.EqualStr || comparisonExpr.Operator == sqlparser.NotEqualStr || comparisonExpr.Operator == sqlparser.NullSafeEqualStr { - if _, ok := comparisonExpr.Left.(*sqlparser.ColName); ok { - exprs = append(exprs, comparisonExpr) - } - } - } - } - return true, nil - }, stmt) - return exprs, err -} diff --git a/hmac/decryptor/hashQuery.go b/hmac/decryptor/hashQuery.go index c5449b329..f20983010 100644 --- a/hmac/decryptor/hashQuery.go +++ b/hmac/decryptor/hashQuery.go @@ -1,24 +1,9 @@ -/* -Copyright 2018, Cossack Labs Limited - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - package decryptor import ( "context" "fmt" + "github.com/cossacklabs/acra/decryptor/base" queryEncryptor "github.com/cossacklabs/acra/encryptor" "github.com/cossacklabs/acra/encryptor/config" @@ -37,21 +22,23 @@ type HashDecryptStore interface { // HashQuery calculate hmac for data inside AcraStruct and change WHERE conditions to support searchable encryption type HashQuery struct { - keystore HashDecryptStore - coder queryEncryptor.DBDataCoder - schemaStore config.TableSchemaStore - decryptor base.ExtendedDataProcessor - parser *sqlparser.Parser + keystore HashDecryptStore + searchableQueryFilter *queryEncryptor.SearchableQueryFilter + coder queryEncryptor.DBDataCoder + decryptor base.ExtendedDataProcessor + parser *sqlparser.Parser } // NewPostgresqlHashQuery return HashQuery with coder for postgresql -func NewPostgresqlHashQuery(keystore HashDecryptStore, schemaStore config.TableSchemaStore, decryptor base.ExtendedDataProcessor) *HashQuery { - return &HashQuery{keystore: keystore, coder: &queryEncryptor.PostgresqlDBDataCoder{}, schemaStore: schemaStore, decryptor: decryptor} +func NewPostgresqlHashQuery(keystore HashDecryptStore, schemaStore config.TableSchemaStore, processor base.ExtendedDataProcessor) *HashQuery { + searchableQueryFilter := queryEncryptor.NewSearchableQueryFilter(schemaStore, queryEncryptor.QueryFilterModeSearchableEncryption) + return &HashQuery{keystore: keystore, coder: &queryEncryptor.PostgresqlDBDataCoder{}, searchableQueryFilter: searchableQueryFilter, decryptor: processor} } // NewMysqlHashQuery return HashQuery with coder for mysql -func NewMysqlHashQuery(keystore HashDecryptStore, schemaStore config.TableSchemaStore, decryptor base.ExtendedDataProcessor) *HashQuery { - return &HashQuery{keystore: keystore, coder: &queryEncryptor.MysqlDBDataCoder{}, schemaStore: schemaStore, decryptor: decryptor} +func NewMysqlHashQuery(keystore HashDecryptStore, schemaStore config.TableSchemaStore, processor base.ExtendedDataProcessor) *HashQuery { + searchableQueryFilter := queryEncryptor.NewSearchableQueryFilter(schemaStore, queryEncryptor.QueryFilterModeSearchableEncryption) + return &HashQuery{keystore: keystore, coder: &queryEncryptor.MysqlDBDataCoder{}, searchableQueryFilter: searchableQueryFilter, decryptor: processor} } // ID returns name of this QueryObserver. @@ -59,108 +46,6 @@ func (encryptor *HashQuery) ID() string { return "HashQuery" } -func (encryptor *HashQuery) filterSearchableComparisons(statement sqlparser.Statement) []searchableExprItem { - // We are interested only in SELECT statements which access at least one encryptable table. - // If that's not the case, we have nothing to do here. - defaultTable, aliasedTables := encryptor.filterInterestingTables(statement) - if len(aliasedTables) == 0 { - logrus.Debugln("No encryptable tables in search query") - return nil - } - // Now take a closer look at WHERE clauses of the statement. We need only expressions - // which are simple equality comparisons, like "WHERE column = value". - exprs := encryptor.filterComparisonExprs(statement) - if len(exprs) == 0 { - logrus.Debugln("No eligible comparisons in search query") - return nil - } - // And among those expressions, not all may refer to columns with searchable encryption - // enabled for them. Leave only those expressions which are searchable. - searchableExprs := encryptor.filterSerchableComparisons(exprs, defaultTable, aliasedTables) - if len(exprs) == 0 { - logrus.Debugln("No searchable comparisons in search query") - return nil - } - return searchableExprs -} - -func (encryptor *HashQuery) filterInterestingTables(statement sqlparser.Statement) (*queryEncryptor.AliasedTableName, queryEncryptor.AliasToTableMap) { - // We are interested only in SELECT statements. - selectStatement, ok := statement.(*sqlparser.Select) - if !ok { - return nil, nil - } - // Not all SELECT statements refer to tables at all. - tables := queryEncryptor.GetTablesWithAliases(selectStatement.From) - if len(tables) == 0 { - return nil, nil - } - // And even then, we can work only with tables that we have an encryption schema for. - var encryptableTables []*queryEncryptor.AliasedTableName - for _, table := range tables { - if v := encryptor.schemaStore.GetTableSchema(table.TableName.Name.String()); v != nil { - encryptableTables = append(encryptableTables, table) - } - } - if len(encryptableTables) == 0 { - return nil, nil - } - return tables[0], queryEncryptor.NewAliasToTableMapFromTables(encryptableTables) -} - -func (encryptor *HashQuery) filterComparisonExprs(statement sqlparser.Statement) []*sqlparser.ComparisonExpr { - // Walk through WHERE clauses of a SELECT statements... - whereExprs, err := getWhereStatements(statement) - if err != nil { - logrus.WithError(err).Debugln("Failed to extract WHERE clauses") - return nil - } - // ...and find all eligible comparison expressions in them. - var exprs []*sqlparser.ComparisonExpr - for _, whereExpr := range whereExprs { - comparisonExprs, err := getEqualComparisonExprs(whereExpr) - if err != nil { - logrus.WithError(err).Debugln("Failed to extract comparison expressions") - return nil - } - exprs = append(exprs, comparisonExprs...) - } - return exprs -} - -type searchableExprItem struct { - expr *sqlparser.ComparisonExpr - setting config.ColumnEncryptionSetting -} - -func (encryptor *HashQuery) filterSerchableComparisons(exprs []*sqlparser.ComparisonExpr, defaultTable *queryEncryptor.AliasedTableName, aliasedTables queryEncryptor.AliasToTableMap) []searchableExprItem { - filtered := make([]searchableExprItem, 0, len(exprs)) - for _, expr := range exprs { - // Leave out comparisons of columns which do not have a schema after alias resolution. - column := expr.Left.(*sqlparser.ColName) - schema := encryptor.getTableSchemaOfColumn(column, defaultTable, aliasedTables) - if schema == nil { - continue - } - // Also leave out those columns which are not searchable. - columnName := column.Name.String() - encryptionSetting := schema.GetColumnEncryptionSettings(columnName) - if encryptionSetting == nil || !encryptionSetting.IsSearchable() { - continue - } - filtered = append(filtered, searchableExprItem{expr: expr, setting: encryptionSetting}) - } - return filtered -} - -func (encryptor *HashQuery) getTableSchemaOfColumn(column *sqlparser.ColName, defaultTable *queryEncryptor.AliasedTableName, aliasedTables queryEncryptor.AliasToTableMap) config.TableSchema { - if column.Qualifier.Qualifier.IsEmpty() { - return encryptor.schemaStore.GetTableSchema(defaultTable.TableName.Name.String()) - } - tableName := aliasedTables[column.Qualifier.Name.String()] - return encryptor.schemaStore.GetTableSchema(tableName) -} - // OnQuery processes query text before database sees it. // // Searchable encryption rewrites WHERE clauses with equality comparisons like this: @@ -182,7 +67,7 @@ func (encryptor *HashQuery) OnQuery(ctx context.Context, query base.OnQueryObjec // Extract the subexpressions that we are interested in for searchable encryption. // The list might be empty for non-SELECT queries or for non-eligible SELECTs. // In that case we don't have any more work to do here. - items := encryptor.filterSearchableComparisons(stmt) + items := encryptor.searchableQueryFilter.FilterSearchableComparisons(stmt) if len(items) == 0 { return query, false, nil } @@ -192,20 +77,20 @@ func (encryptor *HashQuery) OnQuery(ctx context.Context, query base.OnQueryObjec hashSize := []byte(fmt.Sprintf("%d", hmac.GetDefaultHashSize())) for _, item := range items { // column = 'value' ===> substring(column, 1, ) = 'value' - item.expr.Left = &sqlparser.SubstrExpr{ - Name: item.expr.Left.(*sqlparser.ColName), + item.Expr.Left = &sqlparser.SubstrExpr{ + Name: item.Expr.Left.(*sqlparser.ColName), From: sqlparser.NewIntVal([]byte{'1'}), To: sqlparser.NewIntVal(hashSize), } // substring(column, 1, ) = 'value' ===> substring(column, 1, ) = // substring(column, 1, ) = $1 ===> no changes - err := queryEncryptor.UpdateExpressionValue(ctx, item.expr.Right, encryptor.coder, encryptor.calculateHmac) + err := queryEncryptor.UpdateExpressionValue(ctx, item.Expr.Right, encryptor.coder, encryptor.calculateHmac) if err != nil { logrus.WithError(err).Debugln("Failed to update expression") return query, false, err } - sqlVal, ok := item.expr.Right.(*sqlparser.SQLVal) + sqlVal, ok := item.Expr.Right.(*sqlparser.SQLVal) if !ok { continue } @@ -215,7 +100,7 @@ func (encryptor *HashQuery) OnQuery(ctx context.Context, query base.OnQueryObjec } else if err != nil { return query, false, err } - bindSettings[placeholderIndex] = item.setting + bindSettings[placeholderIndex] = item.Setting } logrus.Debugln("HashQuery.OnQuery changed query") return base.NewOnQueryObjectFromStatement(stmt, encryptor.parser), true, nil @@ -238,7 +123,7 @@ func (encryptor *HashQuery) OnBind(ctx context.Context, statement sqlparser.Stat // Extract the subexpressions that we are interested in for searchable encryption. // The list might be empty for non-SELECT queries or for non-eligible SELECTs. // In that case we don't have any more work to do here. - items := encryptor.filterSearchableComparisons(statement) + items := encryptor.searchableQueryFilter.FilterSearchableComparisons(statement) if len(items) == 0 { return values, false, nil } @@ -246,7 +131,7 @@ func (encryptor *HashQuery) OnBind(ctx context.Context, statement sqlparser.Stat // and map them onto values that we need to update. indexes := make([]int, 0, len(values)) for _, item := range items { - switch value := item.expr.Right.(type) { + switch value := item.Expr.Right.(type) { case *sqlparser.SQLVal: var err error index, err := queryEncryptor.ParsePlaceholderIndex(value) diff --git a/pseudonymization/dataTokenizer.go b/pseudonymization/dataTokenizer.go index 228735332..9e6e64e59 100644 --- a/pseudonymization/dataTokenizer.go +++ b/pseudonymization/dataTokenizer.go @@ -17,9 +17,9 @@ package pseudonymization import ( - "github.com/cossacklabs/acra/encryptor/config" "strconv" + "github.com/cossacklabs/acra/encryptor/config" "github.com/cossacklabs/acra/pseudonymization/common" "github.com/sirupsen/logrus" ) diff --git a/pseudonymization/tokenizeQuery.go b/pseudonymization/tokenizeQuery.go new file mode 100644 index 000000000..131aa0af6 --- /dev/null +++ b/pseudonymization/tokenizeQuery.go @@ -0,0 +1,191 @@ +package pseudonymization + +import ( + "context" + + "github.com/cossacklabs/acra/decryptor/base" + queryEncryptor "github.com/cossacklabs/acra/encryptor" + "github.com/cossacklabs/acra/encryptor/config" + "github.com/cossacklabs/acra/sqlparser" + "github.com/sirupsen/logrus" +) + +// TokenizeQuery replace tokenized data inside AcraStruct/AcraBlocks and change WHERE conditions to support searchable tokenization +type TokenizeQuery struct { + coder queryEncryptor.DBDataCoder + tokenEncryptor *TokenEncryptor + searchableQueryFilter *queryEncryptor.SearchableQueryFilter +} + +// NewPostgresqlTokenizeQuery return TokenizeQuery with coder for postgresql +func NewPostgresqlTokenizeQuery(schemaStore config.TableSchemaStore, tokenEncryptor *TokenEncryptor) *TokenizeQuery { + return &TokenizeQuery{ + searchableQueryFilter: queryEncryptor.NewSearchableQueryFilter(schemaStore, queryEncryptor.QueryFilterModeConsistentTokenization), + tokenEncryptor: tokenEncryptor, + coder: &queryEncryptor.PostgresqlDBDataCoder{}, + } +} + +// NewMySQLTokenizeQuery return TokenizeQuery with coder for mysql +func NewMySQLTokenizeQuery(schemaStore config.TableSchemaStore, tokenEncryptor *TokenEncryptor) *TokenizeQuery { + return &TokenizeQuery{ + searchableQueryFilter: queryEncryptor.NewSearchableQueryFilter(schemaStore, queryEncryptor.QueryFilterModeConsistentTokenization), + tokenEncryptor: tokenEncryptor, + } +} + +// ID returns name of this QueryObserver. +func (encryptor *TokenizeQuery) ID() string { + return "TokenizeQuery" +} + +// OnQuery processes query text before database sees it. +// +// Tokenized searchable encryption rewrites WHERE clauses with equality comparisons like this: +// +// WHERE column = 'value' ===> WHERE column = tokenize('value') +// +// If the query is a parameterized prepared query then OnQuery() rewriting yields this: +// +// WHERE column = $1 ===> WHERE column = tokenize($1) +// +// and actual "value" is passed via parameters later. See OnBind() for details. +func (encryptor *TokenizeQuery) OnQuery(ctx context.Context, query base.OnQueryObject) (base.OnQueryObject, bool, error) { + logrus.Debugln("TokenizeQuery.OnQuery") + stmt, err := query.Statement() + if err != nil { + logrus.WithError(err).Debugln("Can't parse SQL statement") + return query, false, err + } + + // Extract the subexpressions that we are interested in for encryption. + // The list might be empty for non-SELECT queries or for non-eligible SELECTs. + // In that case we don't have any more work to do here. + items := encryptor.searchableQueryFilter.FilterSearchableComparisons(stmt) + if len(items) == 0 { + return query, false, nil + } + clientSession := base.ClientSessionFromContext(ctx) + bindSettings := queryEncryptor.PlaceholderSettingsFromClientSession(clientSession) + for _, item := range items { + rightVal, ok := item.Expr.Right.(*sqlparser.SQLVal) + if !ok { + logrus.Debugln("expect SQLVal as Right expression for searchable consistent tokenization") + continue + } + + err = queryEncryptor.UpdateExpressionValue(ctx, item.Expr.Right, encryptor.coder, encryptor.tokenizerDataWithSetting(item.Setting)) + if err != nil { + logrus.WithError(err).Debugln("Failed to update expression") + return query, false, err + } + + placeholderIndex, err := queryEncryptor.ParsePlaceholderIndex(rightVal) + if err == queryEncryptor.ErrInvalidPlaceholder { + continue + } else if err != nil { + return query, false, err + } + bindSettings[placeholderIndex] = item.Setting + } + logrus.Debugln("TokenizeQuery.OnQuery changed query") + return base.NewOnQueryObjectFromStatement(stmt, nil), true, nil +} + +// OnBind processes bound values for prepared statements. +// +// Searchable tokenization rewrites WHERE clauses with equality comparisons like this: +// +// WHERE column = 'value' ===> WHERE column = tokenize('value') +// +// If the query is a parameterized prepared query then OnQuery() rewriting yields this: +// +// WHERE column = $1 ===> WHERE column = tokenize($1) +// +// and actual "value" is passed via parameters, visible here in OnBind(). +func (encryptor *TokenizeQuery) OnBind(ctx context.Context, statement sqlparser.Statement, values []base.BoundValue) ([]base.BoundValue, bool, error) { + logrus.Debugln("TokenizeQuery.OnBind") + // Extract the subexpressions that we are interested in for searchable encryption. + // The list might be empty for non-SELECT queries or for non-eligible SELECTs. + // In that case we don't have any more work to do here. + items := encryptor.searchableQueryFilter.FilterSearchableComparisons(statement) + if len(items) == 0 { + return values, false, nil + } + // Now that we have expressions, analyze them to look for involved placeholders + // and map them onto values that we need to update. + indexes := make([]int, 0, len(values)) + for _, item := range items { + switch value := item.Expr.Right.(type) { + case *sqlparser.SQLVal: + var err error + index, err := queryEncryptor.ParsePlaceholderIndex(value) + if err != nil { + return values, false, err + } + if index >= len(values) { + logrus.WithFields(logrus.Fields{"placeholder": value.Val, "index": index, "values": len(values)}). + Warning("Invalid placeholder index") + return values, false, queryEncryptor.ErrInvalidPlaceholder + } + indexes = append(indexes, index) + } + } + // Finally, once we know which values to replace with tokenized values, do this replacement. + return encryptor.replaceValuesWithTokenizedData(ctx, values, indexes) +} + +func (encryptor *TokenizeQuery) replaceValuesWithTokenizedData(ctx context.Context, values []base.BoundValue, placeholders []int) ([]base.BoundValue, bool, error) { + // If there are no interesting placholder positions then we don't have to process anything. + if len(placeholders) == 0 { + return values, false, nil + } + // Otherwise, decrypt values at positions indicated by placeholders and replace them with their HMACs. + newValues := make([]base.BoundValue, len(values)) + copy(newValues, values) + clientSession := base.ClientSessionFromContext(ctx) + bindData := queryEncryptor.PlaceholderSettingsFromClientSession(clientSession) + + for _, valueIndex := range placeholders { + var encryptionSetting config.ColumnEncryptionSetting = nil + if bindData != nil { + setting, ok := bindData[valueIndex] + if ok { + encryptionSetting = setting + } + } + + data, err := values[valueIndex].GetData(encryptionSetting) + if err != nil { + return values, false, err + } + + tokenize := encryptor.tokenizerDataWithSetting(encryptionSetting) + + tokenized, err := tokenize(ctx, data) + if err != nil { + logrus.WithError(err).WithField("index", valueIndex).Debug("Failed to encrypt column") + return values, false, err + } + // it is ok to ignore the error if not column setting provided + _ = newValues[valueIndex].SetData(tokenized, encryptionSetting) + } + return newValues, true, nil +} + +func (encryptor *TokenizeQuery) tokenizerDataWithSetting(setting config.ColumnEncryptionSetting) func(ctx context.Context, dataToTokenize []byte) (tokenized []byte, err error) { + return func(ctx context.Context, dataToTokenize []byte) (tokenized []byte, err error) { + accessContext := base.AccessContextFromContext(ctx) + + if accessContext.IsWithZone() { + tokenized, err = encryptor.tokenEncryptor.EncryptWithZoneID(accessContext.GetZoneID(), dataToTokenize, setting) + } else { + tokenized, err = encryptor.tokenEncryptor.EncryptWithClientID(accessContext.GetClientID(), dataToTokenize, setting) + } + if err != nil { + logrus.WithError(err).Debugln("Failed to tokenize value") + return nil, err + } + return + } +} diff --git a/pseudonymization/tokenizeQuery_test.go b/pseudonymization/tokenizeQuery_test.go new file mode 100644 index 000000000..48e80b9c3 --- /dev/null +++ b/pseudonymization/tokenizeQuery_test.go @@ -0,0 +1,115 @@ +package pseudonymization + +import ( + "context" + "fmt" + "testing" + + "github.com/cossacklabs/acra/decryptor/base" + "github.com/cossacklabs/acra/decryptor/base/mocks" + encryptor2 "github.com/cossacklabs/acra/encryptor" + "github.com/cossacklabs/acra/encryptor/config" + "github.com/cossacklabs/acra/pseudonymization/common" + "github.com/cossacklabs/acra/pseudonymization/storage" + "github.com/cossacklabs/acra/sqlparser" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// TestSearchableTokenizationWithTextFormat process searchable SELECT query with placeholder for prepared statement +// and use binding values in text format +func TestSearchableTokenizationWithTextFormat(t *testing.T) { + schemaConfigTemplate := ` +schemas: + - table: test_table + columns: + - data1 + encrypted: + - column: data1 + token_type: %s + consistent_tokenization: true +` + tokenStorage, err := storage.NewMemoryTokenStorage() + assert.NoError(t, err) + + anonymizer, err := NewPseudoanonymizer(tokenStorage) + assert.NoError(t, err) + + tokenizer, err := NewDataTokenizer(anonymizer) + assert.NoError(t, err) + + tokenEncryptor, err := NewTokenEncryptor(tokenizer) + assert.NoError(t, err) + + clientSession := &mocks.ClientSession{} + sessionData := make(map[string]interface{}, 2) + clientSession.On("GetData", mock.Anything).Return(func(key string) interface{} { + return sessionData[key] + }, func(key string) bool { + _, ok := sessionData[key] + return ok + }) + clientSession.On("DeleteData", mock.Anything).Run(func(args mock.Arguments) { + delete(sessionData, args[0].(string)) + }) + clientSession.On("SetData", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + sessionData[args[0].(string)] = args[1] + }) + + ctx := base.SetClientSessionToContext(context.Background(), clientSession) + + accessContext := base.NewAccessContext(base.WithClientID([]byte("client-id"))) + ctx = base.SetAccessContextToContext(ctx, accessContext) + + parser := sqlparser.New(sqlparser.ModeDefault) + + randomBytes := make([]byte, 10) + randomRead(randomBytes) + + type testcase struct { + Value []byte + Type common.TokenType + TokenType string + Query string + } + testcases := []testcase{ + {Value: []byte("somedata"), Type: common.TokenType_String, TokenType: "str", Query: "select data1 from test_table where data1='somedata'"}, + {Value: []byte("333"), Type: common.TokenType_Int32, TokenType: "int32", Query: "select data1 from test_table where data1=333"}, + {Value: []byte("33333333333333333"), Type: common.TokenType_Int64, TokenType: "int64", Query: "select data1 from test_table where data1=33333333333333333"}, + {Value: []byte("test@gmail.com"), Type: common.TokenType_Email, TokenType: "email", Query: "select data1 from test_table where data1='test@gmail.com'"}, + {Value: randomBytes, Type: common.TokenType_Bytes, TokenType: "bytes", Query: fmt.Sprintf("select data1 from test_table where data1='%s'", encryptor2.PgEncodeToHexString(randomBytes))}, + } + + for _, tcase := range testcases { + schema, err := config.MapTableSchemaStoreFromConfig([]byte(fmt.Sprintf(schemaConfigTemplate, tcase.TokenType))) + assert.NoError(t, err) + + encryptor := NewPostgresqlTokenizeQuery(schema, tokenEncryptor) + + setting := config.BasicColumnEncryptionSetting{ + TokenType: tcase.TokenType, + ConsistentTokenization: true, + } + anonimized, err := tokenizer.Tokenize(tcase.Value, common.TokenContext{ClientID: []byte("client-id")}, &setting) + assert.NoError(t, err) + + newQuery, ok, err := encryptor.OnQuery(ctx, base.NewOnQueryObjectFromQuery(tcase.Query, parser)) + assert.NoError(t, err) + assert.True(t, ok) + + stmt, err := newQuery.Statement() + assert.NoError(t, err) + + selectQuery := stmt.(*sqlparser.Select) + + whereExpr := selectQuery.Where.Expr.(*sqlparser.ComparisonExpr) + rightExpr := whereExpr.Right.(*sqlparser.SQLVal) + + if tcase.Type == common.TokenType_Bytes { + assert.Equal(t, rightExpr.Val, encryptor2.PgEncodeToHexString(anonimized)) + continue + } + + assert.Equal(t, rightExpr.Val, anonimized) + } +} diff --git a/tests/ee_searchable_tokenization_config.yaml b/tests/ee_searchable_tokenization_config.yaml new file mode 100644 index 000000000..7ec9eb549 --- /dev/null +++ b/tests/ee_searchable_tokenization_config.yaml @@ -0,0 +1,27 @@ +schemas: + - table: test_tokenization_default_client_id + columns: + - id + - nullable + - empty + - token_i32 + - token_i64 + - token_str + - token_bytes + - token_email + encrypted: + - column: token_i32 + token_type: int32 + consistent_tokenization: true + - column: token_i64 + token_type: int64 + consistent_tokenization: true + - column: token_str + token_type: str + consistent_tokenization: true + - column: token_bytes + token_type: bytes + consistent_tokenization: true + - column: token_email + token_type: email + consistent_tokenization: true \ No newline at end of file diff --git a/tests/test.py b/tests/test.py index 83ffb61e9..4b55cb534 100644 --- a/tests/test.py +++ b/tests/test.py @@ -1987,6 +1987,11 @@ def compileQuery(self, query, parameters={}, literal_binds=False): # SQLAlchemy default dialect has placeholders of form ":name". # PostgreSQL syntax is "$n", with 1-based sequential parameters. saPlaceholder = ':' + placeholder + # SQLAlchemy has placeholders of form ":name_1" for literal value + # https://docs.sqlalchemy.org/en/14/core/tutorial.html#operators + saPlaceholderIndex = saPlaceholder + '_' + str(len(values) + 1) + if saPlaceholderIndex in query: + saPlaceholder = saPlaceholderIndex pgPlaceholder = '$' + str(len(values) + 1) # Replace and keep values only for those placeholders which # are actually used in the query. @@ -7218,7 +7223,7 @@ def insert_via_1_bulk(self, query, values): """Execute SQLAlchemy Bulk INSERT query via AcraServer with "TEST_TLS_CLIENT_CERT".""" self.engine1.execute(query.values(values)) - def fetch_from_1(self, query): + def fetch_from_1(self, query, parameters={}, literal_binds=True): """Execute SQLAlchemy SELECT query via AcraServer with "TEST_TLS_CLIENT_CERT".""" return self.engine1.execute(query).fetchall() @@ -7270,8 +7275,8 @@ def insert_via_1_bulk(self, query, values): query, parameters = self.compileBulkInsertQuery(query.values(values), values) return self.executor1.execute_prepared_statement_no_result(query, parameters) - def fetch_from_1(self, query): - query, parameters = self.compileQuery(query, literal_binds=True) + def fetch_from_1(self, query, parameters={}, literal_binds=True): + query, parameters = self.compileQuery(query, parameters=parameters, literal_binds=literal_binds) return self.executor1.execute_prepared_statement(query, parameters) def fetch_from_2(self, query): @@ -7297,8 +7302,8 @@ def insert_via_1_bulk(self, query, values): query, parameters = self.compileBulkInsertQuery(query.values(values), values) return self.executor1.execute_prepared_statement(query, parameters) - def fetch_from_1(self, query): - query, parameters = self.compileQuery(query, literal_binds=True) + def fetch_from_1(self, query, parameters={}, literal_binds=True): + query, parameters = self.compileQuery(query, parameters=parameters, literal_binds=literal_binds) return self.executor1.execute_prepared_statement(query, parameters) def fetch_from_2(self, query): @@ -7362,6 +7367,59 @@ def execute(self, query, ssl_key, ssl_cert): return result +class TestSearchableTokenizationWithoutZone(BaseTokenization): + ZONE = False + ENCRYPTOR_CONFIG = get_encryptor_config('tests/ee_searchable_tokenization_config.yaml') + + def testSearchableTokenizationDefaultClientID(self): + default_client_id_table = sa.Table( + 'test_tokenization_default_client_id', metadata, + sa.Column('id', sa.Integer, primary_key=True), + sa.Column('nullable_column', sa.Text, nullable=True), + sa.Column('empty', sa.LargeBinary(length=COLUMN_DATA_SIZE), nullable=False, default=b''), + sa.Column('token_i32', sa.Integer()), + sa.Column('token_i64', sa.BigInteger()), + sa.Column('token_str', sa.Text), + sa.Column('token_bytes', sa.LargeBinary(length=COLUMN_DATA_SIZE), nullable=False, default=b''), + sa.Column('token_email', sa.Text), + extend_existing=True, + ) + metadata.create_all(self.engine_raw, [default_client_id_table]) + self.engine1.execute(default_client_id_table.delete()) + data = { + 'id': 1, + 'nullable_column': None, + 'empty': b'', + 'token_i32': random_int32(), + 'token_i64': random_int64(), + 'token_str': random_str(), + 'token_bytes': random_bytes(), + 'token_email': random_email(), + } + + # insert data data + self.insert_via_1(default_client_id_table.insert(), data) + + columns = { + 'token_i32': default_client_id_table.c.token_i32, + 'token_i64': default_client_id_table.c.token_i64, + 'token_str': default_client_id_table.c.token_str, + 'token_bytes': default_client_id_table.c.token_bytes, + 'token_email': default_client_id_table.c.token_email, + } + # data owner take source data + for key in columns: + parameters = {key: data[key]} + query = sa.select(default_client_id_table).where(columns[key] == data[key]) + + source_data = self.fetch_from_1(query, parameters, literal_binds=False) + for k in ('token_i32', 'token_i64', 'token_str', 'token_bytes', 'token_email'): + if isinstance(source_data[0][k], (bytearray, bytes)) and isinstance(data[k], str): + self.assertEqual(source_data[0][k], data[k].encode('utf-8')) + else: + self.assertEqual(source_data[0][k], data[k]) + + class TestTokenizationWithoutZone(BaseTokenization): ZONE = False @@ -7889,6 +7947,10 @@ class TestTokenizationWithoutZoneBinaryPostgreSQL(BaseTokenizationWithBinaryPost pass +class TestSearchableTokenizationWithoutZoneBinaryPostgreSQL(BaseTokenizationWithBinaryPostgreSQL, TestSearchableTokenizationWithoutZone): + pass + + class TestTokenizationWithoutZoneBinaryPostgreSQLWithAWSKMSMaterKeyLoading(AWSKMSMasterKeyLoaderMixin, BaseTokenizationWithBinaryPostgreSQL, TestTokenizationWithoutZone): pass @@ -10369,7 +10431,7 @@ def executor_with_ssl(ssl_key, ssl_cert, port=self.ACRASERVER_PORT): def testPreparedStatementIsNotAborted(self): """ - Test that connection is not closed in case of "encoding error" when we + Test that connection is not closed in case of "encoding error" when we use prepared statements. """ async def test(): From 7a918300f9a0ccff2d46241f23ce8e826959f693 Mon Sep 17 00:00:00 2001 From: Artem Zhmaka Date: Wed, 21 Sep 2022 21:25:16 +0200 Subject: [PATCH 2/8] zhars/implement_searchable_tokenization Updated CHANGELOG_DEV.md file --- CHANGELOG_DEV.md | 3 +++ hmac/decryptor/hashQuery.go | 13 +++++++++++++ 2 files changed, 16 insertions(+) diff --git a/CHANGELOG_DEV.md b/CHANGELOG_DEV.md index 1d7be6813..ac7eb3f53 100644 --- a/CHANGELOG_DEV.md +++ b/CHANGELOG_DEV.md @@ -1,3 +1,6 @@ +# 0.94.0 - 2022-09-21 +- Implemented searchable tokenization for PostgreSQL for text/binary protocols + # 0.94.0 - 2022-08-25 - Add support of Hashicorp Consul for `encryptor_config loading`. - Introduce new Hashicorp Consul flags: `consul_connection_api_string` and `consul_kv_config_path` and corresponded `consul` TLS configuration flags. diff --git a/hmac/decryptor/hashQuery.go b/hmac/decryptor/hashQuery.go index f20983010..bf7ed5df2 100644 --- a/hmac/decryptor/hashQuery.go +++ b/hmac/decryptor/hashQuery.go @@ -1,3 +1,16 @@ +/* +Copyright 2018, Cossack Labs Limited +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package decryptor import ( From 9fa050c30f640417a0c6b5eca2a3f69c91e8e968 Mon Sep 17 00:00:00 2001 From: Artem Zhmaka Date: Wed, 21 Sep 2022 21:27:03 +0200 Subject: [PATCH 3/8] zhars/implement_searchable_tokenization Fixed licence comment --- hmac/decryptor/hashQuery.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/hmac/decryptor/hashQuery.go b/hmac/decryptor/hashQuery.go index bf7ed5df2..b7349995e 100644 --- a/hmac/decryptor/hashQuery.go +++ b/hmac/decryptor/hashQuery.go @@ -1,9 +1,12 @@ /* Copyright 2018, Cossack Labs Limited + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. From 4277bf9e6d03868d266fddd8105fcf3f667fb2ef Mon Sep 17 00:00:00 2001 From: Artem Zhmaka Date: Thu, 22 Sep 2022 22:43:39 +0200 Subject: [PATCH 4/8] zhars/implement_searchable_tokenization Added MySQL support --- decryptor/mysql/proxy.go | 3 +++ pseudonymization/tokenizeQuery.go | 1 + tests/test.py | 7 +++++++ 3 files changed, 11 insertions(+) diff --git a/decryptor/mysql/proxy.go b/decryptor/mysql/proxy.go index a120e1735..bcfee394c 100644 --- a/decryptor/mysql/proxy.go +++ b/decryptor/mysql/proxy.go @@ -111,6 +111,9 @@ func (factory *proxyFactory) New(clientID []byte, clientSession base.ClientSessi return nil, err } chainEncryptors = append(chainEncryptors, tokenEncryptor) + + acraBlockStructTokenEncryptor := pseudonymization.NewMySQLTokenizeQuery(schemaStore, tokenEncryptor) + proxy.AddQueryObserver(acraBlockStructTokenEncryptor) } chainEncryptors = append(chainEncryptors, crypto.NewEncryptHandler(registryHandler)) diff --git a/pseudonymization/tokenizeQuery.go b/pseudonymization/tokenizeQuery.go index 131aa0af6..f97023db9 100644 --- a/pseudonymization/tokenizeQuery.go +++ b/pseudonymization/tokenizeQuery.go @@ -31,6 +31,7 @@ func NewMySQLTokenizeQuery(schemaStore config.TableSchemaStore, tokenEncryptor * return &TokenizeQuery{ searchableQueryFilter: queryEncryptor.NewSearchableQueryFilter(schemaStore, queryEncryptor.QueryFilterModeConsistentTokenization), tokenEncryptor: tokenEncryptor, + coder: &queryEncryptor.MysqlDBDataCoder{}, } } diff --git a/tests/test.py b/tests/test.py index 0b4934590..976a4c56c 100644 --- a/tests/test.py +++ b/tests/test.py @@ -2126,6 +2126,9 @@ def compileQuery(self, query, parameters={}, literal_binds=False): for placeholder in res: # parameters map contain values where keys without ':' so we need trim the placeholder before key = placeholder.lstrip(':') + index_suffix = '_' + str(len(values) + 1) + if index_suffix in key: + key = key.rstrip(index_suffix) values.append(parameters[key]) query = query.replace(placeholder, '?') return query, tuple(values) @@ -7965,6 +7968,10 @@ class TestTokenizationWithZoneBinaryBindMySQL(BaseTokenizationWithBinaryBindMySQ pass +class TestSearchableTokenizationWithoutZoneBinaryBindMySQL(BaseTokenizationWithBinaryBindMySQL, TestSearchableTokenizationWithoutZone): + pass + + class BaseMasking(BaseTokenization): WHOLECELL_MODE = False ENCRYPTOR_CONFIG = get_encryptor_config('tests/ee_masking_config.yaml') From 8230fee4b07015fd1449fc436760e2fd786ee008 Mon Sep 17 00:00:00 2001 From: Artem Zhmaka Date: Tue, 27 Sep 2022 17:30:15 +0200 Subject: [PATCH 5/8] zhars/implement_searchable_tokenization Extend SQL syntax for searchable tokenization --- CHANGELOG_DEV.md | 2 +- encryptor/searchable_query_filter.go | 53 +++++++++++++---- pseudonymization/tokenizeQuery.go | 13 +++-- pseudonymization/tokenizeQuery_test.go | 46 ++++++++++++--- tests/test.py | 79 ++++++++++++++++++++++++-- 5 files changed, 164 insertions(+), 29 deletions(-) diff --git a/CHANGELOG_DEV.md b/CHANGELOG_DEV.md index bfa164905..ec833e6a9 100644 --- a/CHANGELOG_DEV.md +++ b/CHANGELOG_DEV.md @@ -1,5 +1,5 @@ # 0.94.0 - 2022-09-21 -- Implemented searchable tokenization for PostgreSQL for text/binary protocols +- Implemented searchable tokenization for PostgreSQL/MySQL for text/binary protocols # 0.94.0 - 2022-09-19 - Accumulate packets in a queue and handle paired packets in the correct order. Fixes issue with incorrectly linked Bind packet to inappropriate Parse packet and nil dereferences. diff --git a/encryptor/searchable_query_filter.go b/encryptor/searchable_query_filter.go index 147f195ff..acc7fe8c5 100644 --- a/encryptor/searchable_query_filter.go +++ b/encryptor/searchable_query_filter.go @@ -17,11 +17,15 @@ limitations under the License. package encryptor import ( + "errors" + "github.com/cossacklabs/acra/encryptor/config" "github.com/cossacklabs/acra/sqlparser" "github.com/sirupsen/logrus" ) +var ErrUnsupportedQueryType = errors.New("unsupported Query type") + // SearchableQueryFilterMode represent the mode work of SearchableQueryFilter type SearchableQueryFilterMode int @@ -37,7 +41,7 @@ type SearchableExprItem struct { Setting config.ColumnEncryptionSetting } -// SearchableQueryFilter filter se +// SearchableQueryFilter filter searchable expression based on SearchableQueryFilterMode type SearchableQueryFilter struct { mode SearchableQueryFilterMode schemaStore config.TableSchemaStore @@ -53,9 +57,13 @@ func NewSearchableQueryFilter(schemaStore config.TableSchemaStore, mode Searchab // FilterSearchableComparisons filter search comparisons from statement func (filter *SearchableQueryFilter) FilterSearchableComparisons(statement sqlparser.Statement) []SearchableExprItem { - // We are interested only in SELECT statements which access at least one encryptable table. - // If that's not the case, we have nothing to do here. - defaultTable, aliasedTables := filter.filterInterestingTables(statement) + tableExps, err := filter.filterTableExpressions(statement) + if err != nil { + logrus.Debugln("Unsupported search query") + return nil + } + + defaultTable, aliasedTables := filter.filterInterestingTables(tableExps) if len(aliasedTables) == 0 { logrus.Debugln("No encryptable tables in search query") return nil @@ -77,14 +85,9 @@ func (filter *SearchableQueryFilter) FilterSearchableComparisons(statement sqlpa return searchableExprs } -func (filter *SearchableQueryFilter) filterInterestingTables(statement sqlparser.Statement) (*AliasedTableName, AliasToTableMap) { - // We are interested only in SELECT statements. - selectStatement, ok := statement.(*sqlparser.Select) - if !ok { - return nil, nil - } +func (filter *SearchableQueryFilter) filterInterestingTables(fromExp sqlparser.TableExprs) (*AliasedTableName, AliasToTableMap) { // Not all SELECT statements refer to tables at all. - tables := GetTablesWithAliases(selectStatement.From) + tables := GetTablesWithAliases(fromExp) if len(tables) == 0 { return nil, nil } @@ -101,6 +104,34 @@ func (filter *SearchableQueryFilter) filterInterestingTables(statement sqlparser return tables[0], NewAliasToTableMapFromTables(encryptableTables) } +func (filter *SearchableQueryFilter) filterTableExpressions(statement sqlparser.Statement) (sqlparser.TableExprs, error) { + if filter.mode == QueryFilterModeConsistentTokenization { + switch query := statement.(type) { + case *sqlparser.Select: + return query.From, nil + case *sqlparser.Update: + return query.TableExprs, nil + case *sqlparser.Delete: + return query.TableExprs, nil + case *sqlparser.Insert: + // only support INSERT INTO table2 SELECT * FROM test_table WHERE data1='somedata' syntax for INSERTs + if selectInInsert, ok := query.Rows.(*sqlparser.Select); ok { + return selectInInsert.From, nil + } + return nil, ErrUnsupportedQueryType + default: + return nil, ErrUnsupportedQueryType + } + } + + // TODO: extend with more query types for support for searchable encryption + selectStatement, ok := statement.(*sqlparser.Select) + if ok { + return selectStatement.From, nil + } + return nil, ErrUnsupportedQueryType +} + func (filter *SearchableQueryFilter) filterComparisonExprs(statement sqlparser.Statement) []*sqlparser.ComparisonExpr { // Walk through WHERE clauses of a SELECT statements... whereExprs, err := getWhereStatements(statement) diff --git a/pseudonymization/tokenizeQuery.go b/pseudonymization/tokenizeQuery.go index f97023db9..14360f201 100644 --- a/pseudonymization/tokenizeQuery.go +++ b/pseudonymization/tokenizeQuery.go @@ -75,7 +75,7 @@ func (encryptor *TokenizeQuery) OnQuery(ctx context.Context, query base.OnQueryO continue } - err = queryEncryptor.UpdateExpressionValue(ctx, item.Expr.Right, encryptor.coder, encryptor.tokenizerDataWithSetting(item.Setting)) + err = queryEncryptor.UpdateExpressionValue(ctx, item.Expr.Right, encryptor.coder, encryptor.getTokenizerDataWithSetting(item.Setting)) if err != nil { logrus.WithError(err).Debugln("Failed to update expression") return query, false, err @@ -156,14 +156,17 @@ func (encryptor *TokenizeQuery) replaceValuesWithTokenizedData(ctx context.Conte } } + if encryptionSetting == nil { + continue + } + data, err := values[valueIndex].GetData(encryptionSetting) if err != nil { return values, false, err } - tokenize := encryptor.tokenizerDataWithSetting(encryptionSetting) - - tokenized, err := tokenize(ctx, data) + tokenizeFunc := encryptor.getTokenizerDataWithSetting(encryptionSetting) + tokenized, err := tokenizeFunc(ctx, data) if err != nil { logrus.WithError(err).WithField("index", valueIndex).Debug("Failed to encrypt column") return values, false, err @@ -174,7 +177,7 @@ func (encryptor *TokenizeQuery) replaceValuesWithTokenizedData(ctx context.Conte return newValues, true, nil } -func (encryptor *TokenizeQuery) tokenizerDataWithSetting(setting config.ColumnEncryptionSetting) func(ctx context.Context, dataToTokenize []byte) (tokenized []byte, err error) { +func (encryptor *TokenizeQuery) getTokenizerDataWithSetting(setting config.ColumnEncryptionSetting) func(ctx context.Context, dataToTokenize []byte) (tokenized []byte, err error) { return func(ctx context.Context, dataToTokenize []byte) (tokenized []byte, err error) { accessContext := base.AccessContextFromContext(ctx) diff --git a/pseudonymization/tokenizeQuery_test.go b/pseudonymization/tokenizeQuery_test.go index 48e80b9c3..9dc799acd 100644 --- a/pseudonymization/tokenizeQuery_test.go +++ b/pseudonymization/tokenizeQuery_test.go @@ -24,6 +24,7 @@ schemas: - table: test_table columns: - data1 + - data2 encrypted: - column: data1 token_type: %s @@ -56,9 +57,10 @@ schemas: sessionData[args[0].(string)] = args[1] }) + clientID := []byte("client-id") ctx := base.SetClientSessionToContext(context.Background(), clientSession) - accessContext := base.NewAccessContext(base.WithClientID([]byte("client-id"))) + accessContext := base.NewAccessContext(base.WithClientID(clientID)) ctx = base.SetAccessContextToContext(ctx, accessContext) parser := sqlparser.New(sqlparser.ModeDefault) @@ -73,7 +75,14 @@ schemas: Query string } testcases := []testcase{ + {Value: []byte("somedata"), Type: common.TokenType_String, TokenType: "str", Query: "INSERT INTO table2 SELECT * FROM test_table WHERE data1='somedata';"}, + {Value: []byte("test@gmail.com"), Type: common.TokenType_Email, TokenType: "email", Query: "INSERT INTO table2 SELECT * FROM test_table WHERE data1='test@gmail.com' and data2='ignoreddata';"}, + {Value: []byte("somedata"), Type: common.TokenType_String, TokenType: "str", Query: "UPDATE test_table SET kind = 'Dramatic' WHERE data1='somedata';"}, + {Value: []byte("4444"), Type: common.TokenType_Int32, TokenType: "int32", Query: "UPDATE test_table SET kind = 'Dramatic' WHERE data1=4444 and data2='ignoreddata';"}, + {Value: []byte("somedata"), Type: common.TokenType_String, TokenType: "str", Query: "DELETE FROM test_table WHERE data1='somedata';"}, + {Value: randomBytes, Type: common.TokenType_Bytes, TokenType: "bytes", Query: fmt.Sprintf("DELETE FROM test_table where data1='%s' or data2='ignoreddata'", encryptor2.PgEncodeToHexString(randomBytes))}, {Value: []byte("somedata"), Type: common.TokenType_String, TokenType: "str", Query: "select data1 from test_table where data1='somedata'"}, + {Value: []byte("somedata"), Type: common.TokenType_String, TokenType: "str", Query: "select data1 from test_table where data1='somedata' and data2='ignoreddata'"}, {Value: []byte("333"), Type: common.TokenType_Int32, TokenType: "int32", Query: "select data1 from test_table where data1=333"}, {Value: []byte("33333333333333333"), Type: common.TokenType_Int64, TokenType: "int64", Query: "select data1 from test_table where data1=33333333333333333"}, {Value: []byte("test@gmail.com"), Type: common.TokenType_Email, TokenType: "email", Query: "select data1 from test_table where data1='test@gmail.com'"}, @@ -90,7 +99,7 @@ schemas: TokenType: tcase.TokenType, ConsistentTokenization: true, } - anonimized, err := tokenizer.Tokenize(tcase.Value, common.TokenContext{ClientID: []byte("client-id")}, &setting) + anonymized, err := tokenizer.Tokenize(tcase.Value, common.TokenContext{ClientID: clientID}, &setting) assert.NoError(t, err) newQuery, ok, err := encryptor.OnQuery(ctx, base.NewOnQueryObjectFromQuery(tcase.Query, parser)) @@ -100,16 +109,39 @@ schemas: stmt, err := newQuery.Statement() assert.NoError(t, err) - selectQuery := stmt.(*sqlparser.Select) + var whereExp *sqlparser.Where + switch query := stmt.(type) { + case *sqlparser.Select: + whereExp = query.Where + case *sqlparser.Update: + whereExp = query.Where + case *sqlparser.Delete: + whereExp = query.Where + case *sqlparser.Insert: + selectInInsert, ok := query.Rows.(*sqlparser.Select) + if !ok { + panic("expect INSERT FROM SELECT queries") + } + whereExp = selectInInsert.Where + } - whereExpr := selectQuery.Where.Expr.(*sqlparser.ComparisonExpr) - rightExpr := whereExpr.Right.(*sqlparser.SQLVal) + var rightExpr *sqlparser.SQLVal + switch expr := whereExp.Expr.(type) { + case *sqlparser.ComparisonExpr: + rightExpr = expr.Right.(*sqlparser.SQLVal) + case *sqlparser.AndExpr: + rightExpr = expr.Left.(*sqlparser.ComparisonExpr).Right.(*sqlparser.SQLVal) + assert.Equal(t, expr.Right.(*sqlparser.ComparisonExpr).Right.(*sqlparser.SQLVal).Val, []byte("ignoreddata")) + case *sqlparser.OrExpr: + rightExpr = expr.Left.(*sqlparser.ComparisonExpr).Right.(*sqlparser.SQLVal) + assert.Equal(t, expr.Right.(*sqlparser.ComparisonExpr).Right.(*sqlparser.SQLVal).Val, []byte("ignoreddata")) + } if tcase.Type == common.TokenType_Bytes { - assert.Equal(t, rightExpr.Val, encryptor2.PgEncodeToHexString(anonimized)) + assert.Equal(t, rightExpr.Val, encryptor2.PgEncodeToHexString(anonymized)) continue } - assert.Equal(t, rightExpr.Val, anonimized) + assert.Equal(t, rightExpr.Val, anonymized) } } diff --git a/tests/test.py b/tests/test.py index 976a4c56c..3661e0313 100644 --- a/tests/test.py +++ b/tests/test.py @@ -1982,15 +1982,17 @@ def compileQuery(self, query, parameters={}, literal_binds=False): compile_kwargs = {"literal_binds": literal_binds} query = str(query.compile(compile_kwargs=compile_kwargs)) values = [] + param_counter = 1 for placeholder, value in parameters.items(): # SQLAlchemy default dialect has placeholders of form ":name". # PostgreSQL syntax is "$n", with 1-based sequential parameters. saPlaceholder = ':' + placeholder # SQLAlchemy has placeholders of form ":name_1" for literal value # https://docs.sqlalchemy.org/en/14/core/tutorial.html#operators - saPlaceholderIndex = saPlaceholder + '_' + str(len(values) + 1) + saPlaceholderIndex = saPlaceholder + '_' + str(param_counter) if saPlaceholderIndex in query: saPlaceholder = saPlaceholderIndex + param_counter += 1 pgPlaceholder = '$' + str(len(values) + 1) # Replace and keep values only for those placeholders which # are actually used in the query. @@ -2122,13 +2124,16 @@ def compileQuery(self, query, parameters={}, literal_binds=False): # parse all parameters like `:id` in the query pattern_string = r'(:\w+)' res = re.findall(pattern_string, query, re.IGNORECASE | re.DOTALL) + param_counter = 1 if len(res) > 0: for placeholder in res: # parameters map contain values where keys without ':' so we need trim the placeholder before key = placeholder.lstrip(':') - index_suffix = '_' + str(len(values) + 1) - if index_suffix in key: - key = key.rstrip(index_suffix) + if key not in parameters.keys(): + index_suffix = '_' + str(param_counter) + if index_suffix in key: + key = key.rstrip(index_suffix) + param_counter += 1 values.append(parameters[key]) query = query.replace(placeholder, '?') return query, tuple(values) @@ -7221,6 +7226,10 @@ def insert_via_1(self, query, values): """Execute SQLAlchemy INSERT query via AcraServer with "TEST_TLS_CLIENT_CERT".""" return self.engine1.execute(query, values) + def execute_via_1(self, query, values): + """Execute SQLAlchemy execute query via AcraServer with "TEST_TLS_CLIENT_CERT".""" + return self.engine1.execute(query, values) + def insert_via_1_bulk(self, query, values): """Execute SQLAlchemy Bulk INSERT query via AcraServer with "TEST_TLS_CLIENT_CERT".""" self.engine1.execute(query.values(values)) @@ -7277,6 +7286,10 @@ def insert_via_1_bulk(self, query, values): query, parameters = self.compileBulkInsertQuery(query.values(values), values) return self.executor1.execute_prepared_statement_no_result(query, parameters) + def execute_via_1(self, query, values): + query, parameters = self.compileQuery(query, values) + self.executor1.execute_prepared_statement_no_result(query, parameters) + def fetch_from_1(self, query, parameters={}, literal_binds=True): query, parameters = self.compileQuery(query, parameters=parameters, literal_binds=literal_binds) return self.executor1.execute_prepared_statement(query, parameters) @@ -7304,6 +7317,10 @@ def insert_via_1_bulk(self, query, values): query, parameters = self.compileBulkInsertQuery(query.values(values), values) return self.executor1.execute_prepared_statement(query, parameters) + def execute_via_1(self, query, values): + query, parameters = self.compileQuery(query, values) + self.executor1.execute_prepared_statement(query, parameters) + def fetch_from_1(self, query, parameters={}, literal_binds=True): query, parameters = self.compileQuery(query, parameters=parameters, literal_binds=literal_binds) return self.executor1.execute_prepared_statement(query, parameters) @@ -7388,8 +7405,10 @@ def testSearchableTokenizationDefaultClientID(self): ) metadata.create_all(self.engine_raw, [default_client_id_table]) self.engine1.execute(default_client_id_table.delete()) + + row_id = 1 data = { - 'id': 1, + 'id': row_id, 'nullable_column': None, 'empty': b'', 'token_i32': random_int32(), @@ -7403,6 +7422,7 @@ def testSearchableTokenizationDefaultClientID(self): self.insert_via_1(default_client_id_table.insert(), data) columns = { + 'id': default_client_id_table.c.id, 'token_i32': default_client_id_table.c.token_i32, 'token_i64': default_client_id_table.c.token_i64, 'token_str': default_client_id_table.c.token_str, @@ -7421,6 +7441,55 @@ def testSearchableTokenizationDefaultClientID(self): else: self.assertEqual(source_data[0][k], data[k]) + new_token_str = random_str() + update_data = { + 'token_str': new_token_str, + 'token_i32': data['token_i32'] + } + + # test searchable tokenization in update where statements + query = sa.update(default_client_id_table).where(columns['token_i32'] == data['token_i32']).values(token_str=new_token_str) + self.execute_via_1(query, update_data) + + parameters = {key: data[key]} + query = sa.select(default_client_id_table).where(columns[key] == data[key]) + source_data = self.fetch_from_1(query, parameters, literal_binds=False) + + if isinstance(source_data[0][k], (bytearray, bytes)) and isinstance(data[k], str): + self.assertEqual(source_data[0]['token_str'], new_token_str.encode('utf-8')) + else: + self.assertEqual(source_data[0]['token_str'], new_token_str) + + row_id += 1 + insert_data = { + 'param_1': row_id, + 'token_i32': data['token_i32'] + } + select_columns = ['id', 'nullable_column', 'empty', 'token_i32', 'token_i64', 'token_str', 'token_bytes', 'token_email'] + select_query = sa.select(sa.literal(row_id).label('id'), sa.column('nullable_column'), sa.column('empty'), columns['token_i32'], columns['token_i64'], columns['token_str'], columns['token_bytes'], columns['token_email']).\ + where(columns['token_i32'] == data['token_i32']) + + query = sa.insert(default_client_id_table).from_select(select_columns, select_query) + self.execute_via_1(query, insert_data) + + # expect that data was encrypted with client_id which used to insert (client_id==keypair1) + source_data = self.fetch_from_1( + sa.select([default_client_id_table]) + .where(default_client_id_table.c.id == row_id)) + + for k in ('token_i32', 'token_i64', 'token_bytes', 'token_email'): + if isinstance(source_data[0][k], (bytearray, bytes)) and isinstance(data[k], str): + self.assertEqual(source_data[0][k], data[k].encode('utf-8')) + else: + self.assertEqual(source_data[0][k], data[k]) + + # test searchable tokenization in update where statements + query = sa.delete(default_client_id_table).where(columns['token_str'] == update_data['token_str']) + self.execute_via_1(query, update_data) + + source_data = self.fetch_from_1(sa.select([default_client_id_table])) + self.assertEqual(0, len(source_data)) + class TestTokenizationWithoutZone(BaseTokenization): ZONE = False From 3e61850ae8db642e6cc8060fdbea8f3b0a053683 Mon Sep 17 00:00:00 2001 From: Artem Zhmaka Date: Thu, 29 Sep 2022 18:07:34 +0200 Subject: [PATCH 6/8] zhars/extend_syntax_for_searchable_encryption Extend syntax for searchable encryption to suport INSERT/UPDATE/DELETE --- decryptor/postgresql/pending_packets.go | 20 +++--- encryptor/searchable_query_filter.go | 37 ++++------- hmac/decryptor/hashQuery_test.go | 73 +++++++++++--------- tests/test.py | 88 +++++++++++++++++++++++++ 4 files changed, 156 insertions(+), 62 deletions(-) diff --git a/decryptor/postgresql/pending_packets.go b/decryptor/postgresql/pending_packets.go index 174572779..bfcf35b9b 100644 --- a/decryptor/postgresql/pending_packets.go +++ b/decryptor/postgresql/pending_packets.go @@ -43,12 +43,13 @@ var ErrRemoveFromEmptyPendingList = errors.New("removing from empty pending list func (packets *pendingPacketsList) Add(packet interface{}) error { switch packet.(type) { case *ParsePacket, *BindPacket, *ExecutePacket, *pgproto3.RowDescription, *pgproto3.ParameterDescription: - packetList, ok := packets.lists[reflect.TypeOf(packet)] + packetType := reflect.TypeOf(packet) + packetList, ok := packets.lists[packetType] if !ok { packetList = list.New() packets.lists[reflect.TypeOf(packet)] = packetList } - log.WithField("packet", packet).Debugln("Add pending packet") + log.WithField("packet", packetType).Debugln("Add pending packet") packetList.PushBack(packet) return nil } @@ -59,7 +60,8 @@ func (packets *pendingPacketsList) Add(packet interface{}) error { func (packets *pendingPacketsList) RemoveNextPendingPacket(packet interface{}) error { switch packet.(type) { case *ParsePacket, *BindPacket, *ExecutePacket, *pgproto3.RowDescription, *pgproto3.ParameterDescription: - packetList, ok := packets.lists[reflect.TypeOf(packet)] + packetType := reflect.TypeOf(packet) + packetList, ok := packets.lists[packetType] if !ok { return ErrRemoveFromEmptyPendingList } @@ -67,7 +69,7 @@ func (packets *pendingPacketsList) RemoveNextPendingPacket(packet interface{}) e if currentElement == nil { return nil } - log.WithField("packet", currentElement.Value).Debugln("Remove pending packet") + log.WithField("packet", packetType).Debugln("Remove pending packet") packetList.Remove(currentElement) return nil } @@ -93,7 +95,8 @@ func (packets *pendingPacketsList) RemoveAll(packet interface{}) error { func (packets *pendingPacketsList) GetPendingPacket(packet interface{}) (interface{}, error) { switch packet.(type) { case *ParsePacket, *BindPacket, *ExecutePacket, *pgproto3.RowDescription, *pgproto3.ParameterDescription: - packetList, ok := packets.lists[reflect.TypeOf(packet)] + packetType := reflect.TypeOf(packet) + packetList, ok := packets.lists[packetType] if !ok { return nil, nil } @@ -101,7 +104,7 @@ func (packets *pendingPacketsList) GetPendingPacket(packet interface{}) (interfa if currentElement == nil { return nil, nil } - log.WithField("packet", currentElement.Value).Debugln("Return pending packet") + log.WithField("packet", packetType).Debugln("Return pending packet") return currentElement.Value, nil } return nil, ErrUnsupportedPendingPacketType @@ -111,7 +114,8 @@ func (packets *pendingPacketsList) GetPendingPacket(packet interface{}) (interfa func (packets *pendingPacketsList) GetLastPending(packet interface{}) (interface{}, error) { switch packet.(type) { case *ParsePacket, *BindPacket, *ExecutePacket, *pgproto3.RowDescription, *pgproto3.ParameterDescription: - packetList, ok := packets.lists[reflect.TypeOf(packet)] + packetType := reflect.TypeOf(packet) + packetList, ok := packets.lists[packetType] if !ok { return nil, nil } @@ -119,7 +123,7 @@ func (packets *pendingPacketsList) GetLastPending(packet interface{}) (interface if currentElement == nil { return nil, nil } - log.WithField("packet", currentElement.Value).Debugln("Return last added packet") + log.WithField("packet", packetType).Debugln("Return last added packet") return currentElement.Value, nil } return nil, ErrUnsupportedPendingPacketType diff --git a/encryptor/searchable_query_filter.go b/encryptor/searchable_query_filter.go index acc7fe8c5..d475e1181 100644 --- a/encryptor/searchable_query_filter.go +++ b/encryptor/searchable_query_filter.go @@ -105,31 +105,22 @@ func (filter *SearchableQueryFilter) filterInterestingTables(fromExp sqlparser.T } func (filter *SearchableQueryFilter) filterTableExpressions(statement sqlparser.Statement) (sqlparser.TableExprs, error) { - if filter.mode == QueryFilterModeConsistentTokenization { - switch query := statement.(type) { - case *sqlparser.Select: - return query.From, nil - case *sqlparser.Update: - return query.TableExprs, nil - case *sqlparser.Delete: - return query.TableExprs, nil - case *sqlparser.Insert: - // only support INSERT INTO table2 SELECT * FROM test_table WHERE data1='somedata' syntax for INSERTs - if selectInInsert, ok := query.Rows.(*sqlparser.Select); ok { - return selectInInsert.From, nil - } - return nil, ErrUnsupportedQueryType - default: - return nil, ErrUnsupportedQueryType + switch query := statement.(type) { + case *sqlparser.Select: + return query.From, nil + case *sqlparser.Update: + return query.TableExprs, nil + case *sqlparser.Delete: + return query.TableExprs, nil + case *sqlparser.Insert: + // only support INSERT INTO table2 SELECT * FROM test_table WHERE data1='somedata' syntax for INSERTs + if selectInInsert, ok := query.Rows.(*sqlparser.Select); ok { + return selectInInsert.From, nil } + return nil, ErrUnsupportedQueryType + default: + return nil, ErrUnsupportedQueryType } - - // TODO: extend with more query types for support for searchable encryption - selectStatement, ok := statement.(*sqlparser.Select) - if ok { - return selectStatement.From, nil - } - return nil, ErrUnsupportedQueryType } func (filter *SearchableQueryFilter) filterComparisonExprs(statement sqlparser.Statement) []*sqlparser.ComparisonExpr { diff --git a/hmac/decryptor/hashQuery_test.go b/hmac/decryptor/hashQuery_test.go index 0965fa766..17a8128f7 100644 --- a/hmac/decryptor/hashQuery_test.go +++ b/hmac/decryptor/hashQuery_test.go @@ -35,12 +35,11 @@ func TestSearchablePreparedStatementsWithTextFormat(t *testing.T) { - table: test_table columns: - data1 + - data2 encrypted: - column: data1 searchable: true` - query := `select data1 from test_table where data1=$1` - schema, err := config.MapTableSchemaStoreFromConfig([]byte(schemaConfig)) if err != nil { t.Fatal(err) @@ -64,36 +63,48 @@ func TestSearchablePreparedStatementsWithTextFormat(t *testing.T) { }), mock.Anything).Return(nil) _ = bindValue - queryObj := base.NewOnQueryObjectFromQuery(query, parser) - queryObj, _, err = encryptor.OnQuery(ctx, queryObj) - if err != nil { - t.Fatal(err) - } - bindPlaceholders := encryptor2.PlaceholderSettingsFromClientSession(clientSession) - if len(bindPlaceholders) != 1 { - t.Fatal("Not found expected amount of placeholders") - } - queryObj = base.NewOnQueryObjectFromQuery(query, parser) - statement, err := queryObj.Statement() - if err != nil { - t.Fatal(err) + type testcase struct { + Query string } - newVals, ok, err := encryptor.OnBind(ctx, statement, []base.BoundValue{boundValue}) - if err != nil { - t.Fatal(err) - } - if !ok { - t.Fatal("Values should be changed") - } - if len(newVals) != 1 { - t.Fatal("Invalid amount of bound values") - } - setting := schema.GetTableSchema("test_table").GetColumnEncryptionSettings("data1") - newData, err := newVals[0].GetData(setting) - if err != nil { - t.Fatal(err) + testcases := []testcase{ + {Query: "SELECT data1 from test_table WHERE data1=$1"}, + {Query: "UPDATE test_table SET kind = 'kind' WHERE data1=$1"}, + {Query: "INSERT INTO table2 SELECT * FROM test_table WHERE data1=$1 and data2=$2"}, + {Query: "DELETE FROM test_table WHERE data1=$1"}, + {Query: "DELETE FROM test_table WHERE data1=$1 OR data2=$2"}, } - if bytes.Equal(newData, sourceBindValue) { - t.Fatal("Data wasn't changed") + for _, testcase := range testcases { + queryObj := base.NewOnQueryObjectFromQuery(testcase.Query, parser) + queryObj, _, err = encryptor.OnQuery(ctx, queryObj) + if err != nil { + t.Fatal(err) + } + bindPlaceholders := encryptor2.PlaceholderSettingsFromClientSession(clientSession) + if len(bindPlaceholders) != 1 { + t.Fatal("Not found expected amount of placeholders") + } + queryObj = base.NewOnQueryObjectFromQuery(testcase.Query, parser) + statement, err := queryObj.Statement() + if err != nil { + t.Fatal(err) + } + newVals, ok, err := encryptor.OnBind(ctx, statement, []base.BoundValue{boundValue}) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatal("Values should be changed") + } + if len(newVals) != 1 { + t.Fatal("Invalid amount of bound values") + } + setting := schema.GetTableSchema("test_table").GetColumnEncryptionSettings("data1") + newData, err := newVals[0].GetData(setting) + if err != nil { + t.Fatal(err) + } + if bytes.Equal(newData, sourceBindValue) { + t.Fatal("Data wasn't changed") + } } } diff --git a/tests/test.py b/tests/test.py index 3661e0313..5692bdcbb 100644 --- a/tests/test.py +++ b/tests/test.py @@ -6778,6 +6778,9 @@ def insertDifferentRows(self, context, count, search_term=None, search_field='se self.insertRow(temp_context) count -= 1 + def execute_via_2(self, query, parameters): + return self.engine2.execute(query, parameters) + def executeSelect2(self, query, parameters): """Execute a SELECT query with parameters via AcraServer for "keypair2".""" return self.engine2.execute(query, parameters).fetchall() @@ -6792,6 +6795,10 @@ def executeSelect2(self, query, parameters): query, parameters = self.compileQuery(query, parameters) return self.executor2.execute_prepared_statement(query, parameters) + def execute_via_2(self, query, values): + query, parameters = self.compileQuery(query, values) + return self.executor2.execute_prepared_statement(query, parameters) + def executeBulkInsert(self, query, values): """Execute a Bulk Insert query with list of values via AcraServer for "TEST_TLS_CLIENT_2_CERT".""" query, parameters = self.compileBulkInsertQuery(query.values(values), values) @@ -6803,6 +6810,10 @@ def executeSelect2(self, query, parameters): query, parameters = self.compileQuery(query, parameters) return self.executor2.execute_prepared_statement(query, parameters) + def execute_via_2(self, query, parameters): + query, parameters = self.compileQuery(query, parameters) + return self.executor2.execute_prepared_statement_no_result(query, parameters) + def executeBulkInsert(self, query, values): """Execute a Bulk Insert query with list of values via AcraServer for "TEST_TLS_CLIENT_2_CERT".""" query, parameters = self.compileBulkInsertQuery(query.values(values), values) @@ -6828,6 +6839,83 @@ def testSearch(self): self.checkDefaultIdEncryption(**context) self.assertEqual(rows[0]['searchable'], search_term) + def testExtendedSyntaxSearch(self): + context = self.get_context_data() + search_term = context['searchable'] + + # Insert searchable data and some additional different rows + self.insertRow(context) + self.insertDifferentRows(context, count=5) + + rows = self.executeSelect2( + sa.select([self.encryptor_table]) + .where(self.encryptor_table.c.searchable == sa.bindparam('searchable')), + {'searchable': search_term}, + ) + self.assertEqual(len(rows), 1) + + self.checkDefaultIdEncryption(**context) + self.assertEqual(rows[0]['searchable'], search_term) + + new_token_i32 = random.randint(0, 2 ** 16) + update_data = { + 'token_i32': new_token_i32, + 'b_searchable': search_term + } + + # test searchable tokenization in update where statements + query = sa.update(self.encryptor_table).where(self.encryptor_table.c.searchable == sa.bindparam('b_searchable')).values(token_i32=new_token_i32) + self.execute_via_2(query, update_data) + + rows = self.executeSelect2( + sa.select([self.encryptor_table]) + .where(self.encryptor_table.c.searchable == sa.bindparam('searchable')), + {'searchable': search_term}, + ) + self.assertEqual(len(rows), 1) + + self.checkDefaultIdEncryption(**context) + self.assertEqual(rows[0]['searchable'], search_term) + self.assertEqual(rows[0]['token_i32'], new_token_i32) + + + row_id = get_random_id() + insert_data = { + 'param_1': row_id, + 'b_searchable': search_term + } + + select_columns = ['id', 'default_client_id', 'number', 'zone_id', 'specified_client_id', 'raw_data', 'searchable', 'searchable_acrablock', 'empty', + 'nullable', 'masking', 'token_bytes', 'token_email', 'token_str', 'token_i32', 'token_i64'] + + select_query = sa.select( + sa.literal(row_id).label('id'), sa.column('default_client_id'), sa.column('number'), sa.column('zone_id'), sa.column('specified_client_id'), + sa.column('raw_data'), sa.column('searchable'), sa.column('searchable_acrablock'), sa.column('empty'), sa.column('nullable'), + sa.column('masking'), sa.column('token_bytes'), sa.column('token_email'), sa.column('token_str'), sa.column('token_i32'), sa.column('token_i64')). \ + where(self.encryptor_table.c.searchable == sa.bindparam('b_searchable')) + + query = sa.insert(self.encryptor_table).from_select(select_columns, select_query) + self.execute_via_2(query, insert_data) + + # after insert there 2 rows should be present in DB + rows = self.executeSelect2( + sa.select([self.encryptor_table]) + .where(self.encryptor_table.c.searchable == sa.bindparam('searchable')), + {'searchable': search_term}, + ) + self.assertEqual(len(rows), 2) + + # test searchable encryption in delete statements + query = sa.delete(self.encryptor_table).where(self.encryptor_table.c.searchable == sa.bindparam('b_searchable')) + self.execute_via_2(query, update_data) + + rows = self.executeSelect2( + sa.select([self.encryptor_table]) + .where(self.encryptor_table.c.searchable == sa.bindparam('searchable')), + {'searchable': search_term}, + ) + self.assertEqual(len(rows), 0) + def testHashValidation(self): context = self.get_context_data() search_term = context['searchable'] From 897ac7680f92677b1f983903044ae86d2808c139 Mon Sep 17 00:00:00 2001 From: Artem Zhmaka Date: Fri, 30 Sep 2022 09:46:09 +0200 Subject: [PATCH 7/8] zhars/extend_syntax_for_searchable_encryption update CHANGELOG_DEV.md file --- CHANGELOG_DEV.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG_DEV.md b/CHANGELOG_DEV.md index ec833e6a9..5403a67fa 100644 --- a/CHANGELOG_DEV.md +++ b/CHANGELOG_DEV.md @@ -1,3 +1,6 @@ +# 0.94.0 - 2022-09-21 +- Extended SQL syntax for searchable encryption for PostgreSQL/MySQL with UPDATE/DELETE/INSERTS queries. + # 0.94.0 - 2022-09-21 - Implemented searchable tokenization for PostgreSQL/MySQL for text/binary protocols From d0adee4c69093706f77c7b8a748302b4fcd3545e Mon Sep 17 00:00:00 2001 From: Artem Zhmaka Date: Tue, 4 Oct 2022 08:50:17 +0100 Subject: [PATCH 8/8] zhars/extend_syntax_for_searchable_encryption Added unit test for searchable encryption extended syntax --- hmac/decryptor/hashQuery_test.go | 96 ++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/hmac/decryptor/hashQuery_test.go b/hmac/decryptor/hashQuery_test.go index 17a8128f7..0689b8912 100644 --- a/hmac/decryptor/hashQuery_test.go +++ b/hmac/decryptor/hashQuery_test.go @@ -3,6 +3,7 @@ package decryptor import ( "bytes" "context" + "fmt" "github.com/cossacklabs/acra/crypto" "github.com/cossacklabs/acra/decryptor/base" "github.com/cossacklabs/acra/decryptor/base/mocks" @@ -10,6 +11,7 @@ import ( "github.com/cossacklabs/acra/encryptor/config" mocks2 "github.com/cossacklabs/acra/keystore/mocks" "github.com/cossacklabs/acra/sqlparser" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "testing" ) @@ -108,3 +110,97 @@ func TestSearchablePreparedStatementsWithTextFormat(t *testing.T) { } } } + +// TestSearchableWithTextFormat process searchable SELECT query without placeholders with text format +func TestSearchableWithTextFormat(t *testing.T) { + clientSession := &mocks.ClientSession{} + sessionData := make(map[string]interface{}, 2) + clientSession.On("GetData", mock.Anything).Return(func(key string) interface{} { + return sessionData[key] + }, func(key string) bool { + _, ok := sessionData[key] + return ok + }) + clientSession.On("DeleteData", mock.Anything).Run(func(args mock.Arguments) { + delete(sessionData, args[0].(string)) + }) + clientSession.On("SetData", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + sessionData[args[0].(string)] = args[1] + }) + schemaConfig := `schemas: + - table: test_table + columns: + - data1 + - data2 + encrypted: + - column: data1 + searchable: true` + + schema, err := config.MapTableSchemaStoreFromConfig([]byte(schemaConfig)) + assert.NoError(t, err) + + ctx := base.SetClientSessionToContext(context.Background(), clientSession) + parser := sqlparser.New(sqlparser.ModeDefault) + keyStore := &mocks2.ServerKeyStore{} + keyStore.On("GetHMACSecretKey", mock.Anything).Return([]byte(`some key`), nil) + registryHandler := crypto.NewRegistryHandler(nil) + encryptor := NewPostgresqlHashQuery(keyStore, schema, registryHandler) + dataQueryPart := "test-data" + + coder := &encryptor2.PostgresqlDBDataCoder{} + + type testcase struct { + Query string + } + testcases := []testcase{ + {Query: "SELECT data1 from test_table WHERE data1='%s'"}, + {Query: "UPDATE test_table SET kind = 'kind' WHERE data1='%s'"}, + {Query: "INSERT INTO table2 SELECT * FROM test_table WHERE data1='%s' and data2='other-data'"}, + {Query: "DELETE FROM test_table WHERE data1='%s'"}, + {Query: "DELETE FROM test_table WHERE data1='%s' OR data2='other-data'"}, + } + for _, testcase := range testcases { + query := fmt.Sprintf(testcase.Query, dataQueryPart) + + queryObj := base.NewOnQueryObjectFromQuery(query, parser) + queryObj, _, err = encryptor.OnQuery(ctx, queryObj) + assert.NoError(t, err) + + stmt, err := queryObj.Statement() + assert.NoError(t, err) + + var whereStatements []*sqlparser.Where + err = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + where, ok := node.(*sqlparser.Where) + if ok { + whereStatements = append(whereStatements, where) + } + return true, nil + }, stmt) + assert.NoError(t, err) + assert.True(t, len(whereStatements) > 0) + + var comparisonExpr *sqlparser.ComparisonExpr + switch node := whereStatements[0].Expr.(type) { + case *sqlparser.ComparisonExpr: + comparisonExpr = node + case *sqlparser.AndExpr: + comparisonExpr = node.Left.(*sqlparser.ComparisonExpr) + case *sqlparser.OrExpr: + comparisonExpr = node.Left.(*sqlparser.ComparisonExpr) + } + + _, isSubstrExpr := comparisonExpr.Left.(*sqlparser.SubstrExpr) + assert.True(t, isSubstrExpr) + + rightVal := comparisonExpr.Right.(*sqlparser.SQLVal) + assert.NotEqual(t, dataQueryPart, string(rightVal.Val)) + + hmacValue, err := encryptor.calculateHmac(ctx, []byte(dataQueryPart)) + assert.NoError(t, err) + + newData, err := coder.Encode(rightVal, hmacValue) + assert.NoError(t, err) + assert.Equal(t, len(rightVal.Val), len(newData)) + } +}