Skip to content

Commit

Permalink
Extend SQL syntax for searchable encryption (#586)
Browse files Browse the repository at this point in the history
* zhars/extend_syntax_for_searchable_encryption

Extend syntax for searchable encryption to support INSERT/UPDATE/DELETE
  • Loading branch information
Zhaars authored Oct 4, 2022
1 parent 492568a commit 0d9d499
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 59 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG_DEV.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# 0.94.0 - 2022-09-21
- Extended SQL syntax for searchable encryption for PostgreSQL/MySQL with UPDATE/DELETE/INSERTS queries.

# 0.94.0 - 2022-09-21
- Implemented searchable tokenization for PostgreSQL/MySQL for text/binary protocols

Expand Down
20 changes: 12 additions & 8 deletions decryptor/postgresql/pending_packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,13 @@ var ErrRemoveFromEmptyPendingList = errors.New("removing from empty pending list
func (packets *pendingPacketsList) Add(packet interface{}) error {
switch packet.(type) {
case *ParsePacket, *BindPacket, *ExecutePacket, *pgproto3.RowDescription, *pgproto3.ParameterDescription:
packetList, ok := packets.lists[reflect.TypeOf(packet)]
packetType := reflect.TypeOf(packet)
packetList, ok := packets.lists[packetType]
if !ok {
packetList = list.New()
packets.lists[reflect.TypeOf(packet)] = packetList
}
log.WithField("packet", packet).Debugln("Add pending packet")
log.WithField("packet", packetType).Debugln("Add pending packet")
packetList.PushBack(packet)
return nil
}
Expand All @@ -59,15 +60,16 @@ func (packets *pendingPacketsList) Add(packet interface{}) error {
func (packets *pendingPacketsList) RemoveNextPendingPacket(packet interface{}) error {
switch packet.(type) {
case *ParsePacket, *BindPacket, *ExecutePacket, *pgproto3.RowDescription, *pgproto3.ParameterDescription:
packetList, ok := packets.lists[reflect.TypeOf(packet)]
packetType := reflect.TypeOf(packet)
packetList, ok := packets.lists[packetType]
if !ok {
return ErrRemoveFromEmptyPendingList
}
currentElement := packetList.Front()
if currentElement == nil {
return nil
}
log.WithField("packet", currentElement.Value).Debugln("Remove pending packet")
log.WithField("packet", packetType).Debugln("Remove pending packet")
packetList.Remove(currentElement)
return nil
}
Expand All @@ -93,15 +95,16 @@ func (packets *pendingPacketsList) RemoveAll(packet interface{}) error {
func (packets *pendingPacketsList) GetPendingPacket(packet interface{}) (interface{}, error) {
switch packet.(type) {
case *ParsePacket, *BindPacket, *ExecutePacket, *pgproto3.RowDescription, *pgproto3.ParameterDescription:
packetList, ok := packets.lists[reflect.TypeOf(packet)]
packetType := reflect.TypeOf(packet)
packetList, ok := packets.lists[packetType]
if !ok {
return nil, nil
}
currentElement := packetList.Front()
if currentElement == nil {
return nil, nil
}
log.WithField("packet", currentElement.Value).Debugln("Return pending packet")
log.WithField("packet", packetType).Debugln("Return pending packet")
return currentElement.Value, nil
}
return nil, ErrUnsupportedPendingPacketType
Expand All @@ -111,15 +114,16 @@ func (packets *pendingPacketsList) GetPendingPacket(packet interface{}) (interfa
func (packets *pendingPacketsList) GetLastPending(packet interface{}) (interface{}, error) {
switch packet.(type) {
case *ParsePacket, *BindPacket, *ExecutePacket, *pgproto3.RowDescription, *pgproto3.ParameterDescription:
packetList, ok := packets.lists[reflect.TypeOf(packet)]
packetType := reflect.TypeOf(packet)
packetList, ok := packets.lists[packetType]
if !ok {
return nil, nil
}
currentElement := packetList.Back()
if currentElement == nil {
return nil, nil
}
log.WithField("packet", currentElement.Value).Debugln("Return last added packet")
log.WithField("packet", packetType).Debugln("Return last added packet")
return currentElement.Value, nil
}
return nil, ErrUnsupportedPendingPacketType
Expand Down
37 changes: 14 additions & 23 deletions encryptor/searchable_query_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,31 +105,22 @@ func (filter *SearchableQueryFilter) filterInterestingTables(fromExp sqlparser.T
}

func (filter *SearchableQueryFilter) filterTableExpressions(statement sqlparser.Statement) (sqlparser.TableExprs, error) {
if filter.mode == QueryFilterModeConsistentTokenization {
switch query := statement.(type) {
case *sqlparser.Select:
return query.From, nil
case *sqlparser.Update:
return query.TableExprs, nil
case *sqlparser.Delete:
return query.TableExprs, nil
case *sqlparser.Insert:
// only support INSERT INTO table2 SELECT * FROM test_table WHERE data1='somedata' syntax for INSERTs
if selectInInsert, ok := query.Rows.(*sqlparser.Select); ok {
return selectInInsert.From, nil
}
return nil, ErrUnsupportedQueryType
default:
return nil, ErrUnsupportedQueryType
switch query := statement.(type) {
case *sqlparser.Select:
return query.From, nil
case *sqlparser.Update:
return query.TableExprs, nil
case *sqlparser.Delete:
return query.TableExprs, nil
case *sqlparser.Insert:
// only support INSERT INTO table2 SELECT * FROM test_table WHERE data1='somedata' syntax for INSERTs
if selectInInsert, ok := query.Rows.(*sqlparser.Select); ok {
return selectInInsert.From, nil
}
return nil, ErrUnsupportedQueryType
default:
return nil, ErrUnsupportedQueryType
}

// TODO: extend with more query types for support for searchable encryption
selectStatement, ok := statement.(*sqlparser.Select)
if ok {
return selectStatement.From, nil
}
return nil, ErrUnsupportedQueryType
}

func (filter *SearchableQueryFilter) filterComparisonExprs(statement sqlparser.Statement) []*sqlparser.ComparisonExpr {
Expand Down
163 changes: 135 additions & 28 deletions hmac/decryptor/hashQuery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ package decryptor
import (
"bytes"
"context"
"fmt"
"github.com/cossacklabs/acra/crypto"
"github.com/cossacklabs/acra/decryptor/base"
"github.com/cossacklabs/acra/decryptor/base/mocks"
encryptor2 "github.com/cossacklabs/acra/encryptor"
"github.com/cossacklabs/acra/encryptor/config"
mocks2 "github.com/cossacklabs/acra/keystore/mocks"
"github.com/cossacklabs/acra/sqlparser"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"testing"
)
Expand All @@ -35,12 +37,11 @@ func TestSearchablePreparedStatementsWithTextFormat(t *testing.T) {
- table: test_table
columns:
- data1
- data2
encrypted:
- column: data1
searchable: true`

query := `select data1 from test_table where data1=$1`

schema, err := config.MapTableSchemaStoreFromConfig([]byte(schemaConfig))
if err != nil {
t.Fatal(err)
Expand All @@ -64,36 +65,142 @@ func TestSearchablePreparedStatementsWithTextFormat(t *testing.T) {
}), mock.Anything).Return(nil)
_ = bindValue

queryObj := base.NewOnQueryObjectFromQuery(query, parser)
queryObj, _, err = encryptor.OnQuery(ctx, queryObj)
if err != nil {
t.Fatal(err)
}
bindPlaceholders := encryptor2.PlaceholderSettingsFromClientSession(clientSession)
if len(bindPlaceholders) != 1 {
t.Fatal("Not found expected amount of placeholders")
}
queryObj = base.NewOnQueryObjectFromQuery(query, parser)
statement, err := queryObj.Statement()
if err != nil {
t.Fatal(err)
type testcase struct {
Query string
}
newVals, ok, err := encryptor.OnBind(ctx, statement, []base.BoundValue{boundValue})
if err != nil {
t.Fatal(err)
testcases := []testcase{
{Query: "SELECT data1 from test_table WHERE data1=$1"},
{Query: "UPDATE test_table SET kind = 'kind' WHERE data1=$1"},
{Query: "INSERT INTO table2 SELECT * FROM test_table WHERE data1=$1 and data2=$2"},
{Query: "DELETE FROM test_table WHERE data1=$1"},
{Query: "DELETE FROM test_table WHERE data1=$1 OR data2=$2"},
}
if !ok {
t.Fatal("Values should be changed")
for _, testcase := range testcases {
queryObj := base.NewOnQueryObjectFromQuery(testcase.Query, parser)
queryObj, _, err = encryptor.OnQuery(ctx, queryObj)
if err != nil {
t.Fatal(err)
}
bindPlaceholders := encryptor2.PlaceholderSettingsFromClientSession(clientSession)
if len(bindPlaceholders) != 1 {
t.Fatal("Not found expected amount of placeholders")
}
queryObj = base.NewOnQueryObjectFromQuery(testcase.Query, parser)
statement, err := queryObj.Statement()
if err != nil {
t.Fatal(err)
}
newVals, ok, err := encryptor.OnBind(ctx, statement, []base.BoundValue{boundValue})
if err != nil {
t.Fatal(err)
}
if !ok {
t.Fatal("Values should be changed")
}
if len(newVals) != 1 {
t.Fatal("Invalid amount of bound values")
}
setting := schema.GetTableSchema("test_table").GetColumnEncryptionSettings("data1")
newData, err := newVals[0].GetData(setting)
if err != nil {
t.Fatal(err)
}
if bytes.Equal(newData, sourceBindValue) {
t.Fatal("Data wasn't changed")
}
}
if len(newVals) != 1 {
t.Fatal("Invalid amount of bound values")
}

// TestSearchableWithTextFormat process searchable SELECT query without placeholders with text format
func TestSearchableWithTextFormat(t *testing.T) {
clientSession := &mocks.ClientSession{}
sessionData := make(map[string]interface{}, 2)
clientSession.On("GetData", mock.Anything).Return(func(key string) interface{} {
return sessionData[key]
}, func(key string) bool {
_, ok := sessionData[key]
return ok
})
clientSession.On("DeleteData", mock.Anything).Run(func(args mock.Arguments) {
delete(sessionData, args[0].(string))
})
clientSession.On("SetData", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
sessionData[args[0].(string)] = args[1]
})
schemaConfig := `schemas:
- table: test_table
columns:
- data1
- data2
encrypted:
- column: data1
searchable: true`

schema, err := config.MapTableSchemaStoreFromConfig([]byte(schemaConfig))
assert.NoError(t, err)

ctx := base.SetClientSessionToContext(context.Background(), clientSession)
parser := sqlparser.New(sqlparser.ModeDefault)
keyStore := &mocks2.ServerKeyStore{}
keyStore.On("GetHMACSecretKey", mock.Anything).Return([]byte(`some key`), nil)
registryHandler := crypto.NewRegistryHandler(nil)
encryptor := NewPostgresqlHashQuery(keyStore, schema, registryHandler)
dataQueryPart := "test-data"

coder := &encryptor2.PostgresqlDBDataCoder{}

type testcase struct {
Query string
}
setting := schema.GetTableSchema("test_table").GetColumnEncryptionSettings("data1")
newData, err := newVals[0].GetData(setting)
if err != nil {
t.Fatal(err)
testcases := []testcase{
{Query: "SELECT data1 from test_table WHERE data1='%s'"},
{Query: "UPDATE test_table SET kind = 'kind' WHERE data1='%s'"},
{Query: "INSERT INTO table2 SELECT * FROM test_table WHERE data1='%s' and data2='other-data'"},
{Query: "DELETE FROM test_table WHERE data1='%s'"},
{Query: "DELETE FROM test_table WHERE data1='%s' OR data2='other-data'"},
}
if bytes.Equal(newData, sourceBindValue) {
t.Fatal("Data wasn't changed")
for _, testcase := range testcases {
query := fmt.Sprintf(testcase.Query, dataQueryPart)

queryObj := base.NewOnQueryObjectFromQuery(query, parser)
queryObj, _, err = encryptor.OnQuery(ctx, queryObj)
assert.NoError(t, err)

stmt, err := queryObj.Statement()
assert.NoError(t, err)

var whereStatements []*sqlparser.Where
err = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
where, ok := node.(*sqlparser.Where)
if ok {
whereStatements = append(whereStatements, where)
}
return true, nil
}, stmt)
assert.NoError(t, err)
assert.True(t, len(whereStatements) > 0)

var comparisonExpr *sqlparser.ComparisonExpr
switch node := whereStatements[0].Expr.(type) {
case *sqlparser.ComparisonExpr:
comparisonExpr = node
case *sqlparser.AndExpr:
comparisonExpr = node.Left.(*sqlparser.ComparisonExpr)
case *sqlparser.OrExpr:
comparisonExpr = node.Left.(*sqlparser.ComparisonExpr)
}

_, isSubstrExpr := comparisonExpr.Left.(*sqlparser.SubstrExpr)
assert.True(t, isSubstrExpr)

rightVal := comparisonExpr.Right.(*sqlparser.SQLVal)
assert.NotEqual(t, dataQueryPart, string(rightVal.Val))

hmacValue, err := encryptor.calculateHmac(ctx, []byte(dataQueryPart))
assert.NoError(t, err)

newData, err := coder.Encode(rightVal, hmacValue)
assert.NoError(t, err)
assert.Equal(t, len(rightVal.Val), len(newData))
}
}
Loading

0 comments on commit 0d9d499

Please sign in to comment.