Skip to content

Commit

Permalink
Merge pull request #18266 from richardwu/unwrap-tuples
Browse files Browse the repository at this point in the history
sql: properly handle parenthesized/nested tuples as operands for ANY/ALL operations
  • Loading branch information
richardwu authored Sep 9, 2017
2 parents 9a3fdbb + 7a02e3a commit 933c40a
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 24 deletions.
32 changes: 30 additions & 2 deletions pkg/sql/logictest/testdata/logic_test/suboperators
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ SELECT 1 = ANY(ARRAY[1, 2])
----
true

query B
SELECT 1 = ANY (((ARRAY[1, 2])))
----
true

query B
SELECT 1 = SOME(ARRAY[1, 2])
----
Expand All @@ -23,6 +28,11 @@ SELECT 1 = ANY(ARRAY[3, 4])
----
false

query B
SELECT 1 = ANY (((ARRAY[3, 4])))
----
false

query B
SELECT 1 < ANY(ARRAY[0, 5])
----
Expand Down Expand Up @@ -96,7 +106,7 @@ query III
SELECT * FROM abc WHERE a = ANY(ARRAY[NULL, NULL])
----

query error unsupported comparison operator: 1 = ANY \(ARRAY\['foo', 'bar'\]\)
query error unsupported comparison operator: 1 = ANY ARRAY\['foo', 'bar'\]
SELECT 1 = ANY(ARRAY['foo', 'bar'])

query error unsupported comparison operator: <int> = ANY <string\[\]>
Expand Down Expand Up @@ -255,7 +265,7 @@ query III
SELECT * FROM abc WHERE a > ALL(ARRAY[NULL, NULL])
----

query error unsupported comparison operator: 1 = ALL \(ARRAY\['foo', 'bar'\]\)
query error unsupported comparison operator: 1 = ALL ARRAY\['foo', 'bar'\]
SELECT 1 = ALL(ARRAY['foo', 'bar'])

query error unsupported comparison operator: <int> = ALL <string\[\]>
Expand Down Expand Up @@ -332,11 +342,21 @@ SELECT 1 = ANY (1, 2, 3)
----
true

query B
SELECT 1 = ANY (((1, 2, 3)))
----
true

query B
SELECT 1 = ANY (2, 3, 4)
----
false

query B
SELECT 1 = ANY (((2, 3, 4)))
----
false

query error incompatible tuple element type: decimal
SELECT 1 = ANY (1, 1.1)

Expand All @@ -348,11 +368,19 @@ true
query error incompatible tuple element type: decimal
SELECT 1 = ANY (1.0, 1.1)

query error incompatible tuple element type: decimal
SELECT 1 = ANY (((1.0, 1.1)))

query B
SELECT 1::decimal = ANY (1.0, 1.1)
----
true

query B
SELECT 1::decimal = ANY (((1.0, 1.1)))
----
true

query error could not parse \"hello\" as type int
SELECT 1 = ANY (1, 'hello', 3)

Expand Down
35 changes: 16 additions & 19 deletions pkg/sql/parser/type_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -603,9 +603,9 @@ func (expr *ParenExpr) TypeCheck(ctx *SemaContext, desired Type) (TypedExpr, err
if err != nil {
return nil, err
}
expr.Expr = exprTyped
expr.typ = exprTyped.ResolvedType()
return expr, nil
// Parentheses are semantically unimportant and can be removed/replaced
// with its nested expression in our plan. This makes type checking cleaner.
return exprTyped, nil
}

// presetTypesForTesting is a mapping of qualified names to types that can be mocked out
Expand Down Expand Up @@ -908,7 +908,10 @@ func (d dNull) TypeCheck(_ *SemaContext, desired Type) (TypedExpr, error) { retu
// typeCheckAndRequireTupleElems asserts that all elements in the Tuple
// can be typed as required and are equivalent to required. Note that one would invoke
// with the required element type and NOT TTuple (as opposed to how Tuple.TypeCheck operates).
// For example, (1, 2.5) with required TypeDecimal would raise a sane error whereas (1.0, 2.5) with required TypeDecimal would pass.
// For example, (1, 2.5) with required TypeDecimal would raise a sane error whereas (1.0, 2.5)
// with required TypeDecimal would pass.
//
// It is only valid to pass in a Tuple expression
func typeCheckAndRequireTupleElems(ctx *SemaContext, expr Expr, required Type) (TypedExpr, error) {
tuple := expr.(*Tuple)
tuple.types = make(TTuple, len(tuple.Exprs))
Expand Down Expand Up @@ -956,14 +959,19 @@ const (
func typeCheckComparisonOpWithSubOperator(
ctx *SemaContext, op, subOp ComparisonOperator, left, right Expr,
) (TypedExpr, TypedExpr, CmpOp, error) {
// Parentheses are semantically unimportant and can be removed/replaced
// with its nested expression in our plan. This makes type checking cleaner.
left = StripParens(left)
right = StripParens(right)

// Determine the set of comparisons are possible for the sub-operation,
// which will be memoized.
foldedOp, _, _, _, _ := foldComparisonExpr(subOp, nil, nil)
ops := CmpOps[foldedOp]

var cmpTypeLeft, cmpTypeRight Type
var leftTyped, rightTyped TypedExpr
if array, isConstructor := StripParens(right).(*Array); isConstructor {
if array, isConstructor := right.(*Array); isConstructor {
// If the right expression is an (optionally nested) array constructor, we
// perform type inference on the array elements and the left expression.
sameTypeExprs := make([]Expr, len(array.Exprs)+1)
Expand All @@ -988,16 +996,7 @@ func typeCheckComparisonOpWithSubOperator(
}
array.typ = TArray{retType}

rightParen := right
for {
if p, ok := rightParen.(*ParenExpr); ok {
p.typ = array.typ
rightParen = p.Expr
continue
}
break
}
rightTyped = right.(TypedExpr)
rightTyped = array
cmpTypeRight = retType

// Return early without looking up a CmpOp if the comparison type is TypeNull.
Expand All @@ -1014,12 +1013,10 @@ func typeCheckComparisonOpWithSubOperator(
}
cmpTypeLeft = leftTyped.ResolvedType()

// TODO(richardwu): Write an Unwrap function for BinaryExpr to handle the case where the tuple
// is nested within ParenExpr.
if _, ok := right.(*Tuple); ok {
if tuple, ok := right.(*Tuple); ok {
// If right expression is a tuple, we require that all elements' inferred
// type is equivalent to the left's type.
rightTyped, err = typeCheckAndRequireTupleElems(ctx, right, cmpTypeLeft)
rightTyped, err = typeCheckAndRequireTupleElems(ctx, tuple, cmpTypeLeft)
if err != nil {
return nil, nil, CmpOp{}, err
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/sql/parser/type_check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func TestTypeCheck(t *testing.T) {
{`NULL || 'hello'::bytes`, `NULL`},
{`NULL::int`, `NULL::INT`},
{`INTERVAL '1s'`, `'1s':::INTERVAL`},
{`(1.1::decimal)::decimal`, `(1.1:::DECIMAL::DECIMAL)::DECIMAL`},
{`(1.1::decimal)::decimal`, `1.1:::DECIMAL::DECIMAL::DECIMAL`},
{`NULL = 1`, `NULL`},
{`1 = NULL`, `NULL`},
{`true AND NULL`, `true AND NULL`},
Expand Down Expand Up @@ -82,9 +82,9 @@ func TestTypeCheck(t *testing.T) {
{`ARRAY[1.5, 2.5, 3.5]`, `ARRAY[1.5:::DECIMAL, 2.5:::DECIMAL, 3.5:::DECIMAL]`},
{`ARRAY[NULL]`, `ARRAY[NULL]`},
{`1 = ANY ARRAY[1.5, 2.5, 3.5]`, `1:::DECIMAL = ANY ARRAY[1.5:::DECIMAL, 2.5:::DECIMAL, 3.5:::DECIMAL]`},
{`true = SOME (ARRAY[true, false])`, `true = SOME (ARRAY[true, false])`},
{`true = SOME (ARRAY[true, false])`, `true = SOME ARRAY[true, false]`},
{`1.3 = ALL ARRAY[1, 2, 3]`, `1.3:::DECIMAL = ALL ARRAY[1:::DECIMAL, 2:::DECIMAL, 3:::DECIMAL]`},
{`1.3 = ALL ((ARRAY[]))`, `1.3:::DECIMAL = ALL ((ARRAY[]))`},
{`1.3 = ALL ((ARRAY[]))`, `1.3:::DECIMAL = ALL ARRAY[]`},
{`NULL = ALL ARRAY[1.5, 2.5, 3.5]`, `NULL`},
{`NULL = ALL ARRAY[NULL, NULL]`, `NULL`},
{`1 = ALL NULL`, `NULL`},
Expand Down

0 comments on commit 933c40a

Please sign in to comment.