Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Collations Module Integration #9018

Merged
merged 10 commits into from
Nov 3, 2021
7 changes: 5 additions & 2 deletions go/vt/vtgate/engine/comparer.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package engine

import (
"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/vtgate/evalengine"
)
Expand All @@ -38,7 +39,8 @@ func (c *comparer) compare(r1, r2 []sqltypes.Value) (int, error) {
} else {
colIndex = c.orderBy
}
cmp, err := evalengine.NullsafeCompare(r1[colIndex], r2[colIndex])
// TODO(king-11) make collation aware
cmp, err := evalengine.NullsafeCompare(r1[colIndex], r2[colIndex], collations.Unknown)
if err != nil {
_, isComparisonErr := err.(evalengine.UnsupportedComparisonError)
if !(isComparisonErr && c.weightString != -1) {
Expand All @@ -47,7 +49,8 @@ func (c *comparer) compare(r1, r2 []sqltypes.Value) (int, error) {
// in case of a comparison error switch to using the weight string column for ordering
c.orderBy = c.weightString
c.weightString = -1
cmp, err = evalengine.NullsafeCompare(r1[c.orderBy], r2[c.orderBy])
// TODO(king-11) make collation aware
cmp, err = evalengine.NullsafeCompare(r1[c.orderBy], r2[c.orderBy], collations.Unknown)
if err != nil {
return 0, err
}
Expand Down
4 changes: 3 additions & 1 deletion go/vt/vtgate/engine/distinct.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package engine

import (
"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/vtgate/evalengine"
Expand Down Expand Up @@ -73,7 +74,8 @@ func (pt *probeTable) exists(inputRow row) (bool, error) {

func equal(a, b []sqltypes.Value) (bool, error) {
for i, aVal := range a {
cmp, err := evalengine.NullsafeCompare(aVal, b[i])
// TODO(king-11) make collation aware
cmp, err := evalengine.NullsafeCompare(aVal, b[i], collations.Unknown)
if err != nil {
return false, err
}
Expand Down
36 changes: 24 additions & 12 deletions go/vt/vtgate/engine/ordered_aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,17 @@ func (oa *OrderedAggregate) GetTableName() string {
return oa.Input.GetTableName()
}

// getCollations specifies the collation ID value for columns.
func (oa *OrderedAggregate) getCollations() map[int]collations.ID {
colls := make(map[int]collations.ID)
for _, key := range oa.GroupByKeys {
if key.CollationID != collations.Unknown {
colls[key.KeyCol] = key.CollationID
}
}
return colls
}

// SetTruncateColumnCount sets the truncate column count.
func (oa *OrderedAggregate) SetTruncateColumnCount(count int) {
oa.TruncateColumnCount = count
Expand All @@ -209,6 +220,7 @@ func (oa *OrderedAggregate) execute(vcursor VCursor, bindVars map[string]*queryp
Fields: oa.convertFields(result.Fields),
Rows: make([][]sqltypes.Value, 0, len(result.Rows)),
}
colls := oa.getCollations()
// This code is similar to the one in StreamExecute.
var current []sqltypes.Value
var curDistincts []sqltypes.Value
Expand All @@ -217,14 +229,13 @@ func (oa *OrderedAggregate) execute(vcursor VCursor, bindVars map[string]*queryp
current, curDistincts = oa.convertRow(row)
continue
}

equal, err := oa.keysEqual(current, row)
equal, err := oa.keysEqual(current, row, colls)
if err != nil {
return nil, err
}

if equal {
current, curDistincts, err = oa.merge(result.Fields, current, row, curDistincts)
current, curDistincts, err = oa.merge(result.Fields, current, row, curDistincts, colls)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -271,20 +282,21 @@ func (oa *OrderedAggregate) TryStreamExecute(vcursor VCursor, bindVars map[strin
return err
}
}
colls := oa.getCollations()
// This code is similar to the one in Execute.
for _, row := range qr.Rows {
if current == nil {
current, curDistincts = oa.convertRow(row)
continue
}

equal, err := oa.keysEqual(current, row)
equal, err := oa.keysEqual(current, row, colls)
if err != nil {
return err
}

if equal {
current, curDistincts, err = oa.merge(fields, current, row, curDistincts)
current, curDistincts, err = oa.merge(fields, current, row, curDistincts, colls)
if err != nil {
return err
}
Expand Down Expand Up @@ -395,16 +407,16 @@ func (oa *OrderedAggregate) NeedsTransaction() bool {
return oa.Input.NeedsTransaction()
}

func (oa *OrderedAggregate) keysEqual(row1, row2 []sqltypes.Value) (bool, error) {
func (oa *OrderedAggregate) keysEqual(row1, row2 []sqltypes.Value, colls map[int]collations.ID) (bool, error) {
for _, key := range oa.GroupByKeys {
cmp, err := evalengine.NullsafeCompare(row1[key.KeyCol], row2[key.KeyCol])
cmp, err := evalengine.NullsafeCompare(row1[key.KeyCol], row2[key.KeyCol], colls[key.KeyCol])
if err != nil {
_, isComparisonErr := err.(evalengine.UnsupportedComparisonError)
if !(isComparisonErr && key.WeightStringCol != -1) {
return false, err
}
key.KeyCol = key.WeightStringCol
cmp, err = evalengine.NullsafeCompare(row1[key.WeightStringCol], row2[key.WeightStringCol])
cmp, err = evalengine.NullsafeCompare(row1[key.WeightStringCol], row2[key.WeightStringCol], colls[key.KeyCol])
if err != nil {
return false, err
}
Expand All @@ -416,14 +428,14 @@ func (oa *OrderedAggregate) keysEqual(row1, row2 []sqltypes.Value) (bool, error)
return true, nil
}

func (oa *OrderedAggregate) merge(fields []*querypb.Field, row1, row2 []sqltypes.Value, curDistincts []sqltypes.Value) ([]sqltypes.Value, []sqltypes.Value, error) {
func (oa *OrderedAggregate) merge(fields []*querypb.Field, row1, row2 []sqltypes.Value, curDistincts []sqltypes.Value, colls map[int]collations.ID) ([]sqltypes.Value, []sqltypes.Value, error) {
result := sqltypes.CopyRow(row1)
for index, aggr := range oa.Aggregates {
if aggr.isDistinct() {
if row2[aggr.KeyCol].IsNull() {
continue
}
cmp, err := evalengine.NullsafeCompare(curDistincts[index], row2[aggr.KeyCol])
cmp, err := evalengine.NullsafeCompare(curDistincts[index], row2[aggr.KeyCol], colls[aggr.KeyCol])
if err != nil {
return nil, nil, err
}
Expand All @@ -439,9 +451,9 @@ func (oa *OrderedAggregate) merge(fields []*querypb.Field, row1, row2 []sqltypes
v2 := row2[aggr.Col]
result[aggr.Col] = evalengine.NullsafeAdd(value, v2, fields[aggr.Col].Type)
case AggregateMin:
result[aggr.Col], err = evalengine.Min(row1[aggr.Col], row2[aggr.Col])
result[aggr.Col], err = evalengine.Min(row1[aggr.Col], row2[aggr.Col], colls[aggr.Col])
case AggregateMax:
result[aggr.Col], err = evalengine.Max(row1[aggr.Col], row2[aggr.Col])
result[aggr.Col], err = evalengine.Max(row1[aggr.Col], row2[aggr.Col], colls[aggr.Col])
case AggregateCountDistinct:
result[aggr.Col] = evalengine.NullsafeAdd(row1[aggr.Col], countOne, OpcodeType[aggr.Opcode])
case AggregateSumDistinct:
Expand Down
138 changes: 136 additions & 2 deletions go/vt/vtgate/engine/ordered_aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/test/utils"

Expand Down Expand Up @@ -652,13 +653,13 @@ func TestMerge(t *testing.T) {
"1|3|2.8|2|bc",
)

merged, _, err := oa.merge(fields, r.Rows[0], r.Rows[1], nil)
merged, _, err := oa.merge(fields, r.Rows[0], r.Rows[1], nil, nil)
assert.NoError(err)
want := sqltypes.MakeTestResult(fields, "1|5|6|2|bc").Rows[0]
assert.Equal(want, merged)

// swap and retry
merged, _, err = oa.merge(fields, r.Rows[1], r.Rows[0], nil)
merged, _, err = oa.merge(fields, r.Rows[1], r.Rows[0], nil, nil)
assert.NoError(err)
assert.Equal(want, merged)
}
Expand Down Expand Up @@ -1050,3 +1051,136 @@ func TestMultiDistinct(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, want, results)
}

func TestOrderedAggregateCollate(t *testing.T) {
assert := assert.New(t)
fields := sqltypes.MakeTestFields(
"col|count(*)",
"varchar|decimal",
)
fp := &fakePrimitive{
results: []*sqltypes.Result{sqltypes.MakeTestResult(
fields,
"a|1",
"A|1",
"Ǎ|1",
"b|2",
"B|-1",
"c|3",
"c|4",
"ß|11",
"ss|2",
)},
}

collationID, _ := collations.IDFromName("utf8mb4_0900_ai_ci")
oa := &OrderedAggregate{
Aggregates: []*AggregateParams{{
Opcode: AggregateCount,
Col: 1,
}},
GroupByKeys: []*GroupByParams{{KeyCol: 0, CollationID: collationID}},
Input: fp,
}

result, err := oa.TryExecute(&noopVCursor{}, nil, false)
assert.NoError(err)

wantResult := sqltypes.MakeTestResult(
fields,
"a|3",
"b|1",
"c|7",
"ß|13",
)
assert.Equal(wantResult, result)
}

func TestOrderedAggregateCollateAS(t *testing.T) {
assert := assert.New(t)
fields := sqltypes.MakeTestFields(
"col|count(*)",
"varchar|decimal",
)
fp := &fakePrimitive{
results: []*sqltypes.Result{sqltypes.MakeTestResult(
fields,
"a|1",
"A|1",
"Ǎ|1",
"b|2",
"c|3",
"c|4",
"Ç|4",
)},
}

collationID, _ := collations.IDFromName("utf8mb4_0900_as_ci")
oa := &OrderedAggregate{
Aggregates: []*AggregateParams{{
Opcode: AggregateCount,
Col: 1,
}},
GroupByKeys: []*GroupByParams{{KeyCol: 0, CollationID: collationID}},
Input: fp,
}

result, err := oa.TryExecute(&noopVCursor{}, nil, false)
assert.NoError(err)

wantResult := sqltypes.MakeTestResult(
fields,
"a|2",
"Ǎ|1",
"b|2",
"c|7",
"Ç|4",
)
assert.Equal(wantResult, result)
}

func TestOrderedAggregateCollateKS(t *testing.T) {
assert := assert.New(t)
fields := sqltypes.MakeTestFields(
"col|count(*)",
"varchar|decimal",
)
fp := &fakePrimitive{
results: []*sqltypes.Result{sqltypes.MakeTestResult(
fields,
"a|1",
"A|1",
"Ǎ|1",
"b|2",
"c|3",
"c|4",
"\xE3\x83\x8F\xE3\x81\xAF|2",
"\xE3\x83\x8F\xE3\x83\x8F|1",
)},
}

collationID, _ := collations.IDFromName("utf8mb4_ja_0900_as_cs_ks")
oa := &OrderedAggregate{
Aggregates: []*AggregateParams{{
Opcode: AggregateCount,
Col: 1,
}},
GroupByKeys: []*GroupByParams{{KeyCol: 0, CollationID: collationID}},
Input: fp,
}

result, err := oa.TryExecute(&noopVCursor{}, nil, false)
assert.NoError(err)

wantResult := sqltypes.MakeTestResult(
fields,
"a|1",
"A|1",
"Ǎ|1",
"b|2",
"c|7",
"\xE3\x83\x8F\xE3\x81\xAF|2",
"\xE3\x83\x8F\xE3\x83\x8F|1",
)
assert.Equal(wantResult, result)
}
Loading