Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend SQL syntax for searchable encryption #586

Merged
merged 10 commits into from
Oct 4, 2022
3 changes: 3 additions & 0 deletions CHANGELOG_DEV.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# 0.94.0 - 2022-09-21
- Extended SQL syntax for searchable encryption for PostgreSQL/MySQL with UPDATE/DELETE/INSERTS queries.

# 0.94.0 - 2022-09-21
- Implemented searchable tokenization for PostgreSQL/MySQL for text/binary protocols

Expand Down
20 changes: 12 additions & 8 deletions decryptor/postgresql/pending_packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,13 @@ var ErrRemoveFromEmptyPendingList = errors.New("removing from empty pending list
func (packets *pendingPacketsList) Add(packet interface{}) error {
switch packet.(type) {
case *ParsePacket, *BindPacket, *ExecutePacket, *pgproto3.RowDescription, *pgproto3.ParameterDescription:
packetList, ok := packets.lists[reflect.TypeOf(packet)]
packetType := reflect.TypeOf(packet)
packetList, ok := packets.lists[packetType]
if !ok {
packetList = list.New()
packets.lists[reflect.TypeOf(packet)] = packetList
}
log.WithField("packet", packet).Debugln("Add pending packet")
log.WithField("packet", packetType).Debugln("Add pending packet")
packetList.PushBack(packet)
return nil
}
Expand All @@ -59,15 +60,16 @@ func (packets *pendingPacketsList) Add(packet interface{}) error {
func (packets *pendingPacketsList) RemoveNextPendingPacket(packet interface{}) error {
switch packet.(type) {
case *ParsePacket, *BindPacket, *ExecutePacket, *pgproto3.RowDescription, *pgproto3.ParameterDescription:
packetList, ok := packets.lists[reflect.TypeOf(packet)]
packetType := reflect.TypeOf(packet)
packetList, ok := packets.lists[packetType]
if !ok {
return ErrRemoveFromEmptyPendingList
}
currentElement := packetList.Front()
if currentElement == nil {
return nil
}
log.WithField("packet", currentElement.Value).Debugln("Remove pending packet")
log.WithField("packet", packetType).Debugln("Remove pending packet")
packetList.Remove(currentElement)
return nil
}
Expand All @@ -93,15 +95,16 @@ func (packets *pendingPacketsList) RemoveAll(packet interface{}) error {
func (packets *pendingPacketsList) GetPendingPacket(packet interface{}) (interface{}, error) {
switch packet.(type) {
case *ParsePacket, *BindPacket, *ExecutePacket, *pgproto3.RowDescription, *pgproto3.ParameterDescription:
packetList, ok := packets.lists[reflect.TypeOf(packet)]
packetType := reflect.TypeOf(packet)
packetList, ok := packets.lists[packetType]
if !ok {
return nil, nil
}
currentElement := packetList.Front()
if currentElement == nil {
return nil, nil
}
log.WithField("packet", currentElement.Value).Debugln("Return pending packet")
log.WithField("packet", packetType).Debugln("Return pending packet")
return currentElement.Value, nil
}
return nil, ErrUnsupportedPendingPacketType
Expand All @@ -111,15 +114,16 @@ func (packets *pendingPacketsList) GetPendingPacket(packet interface{}) (interfa
func (packets *pendingPacketsList) GetLastPending(packet interface{}) (interface{}, error) {
switch packet.(type) {
case *ParsePacket, *BindPacket, *ExecutePacket, *pgproto3.RowDescription, *pgproto3.ParameterDescription:
packetList, ok := packets.lists[reflect.TypeOf(packet)]
packetType := reflect.TypeOf(packet)
packetList, ok := packets.lists[packetType]
if !ok {
return nil, nil
}
currentElement := packetList.Back()
if currentElement == nil {
return nil, nil
}
log.WithField("packet", currentElement.Value).Debugln("Return last added packet")
log.WithField("packet", packetType).Debugln("Return last added packet")
return currentElement.Value, nil
}
return nil, ErrUnsupportedPendingPacketType
Expand Down
37 changes: 14 additions & 23 deletions encryptor/searchable_query_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,31 +105,22 @@ func (filter *SearchableQueryFilter) filterInterestingTables(fromExp sqlparser.T
}

func (filter *SearchableQueryFilter) filterTableExpressions(statement sqlparser.Statement) (sqlparser.TableExprs, error) {
if filter.mode == QueryFilterModeConsistentTokenization {
switch query := statement.(type) {
case *sqlparser.Select:
return query.From, nil
case *sqlparser.Update:
return query.TableExprs, nil
case *sqlparser.Delete:
return query.TableExprs, nil
case *sqlparser.Insert:
// only support INSERT INTO table2 SELECT * FROM test_table WHERE data1='somedata' syntax for INSERTs
if selectInInsert, ok := query.Rows.(*sqlparser.Select); ok {
return selectInInsert.From, nil
}
return nil, ErrUnsupportedQueryType
default:
return nil, ErrUnsupportedQueryType
switch query := statement.(type) {
case *sqlparser.Select:
return query.From, nil
case *sqlparser.Update:
return query.TableExprs, nil
case *sqlparser.Delete:
return query.TableExprs, nil
case *sqlparser.Insert:
// only support INSERT INTO table2 SELECT * FROM test_table WHERE data1='somedata' syntax for INSERTs
if selectInInsert, ok := query.Rows.(*sqlparser.Select); ok {
return selectInInsert.From, nil
}
return nil, ErrUnsupportedQueryType
default:
return nil, ErrUnsupportedQueryType
}

// TODO: extend with more query types for support for searchable encryption
selectStatement, ok := statement.(*sqlparser.Select)
if ok {
return selectStatement.From, nil
}
return nil, ErrUnsupportedQueryType
}

func (filter *SearchableQueryFilter) filterComparisonExprs(statement sqlparser.Statement) []*sqlparser.ComparisonExpr {
Expand Down
163 changes: 135 additions & 28 deletions hmac/decryptor/hashQuery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ package decryptor
import (
"bytes"
"context"
"fmt"
"github.com/cossacklabs/acra/crypto"
"github.com/cossacklabs/acra/decryptor/base"
"github.com/cossacklabs/acra/decryptor/base/mocks"
encryptor2 "github.com/cossacklabs/acra/encryptor"
"github.com/cossacklabs/acra/encryptor/config"
mocks2 "github.com/cossacklabs/acra/keystore/mocks"
"github.com/cossacklabs/acra/sqlparser"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"testing"
)
Expand All @@ -35,12 +37,11 @@ func TestSearchablePreparedStatementsWithTextFormat(t *testing.T) {
- table: test_table
columns:
- data1
- data2
encrypted:
- column: data1
searchable: true`

query := `select data1 from test_table where data1=$1`

schema, err := config.MapTableSchemaStoreFromConfig([]byte(schemaConfig))
if err != nil {
t.Fatal(err)
Expand All @@ -64,36 +65,142 @@ func TestSearchablePreparedStatementsWithTextFormat(t *testing.T) {
}), mock.Anything).Return(nil)
_ = bindValue

queryObj := base.NewOnQueryObjectFromQuery(query, parser)
queryObj, _, err = encryptor.OnQuery(ctx, queryObj)
if err != nil {
t.Fatal(err)
}
bindPlaceholders := encryptor2.PlaceholderSettingsFromClientSession(clientSession)
if len(bindPlaceholders) != 1 {
t.Fatal("Not found expected amount of placeholders")
}
queryObj = base.NewOnQueryObjectFromQuery(query, parser)
statement, err := queryObj.Statement()
if err != nil {
t.Fatal(err)
type testcase struct {
Query string
}
newVals, ok, err := encryptor.OnBind(ctx, statement, []base.BoundValue{boundValue})
if err != nil {
t.Fatal(err)
testcases := []testcase{
{Query: "SELECT data1 from test_table WHERE data1=$1"},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add unit-tests for these cases but with value literals? to test it without prepared statements too.

{Query: "UPDATE test_table SET kind = 'kind' WHERE data1=$1"},
{Query: "INSERT INTO table2 SELECT * FROM test_table WHERE data1=$1 and data2=$2"},
{Query: "DELETE FROM test_table WHERE data1=$1"},
{Query: "DELETE FROM test_table WHERE data1=$1 OR data2=$2"},
}
if !ok {
t.Fatal("Values should be changed")
for _, testcase := range testcases {
queryObj := base.NewOnQueryObjectFromQuery(testcase.Query, parser)
queryObj, _, err = encryptor.OnQuery(ctx, queryObj)
if err != nil {
t.Fatal(err)
}
bindPlaceholders := encryptor2.PlaceholderSettingsFromClientSession(clientSession)
if len(bindPlaceholders) != 1 {
t.Fatal("Not found expected amount of placeholders")
}
queryObj = base.NewOnQueryObjectFromQuery(testcase.Query, parser)
statement, err := queryObj.Statement()
if err != nil {
t.Fatal(err)
}
newVals, ok, err := encryptor.OnBind(ctx, statement, []base.BoundValue{boundValue})
if err != nil {
t.Fatal(err)
}
if !ok {
t.Fatal("Values should be changed")
}
if len(newVals) != 1 {
t.Fatal("Invalid amount of bound values")
}
setting := schema.GetTableSchema("test_table").GetColumnEncryptionSettings("data1")
newData, err := newVals[0].GetData(setting)
if err != nil {
t.Fatal(err)
}
if bytes.Equal(newData, sourceBindValue) {
t.Fatal("Data wasn't changed")
}
}
if len(newVals) != 1 {
t.Fatal("Invalid amount of bound values")
}

// TestSearchableWithTextFormat process searchable SELECT query without placeholders with text format
func TestSearchableWithTextFormat(t *testing.T) {
clientSession := &mocks.ClientSession{}
sessionData := make(map[string]interface{}, 2)
clientSession.On("GetData", mock.Anything).Return(func(key string) interface{} {
return sessionData[key]
}, func(key string) bool {
_, ok := sessionData[key]
return ok
})
clientSession.On("DeleteData", mock.Anything).Run(func(args mock.Arguments) {
delete(sessionData, args[0].(string))
})
clientSession.On("SetData", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
sessionData[args[0].(string)] = args[1]
})
schemaConfig := `schemas:
- table: test_table
columns:
- data1
- data2
encrypted:
- column: data1
searchable: true`

schema, err := config.MapTableSchemaStoreFromConfig([]byte(schemaConfig))
assert.NoError(t, err)

ctx := base.SetClientSessionToContext(context.Background(), clientSession)
parser := sqlparser.New(sqlparser.ModeDefault)
keyStore := &mocks2.ServerKeyStore{}
keyStore.On("GetHMACSecretKey", mock.Anything).Return([]byte(`some key`), nil)
registryHandler := crypto.NewRegistryHandler(nil)
encryptor := NewPostgresqlHashQuery(keyStore, schema, registryHandler)
dataQueryPart := "test-data"

coder := &encryptor2.PostgresqlDBDataCoder{}

type testcase struct {
Query string
}
setting := schema.GetTableSchema("test_table").GetColumnEncryptionSettings("data1")
newData, err := newVals[0].GetData(setting)
if err != nil {
t.Fatal(err)
testcases := []testcase{
{Query: "SELECT data1 from test_table WHERE data1='%s'"},
{Query: "UPDATE test_table SET kind = 'kind' WHERE data1='%s'"},
{Query: "INSERT INTO table2 SELECT * FROM test_table WHERE data1='%s' and data2='other-data'"},
{Query: "DELETE FROM test_table WHERE data1='%s'"},
{Query: "DELETE FROM test_table WHERE data1='%s' OR data2='other-data'"},
}
if bytes.Equal(newData, sourceBindValue) {
t.Fatal("Data wasn't changed")
for _, testcase := range testcases {
query := fmt.Sprintf(testcase.Query, dataQueryPart)

queryObj := base.NewOnQueryObjectFromQuery(query, parser)
queryObj, _, err = encryptor.OnQuery(ctx, queryObj)
assert.NoError(t, err)

stmt, err := queryObj.Statement()
assert.NoError(t, err)

var whereStatements []*sqlparser.Where
err = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
where, ok := node.(*sqlparser.Where)
if ok {
whereStatements = append(whereStatements, where)
}
return true, nil
}, stmt)
assert.NoError(t, err)
assert.True(t, len(whereStatements) > 0)

var comparisonExpr *sqlparser.ComparisonExpr
switch node := whereStatements[0].Expr.(type) {
case *sqlparser.ComparisonExpr:
comparisonExpr = node
case *sqlparser.AndExpr:
comparisonExpr = node.Left.(*sqlparser.ComparisonExpr)
case *sqlparser.OrExpr:
comparisonExpr = node.Left.(*sqlparser.ComparisonExpr)
}

_, isSubstrExpr := comparisonExpr.Left.(*sqlparser.SubstrExpr)
assert.True(t, isSubstrExpr)

rightVal := comparisonExpr.Right.(*sqlparser.SQLVal)
assert.NotEqual(t, dataQueryPart, string(rightVal.Val))

hmacValue, err := encryptor.calculateHmac(ctx, []byte(dataQueryPart))
assert.NoError(t, err)

newData, err := coder.Encode(rightVal, hmacValue)
assert.NoError(t, err)
assert.Equal(t, len(rightVal.Val), len(newData))
}
}
Loading