Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed bug with matching default table with alias #594

Merged
merged 4 commits into from
Oct 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion encryptor/queryDataEncryptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ func (encryptor *QueryDataEncryptor) OnColumn(ctx context.Context, data []byte)
const allColumnsName = "*"

func (encryptor *QueryDataEncryptor) onSelect(ctx context.Context, statement *sqlparser.Select) (bool, error) {
columns, err := mapColumnsToAliases(statement)
columns, err := mapColumnsToAliases(statement, encryptor.schemaStore)
if err != nil {
logrus.WithError(err).Errorln("Can't extract columns from SELECT statement")
return false, err
Expand Down
19 changes: 1 addition & 18 deletions encryptor/searchable_query_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (filter *SearchableQueryFilter) filterInterestingTables(fromExp sqlparser.T
var defaultTableName string
// if query contains table without alias we need to detect default table
// if no, we can ignore default table and AliasToTableMap will be used to map ColName with encryptor_config
if filter.hasTablesWithoutAliases(fromExp) {
if hasTablesWithoutAliases(fromExp) {
var err error
defaultTableName, err = getFirstTableWithoutAlias(fromExp)
if err != nil {
Expand Down Expand Up @@ -234,23 +234,6 @@ func isSupportedSQLVal(val *sqlparser.SQLVal) bool {
return false
}

func (filter *SearchableQueryFilter) hasTablesWithoutAliases(stmt sqlparser.SQLNode) bool {
var hasTableWithoutAlias bool
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
if tableExpr, ok := node.(*sqlparser.AliasedTableExpr); ok {
if tableExpr.As.IsEmpty() {
hasTableWithoutAlias = true
}
}
return true, nil
}, stmt)
if err != nil {
return false
}

return hasTableWithoutAlias
}

// getEqualComparisonExprs return only <ColName> = <VALUE> or <ColName> != <VALUE> or <ColName> <=> <VALUE> expressions
func (filter *SearchableQueryFilter) getEqualComparisonExprs(stmt sqlparser.SQLNode, defaultTable *AliasedTableName, aliasedTables AliasToTableMap) ([]*sqlparser.ComparisonExpr, error) {
var exprs []*sqlparser.ComparisonExpr
Expand Down
101 changes: 96 additions & 5 deletions encryptor/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (

var errNotFoundtable = errors.New("not found table for alias")
var errNotSupported = errors.New("not supported type of sql node")
var errTableAlreadyMatched = errors.New("aliased table name already matched")

type columnInfo struct {
Name string
Expand Down Expand Up @@ -175,6 +176,67 @@ func getFirstTableWithoutAlias(fromExpr sqlparser.TableExprs) (string, error) {
return name, nil
}

func getMatchedAliasedTable(fromExpr sqlparser.TableExprs, colName *sqlparser.ColName, tableSchemaStore config.TableSchemaStore) (string, error) {
if len(fromExpr) == 0 {
return "", errEmptyTableExprs
}

isTableColumn := func(tableSchema config.TableSchema, colName *sqlparser.ColName) bool {
for _, column := range tableSchema.Columns() {
if column == colName.Name.ValueForConfig() {
return true
}
}
return false
}

var alisedName string
for _, exp := range fromExpr {
aliased, ok := exp.(*sqlparser.AliasedTableExpr)
if !ok {
return "", errUnsupportedExpression
}

tableName, ok := aliased.Expr.(sqlparser.TableName)
if !ok {
return "", errUnsupportedExpression
}

tableSchema := tableSchemaStore.GetTableSchema(tableName.Name.ValueForConfig())
if tableSchema == nil {
continue
}

if isTableColumn(tableSchema, colName) {
tName, ok := getAliasedName(aliased)
if !ok {
return "", errUnsupportedExpression
}

if alisedName != "" {
logrus.WithField("alias", alisedName).Infoln("Ambiguous column found, several tables contain the same column")
return "", errTableAlreadyMatched
}

alisedName = tName
}
}

return alisedName, nil
}

func getAliasedName(aliased *sqlparser.AliasedTableExpr) (string, bool) {
if _, ok := aliased.Expr.(sqlparser.TableName); !ok {
return "", false
}

if aliased.As.IsEmpty() {
return "", false
}

return aliased.As.ValueForConfig(), true
}

func getNonAliasedName(aliased *sqlparser.AliasedTableExpr) (string, bool) {
if !aliased.As.IsEmpty() {
return "", false
Expand Down Expand Up @@ -292,7 +354,7 @@ func findTableName(alias, columnName string, expr sqlparser.SQLNode) (columnInfo
return columnInfo{}, errNotFoundtable
}

func mapColumnsToAliases(selectQuery *sqlparser.Select) ([]*columnInfo, error) {
func mapColumnsToAliases(selectQuery *sqlparser.Select, tableSchemaStore config.TableSchemaStore) ([]*columnInfo, error) {
out := make([]*columnInfo, 0, len(selectQuery.SelectExprs))
var joinTables []string
var joinAliases map[string]string
Expand All @@ -306,6 +368,11 @@ func mapColumnsToAliases(selectQuery *sqlparser.Select) ([]*columnInfo, error) {
}
}

// from DB prospective its valid to have columns without aliases in select but do have alias on table
// SELECT "id", "email", "mobile_number" AS "mobileNumber" FROM "users" AS "User""
// in such case Acra should consider aliased table as default table
hasTablesWithoutAliases := hasTablesWithoutAliases(selectQuery.From)

for _, expr := range selectQuery.SelectExprs {
aliased, ok := expr.(*sqlparser.AliasedExpr)
if ok {
Expand All @@ -322,7 +389,7 @@ func mapColumnsToAliases(selectQuery *sqlparser.Select) ([]*columnInfo, error) {
return nil, errUnsupportedExpression
}

subColumn, err := mapColumnsToAliases(subSelect)
subColumn, err := mapColumnsToAliases(subSelect, tableSchemaStore)
if err != nil {
return nil, err
}
Expand All @@ -334,14 +401,21 @@ func mapColumnsToAliases(selectQuery *sqlparser.Select) ([]*columnInfo, error) {
colName, ok := aliased.Expr.(*sqlparser.ColName)
if ok {
if colName.Qualifier.Name.IsEmpty() {
firstTable, err := getFirstTableWithoutAlias(selectQuery.From)
var columnTable string
var err error
if hasTablesWithoutAliases {
columnTable, err = getFirstTableWithoutAlias(selectQuery.From)
} else {
columnTable, err = getMatchedAliasedTable(selectQuery.From, colName, tableSchemaStore)
}
if err != nil {
out = append(out, nil)
continue
}
info, err := findTableName(firstTable, colName.Name.String(), selectQuery.From)

info, err := findTableName(columnTable, colName.Name.String(), selectQuery.From)
if err == nil {
info.Alias = firstTable
info.Alias = columnTable
out = append(out, &info)
continue
}
Expand Down Expand Up @@ -398,6 +472,23 @@ func mapColumnsToAliases(selectQuery *sqlparser.Select) ([]*columnInfo, error) {
return out, nil
}

func hasTablesWithoutAliases(stmt sqlparser.SQLNode) bool {
var hasTableWithoutAlias bool
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
if tableExpr, ok := node.(*sqlparser.AliasedTableExpr); ok {
if tableExpr.As.IsEmpty() {
hasTableWithoutAlias = true
}
}
return true, nil
}, stmt)
if err != nil {
return false
}

return hasTableWithoutAlias
}

// InvalidPlaceholderIndex value that represent invalid index for sql placeholders
const InvalidPlaceholderIndex = -1

Expand Down
135 changes: 128 additions & 7 deletions encryptor/utils_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package encryptor

import (
"testing"

"github.com/cossacklabs/acra/decryptor/base/mocks"
"github.com/cossacklabs/acra/encryptor/config"
"github.com/cossacklabs/acra/sqlparser"
"github.com/cossacklabs/acra/sqlparser/dialect"
"github.com/cossacklabs/acra/sqlparser/dialect/mysql"
"github.com/cossacklabs/acra/sqlparser/dialect/postgresql"
"github.com/stretchr/testify/mock"
"testing"
)

func TestGetFirstTableWithoutAlias(t *testing.T) {
Expand Down Expand Up @@ -85,7 +89,7 @@ inner join table6 on table6.col1=t1.col1
if !ok {
t.Fatal("Test query should be Select expression")
}
columns, err := mapColumnsToAliases(selectExpr)
columns, err := mapColumnsToAliases(selectExpr, &config.MapTableSchemaStore{})
if err != nil {
t.Fatal(err)
}
Expand All @@ -103,6 +107,123 @@ inner join table6 on table6.col1=t1.col1
}
}
})
t.Run("with aliased table and non-aliased colum name", func(t *testing.T) {
testConfig := `
schemas:
- table: users
columns:
- id
- email
- mobile_number
encrypted:
- column: id

- table: users_duplicate
columns:
- id
- email
- mobile_number
encrypted:
- column: id

- table: users_temp
columns:
- id_tmp
- email_tmp
- mobile_number_tmp
encrypted:
- column: id_tmp
`
schemaStore, err := config.MapTableSchemaStoreFromConfig([]byte(testConfig))
if err != nil {
t.Fatal(err)
}

testcases := []struct {
query string
dialect dialect.Dialect
expectedValues []*columnInfo
}{
{
query: `SELECT "id", "email", "mobile_number" AS "mobileNumber" FROM "users" AS "User" where "User"."is_active"`,
expectedValues: []*columnInfo{
{Alias: "User", Table: "users", Name: "id"},
{Alias: "User", Table: "users", Name: "email"},
{Alias: "User", Table: "users", Name: "mobile_number"},
},
},
{
query: `SELECT "id", "email", "mobile_number" AS "mobileNumber" FROM "users" AS "User", "table1" as "test_table"`,
expectedValues: []*columnInfo{
{Alias: "User", Table: "users", Name: "id"},
{Alias: "User", Table: "users", Name: "email"},
{Alias: "User", Table: "users", Name: "mobile_number"},
},
},
{
query: `SELECT "id", "email", "mobile_number" AS "mobileNumber" FROM "users" AS "User", "users_duplicate" as "User2"`,
expectedValues: []*columnInfo{
nil, nil, nil,
},
},
{
query: `SELECT "id", "email", "mobile_number", "id_tmp", "email_tmp", "mobile_number_tmp" AS "mobileNumber" FROM "users" AS "User", "users_temp" as "temp"`,
expectedValues: []*columnInfo{
{Alias: "User", Table: "users", Name: "id"},
{Alias: "User", Table: "users", Name: "email"},
{Alias: "User", Table: "users", Name: "mobile_number"},
{Alias: "temp", Table: "users_temp", Name: "id_tmp"},
{Alias: "temp", Table: "users_temp", Name: "email_tmp"},
{Alias: "temp", Table: "users_temp", Name: "mobile_number_tmp"},
},
},
{
query: `SELECT id, email, mobile_number FROM users AS alias where alias.is_active`,
dialect: mysql.NewMySQLDialect(),
expectedValues: []*columnInfo{
{Alias: "alias", Table: "users", Name: "id"},
{Alias: "alias", Table: "users", Name: "email"},
{Alias: "alias", Table: "users", Name: "mobile_number"},
},
},
}
for i, tcase := range testcases {
var dialect dialect.Dialect = postgresql.NewPostgreSQLDialect()
if tcase.dialect != nil {
dialect = tcase.dialect
}
sqlparser.SetDefaultDialect(dialect)

parsed, err := parser.Parse(tcase.query)
if err != nil {
t.Fatal(err)
}
selectExpr, ok := parsed.(*sqlparser.Select)
if !ok {
t.Fatal("Test query should be Select expression")
}
columns, err := mapColumnsToAliases(selectExpr, schemaStore)
if err != nil {
t.Fatal(err)
}
if len(columns) != len(tcase.expectedValues) {
t.Fatal("Returned incorrect length of values")
}

for y, column := range columns {
if column == nil {
if tcase.expectedValues[y] != nil {
t.Fatalf("[%d] expected nil column value ", i)
}
continue
}

if *column != *tcase.expectedValues[y] {
t.Fatalf("[%d] Column info is not equal to expected - %+v, actual - %+v", i, tcase.expectedValues[i], *column)
}
}
}
})
t.Run("Join enumeration fields query", func(t *testing.T) {
queries := []string{
`select table1.number, from_number, to_number, type, amount, created_date
Expand Down Expand Up @@ -168,7 +289,7 @@ inner join table6 on table6.col1=t1.col1
t.Fatal("Test query should be Select expression")
}

columns, err := mapColumnsToAliases(selectExpr)
columns, err := mapColumnsToAliases(selectExpr, &config.MapTableSchemaStore{})
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -226,7 +347,7 @@ inner join table6 on table6.col1=t1.col1
t.Fatal("Test query should be Select expression")
}

columns, err := mapColumnsToAliases(selectExpr)
columns, err := mapColumnsToAliases(selectExpr, &config.MapTableSchemaStore{})
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -261,7 +382,7 @@ inner join table6 on table6.col1=t1.col1

expectedValue := columnInfo{Alias: "*", Table: "test_table", Name: "*"}

columns, err := mapColumnsToAliases(selectExpr)
columns, err := mapColumnsToAliases(selectExpr, &config.MapTableSchemaStore{})
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -307,7 +428,7 @@ inner join table6 on table6.col1=t1.col1
t.Fatal("Test query should be Select expression")
}

columns, err := mapColumnsToAliases(selectExpr)
columns, err := mapColumnsToAliases(selectExpr, &config.MapTableSchemaStore{})
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -345,7 +466,7 @@ inner join table6 on table6.col1=t1.col1
{Alias: allColumnsName, Table: "test_table", Name: allColumnsName},
}

columns, err := mapColumnsToAliases(selectExpr)
columns, err := mapColumnsToAliases(selectExpr, &config.MapTableSchemaStore{})
if err != nil {
t.Fatal(err)
}
Expand Down
Loading