Skip to content

Commit

Permalink
feat!: transpile filters to expressions with parameters
Browse files Browse the repository at this point in the history
Instead of transpiling filters to a single spansql.BoolExpr with
spansql.[...]Literal, filters are transpiled to a spansql.BoolExpr with
every literal returned as a parameter.

This fixes potential security issue where malicous filters are
transpiled directly to SQL without proper escaping.

Parameters are automatically named using a counter based on order of
appearance in the expr.Expr AST. This could be prettier, but atleast
provides a deterministic naming which *could* improve ability to cache
the filter (this is an assumption).
  • Loading branch information
ericwenn committed Feb 18, 2021
1 parent 0f127dc commit 64cf475
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 34 deletions.
6 changes: 4 additions & 2 deletions spanfiltering/transpile.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import (
"go.einride.tech/aip/filtering"
)

// TranspileFilter transpiles a parsed AIP filter expression to a spansql.BoolExpr.
func TranspileFilter(filter filtering.Filter) (spansql.BoolExpr, error) {
// TranspileFilter transpiles a parsed AIP filter expression to a spansql.BoolExpr, and
// parameters used in the expression.
// The parameter map is nil if the expression does not contain any parameters.
func TranspileFilter(filter filtering.Filter) (spansql.BoolExpr, map[string]interface{}, error) {
var t Transpiler
t.Init(filter)
return t.Transpile()
Expand Down
39 changes: 29 additions & 10 deletions spanfiltering/transpile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package spanfiltering

import (
"testing"
"time"

syntaxv1 "go.einride.tech/aip/examples/proto/gen/einride/example/syntax/v1"
"go.einride.tech/aip/filtering"
Expand All @@ -11,11 +12,12 @@ import (
func TestTranspileFilter(t *testing.T) {
t.Parallel()
for _, tt := range []struct {
name string
filter string
declarations []filtering.DeclarationOption
expectedSQL string
errorContains string
name string
filter string
declarations []filtering.DeclarationOption
expectedSQL string
expectedParams map[string]interface{}
errorContains string
}{
{
name: "simple flag",
Expand Down Expand Up @@ -44,17 +46,23 @@ func TestTranspileFilter(t *testing.T) {
filtering.DeclareIdent("author", filtering.TypeString),
filtering.DeclareIdent("read", filtering.TypeBool),
},
expectedSQL: `((author = "Karin Boye") AND (NOT read))`,
expectedSQL: `((author = @param_0) AND (NOT read))`,
expectedParams: map[string]interface{}{
"param_0": "Karin Boye",
},
},

{
name: "string equality and flag",
name: "timestamp",
filter: `create_time > timestamp("2021-02-14T14:49:34+01:00")`,
declarations: []filtering.DeclarationOption{
filtering.DeclareStandardFunctions(),
filtering.DeclareIdent("create_time", filtering.TypeTimestamp),
},
expectedSQL: `(create_time > (TIMESTAMP '2021-02-14 14:49:34.000000 +01:00'))`,
expectedSQL: `(create_time > (@param_0))`,
expectedParams: map[string]interface{}{
"param_0": mustParseTime(t, "2021-02-14T14:49:34+01:00"),
},
},

{
Expand All @@ -63,7 +71,10 @@ func TestTranspileFilter(t *testing.T) {
declarations: []filtering.DeclarationOption{
filtering.DeclareEnumIdent("example_enum", syntaxv1.Enum(0).Type()),
},
expectedSQL: `(example_enum = 1)`,
expectedSQL: `(example_enum = @param_0)`,
expectedParams: map[string]interface{}{
"param_0": int64(1),
},
},

{
Expand All @@ -86,17 +97,25 @@ func TestTranspileFilter(t *testing.T) {
return
}
assert.NilError(t, err)
actual, err := TranspileFilter(filter)
actual, params, err := TranspileFilter(filter)
if err != nil && tt.errorContains != "" {
assert.ErrorContains(t, err, tt.errorContains)
return
}
assert.NilError(t, err)
assert.Equal(t, tt.expectedSQL, actual.SQL())
assert.DeepEqual(t, tt.expectedParams, params)
})
}
}

func mustParseTime(t *testing.T, s string) time.Time {
t.Helper()
tm, err := time.Parse(time.RFC3339, s)
assert.NilError(t, err)
return tm
}

type mockRequest struct {
filter string
}
Expand Down
66 changes: 44 additions & 22 deletions spanfiltering/transpiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package spanfiltering

import (
"fmt"
"strconv"
"time"

"cloud.google.com/go/spanner/spansql"
Expand All @@ -12,28 +13,35 @@ import (
)

type Transpiler struct {
filter filtering.Filter
filter filtering.Filter
params map[string]interface{}
paramCounter int
}

func (t *Transpiler) Init(filter filtering.Filter) {
*t = Transpiler{
filter: filter,
params: make(map[string]interface{}),
}
}

func (t *Transpiler) Transpile() (spansql.BoolExpr, error) {
func (t *Transpiler) Transpile() (spansql.BoolExpr, map[string]interface{}, error) {
if t.filter.CheckedExpr == nil {
return spansql.True, nil
return spansql.True, nil, nil
}
resultExpr, err := t.transpileExpr(t.filter.CheckedExpr.Expr)
if err != nil {
return nil, err
return nil, nil, err
}
resultBoolExpr, ok := resultExpr.(spansql.BoolExpr)
if !ok {
return nil, fmt.Errorf("not a bool expr")
return nil, nil, fmt.Errorf("not a bool expr")
}
params := t.params
if t.paramCounter == 0 {
params = nil
}
return resultBoolExpr, nil
return resultBoolExpr, params, nil
}

func (t *Transpiler) transpileExpr(e *expr.Expr) (spansql.Expr, error) {
Expand All @@ -58,15 +66,16 @@ func (t *Transpiler) transpileExpr(e *expr.Expr) (spansql.Expr, error) {
func (t *Transpiler) transpileConstExpr(e *expr.Expr) (spansql.Expr, error) {
switch kind := e.GetConstExpr().ConstantKind.(type) {
case *expr.Constant_BoolValue:
return spansql.BoolLiteral(kind.BoolValue), nil
return t.param(kind.BoolValue), nil
case *expr.Constant_DoubleValue:
return spansql.FloatLiteral(kind.DoubleValue), nil
return t.param(kind.DoubleValue), nil
case *expr.Constant_Int64Value:
return spansql.IntegerLiteral(kind.Int64Value), nil
return t.param(kind.Int64Value), nil
case *expr.Constant_StringValue:
return spansql.StringLiteral(kind.StringValue), nil
return t.param(kind.StringValue), nil
case *expr.Constant_Uint64Value:
return spansql.IntegerLiteral(kind.Uint64Value), nil
// spanner does not support uint64
return t.param(int64(kind.Uint64Value)), nil
default:
return nil, fmt.Errorf("unsupported const expr: %v", kind)
}
Expand Down Expand Up @@ -109,7 +118,8 @@ func (t *Transpiler) transpileIdentExpr(e *expr.Expr) (spansql.Expr, error) {
if enumType, err := protoregistry.GlobalTypes.FindEnumByName(protoreflect.FullName(messageType)); err == nil {
if enumValue := enumType.Descriptor().Values().ByName(protoreflect.Name(identExpr.Name)); enumValue != nil {
// TODO: Configurable support for string literals.
return spansql.IntegerLiteral(enumValue.Number()), nil
// spanner does not support int32
return t.param(int64(enumValue.Number())), nil
}
}
}
Expand Down Expand Up @@ -225,24 +235,36 @@ func (t *Transpiler) transpileHasCallExpr(e *expr.Expr) (spansql.BoolExpr, error
return nil, fmt.Errorf("TODO: add support for transpiling `:`")
}

func (t *Transpiler) transpileTimestampCallExpr(e *expr.Expr) (spansql.TimestampLiteral, error) {
func (t *Transpiler) transpileTimestampCallExpr(e *expr.Expr) (spansql.Expr, error) {
callExpr := e.GetCallExpr()
if len(callExpr.Args) != 1 {
return spansql.TimestampLiteral{}, fmt.Errorf(
return nil, fmt.Errorf(
"unexpected number of arguments to `%s`: %d", callExpr.Function, len(callExpr.Args),
)
}
arg, err := t.transpileExpr(callExpr.Args[0])
if err != nil {
return spansql.TimestampLiteral{}, err
constArg, ok := callExpr.Args[0].ExprKind.(*expr.Expr_ConstExpr)
if !ok {
return nil, fmt.Errorf("expected constant string arg to %s", callExpr.Function)
}
stringArg, ok := arg.(spansql.StringLiteral)
stringArg, ok := constArg.ConstExpr.ConstantKind.(*expr.Constant_StringValue)
if !ok {
return spansql.TimestampLiteral{}, fmt.Errorf("expected string arg to %s", callExpr.Function)
return nil, fmt.Errorf("expected constant string arg to %s", callExpr.Function)
}
timeArg, err := time.Parse(time.RFC3339, string(stringArg))
timeArg, err := time.Parse(time.RFC3339, stringArg.StringValue)
if err != nil {
return spansql.TimestampLiteral{}, fmt.Errorf("invalid string arg to %s: %w", callExpr.Function, err)
return nil, fmt.Errorf("invalid string arg to %s: %w", callExpr.Function, err)
}
return spansql.TimestampLiteral(timeArg), nil
return t.param(timeArg), nil
}

func (t *Transpiler) param(param interface{}) spansql.Param {
p := t.nextParam()
t.params[p] = param
return spansql.Param(p)
}

func (t *Transpiler) nextParam() string {
param := "param_" + strconv.Itoa(t.paramCounter)
t.paramCounter++
return param
}

0 comments on commit 64cf475

Please sign in to comment.