Skip to content

Commit

Permalink
Merge pull request #9097 from planetscale/default-collation
Browse files Browse the repository at this point in the history
Add default collation to VTGate and VTTablet
  • Loading branch information
vmg authored Nov 16, 2021
2 parents b89ffe5 + 5936dff commit 9085201
Show file tree
Hide file tree
Showing 64 changed files with 2,086 additions and 1,399 deletions.
5 changes: 2 additions & 3 deletions go/cmd/vttablet/vttablet.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ package main

import (
"bytes"
"context"
"flag"
"os"

"context"

"vitess.io/vitess/go/vt/binlog"
"vitess.io/vitess/go/vt/dbconfigs"
"vitess.io/vitess/go/vt/log"
Expand Down Expand Up @@ -91,7 +90,7 @@ func main() {
if servenv.GRPCPort != nil {
gRPCPort = int32(*servenv.GRPCPort)
}
tablet, err := tabletmanager.BuildTabletFromInput(tabletAlias, int32(*servenv.Port), gRPCPort)
tablet, err := tabletmanager.BuildTabletFromInput(tabletAlias, int32(*servenv.Port), gRPCPort, mysqld.GetVersionString())
if err != nil {
log.Exitf("failed to parse -tablet-path: %v", err)
}
Expand Down
3 changes: 2 additions & 1 deletion go/cmd/vttestserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ func init() {
flag.IntVar(&topo.rdonly, "rdonly_count", 1,
"Rdonly tablets per shard")

flag.StringVar(&config.Charset, "charset", "utf8", "MySQL charset")
flag.StringVar(&config.Charset, "charset", "utf8mb4", "MySQL charset")
flag.StringVar(&config.Collation, "collation", "", "MySQL collation")
flag.StringVar(&config.SnapshotFile, "snapshot_file", "",
"A MySQL DB snapshot file")

Expand Down
145 changes: 109 additions & 36 deletions go/mysql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,17 @@ limitations under the License.
package mysql

import (
"context"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"net"
"strconv"
"strings"
"time"

"context"

"vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/mysql/collations"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vttls"
)
Expand Down Expand Up @@ -63,12 +61,6 @@ func Connect(ctx context.Context, params *ConnParams) (*Conn, error) {
addr = net.JoinHostPort(params.Host, fmt.Sprintf("%v", params.Port))
}

// Figure out the character set we want.
characterSet, err := parseCharacterSet(params.Charset)
if err != nil {
return nil, err
}

// Start a background connection routine. It first
// establishes a network connection, returns it on the channel,
// then starts the negotiation, and returns the result on the channel.
Expand Down Expand Up @@ -123,7 +115,7 @@ func Connect(ctx context.Context, params *ConnParams) (*Conn, error) {
// make any read or write just return with an error
// right away.
status <- connectResult{
err: c.clientHandshake(characterSet, params),
err: c.clientHandshake(params),
}
}()

Expand Down Expand Up @@ -174,6 +166,15 @@ func Connect(ctx context.Context, params *ConnParams) (*Conn, error) {
return nil, cr.err
}
}

// Once we are connected to the server, we set the collation for this connection.
// This step usually occurs during the handshake, however, the handshake protocol
// grants us 8 bits for the collation ID, which is lower than the range of supported
// collations. For this reason, we manually set the collation for the connection.
if err := setCollationForConnection(c, params); err != nil {
return nil, err
}

return c, nil
}

Expand All @@ -198,36 +199,86 @@ func (c *Conn) Ping() error {
case ErrPacket:
return ParseErrorPacket(data)
}
return vterrors.Errorf(vtrpc.Code_INTERNAL, "unexpected packet type: %d", data[0])
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unexpected packet type: %d", data[0])
}

// parseCharacterSet parses the provided character set.
// Returns SQLError(CRCantReadCharset) if it can't.
func parseCharacterSet(cs string) (uint8, error) {
// Check if it's empty, return utf8. This is a reasonable default.
if cs == "" {
return CharacterSetUtf8, nil
// setCollationForConnection sets the connection's collation to the given collation.
//
// The charset should always be set as it has a default value ("utf8mb4"),
// however, one can always override its default to an empty string, which
// is not a problem as long as the user has specified the collation.
// If the collation flag was not specified when starting the tablet, we
// attempt to find the default collation for the current charset.
// If either the collation and charset are missing, or the resolution of
// the default collation using the given charset fails, we error out.
//
// This method is also responsible for creating and storing the collation
// environment that will be used by this connection. The collation environment
// allows us to make informed decisions around charset's default collation
// depending on the MySQL/MariaDB version we are using.
func setCollationForConnection(c *Conn, params *ConnParams) error {
// Once we have done the initial handshake with MySQL, we receive the server version
// string. This string is critical as it enables the instantiation of a new collation
// environment variable.
// Certain MySQL or MariaDB versions might have different default collations for some
// charsets, so it is important to use a database-version-aware collation system/API.
env := collations.NewEnvironment(c.ServerVersion)

// if there is no collation or charset, we default to utf8mb4
charset := params.Charset
if params.Collation == "" && charset == "" {
charset = "utf8mb4"
}

var coll collations.Collation
if params.Collation == "" {
// If there is no collation we will just use the charset's default collation
// otherwise we directly use the given collation.
coll = env.DefaultCollationForCharset(charset)
} else {
// Here we call the collations API to ensure the collation/charset exist
// and is supported by Vitess.
coll = env.LookupByName(params.Collation)
}

// Check if it's in our map.
characterSet, ok := CharacterSetMap[strings.ToLower(cs)]
if ok {
return characterSet, nil
if coll == nil {
// The given collation is most likely unknown or unsupported, we need to fail.
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "cannot resolve collation: '%s'", params.Collation)
}

// As a fallback, try to parse a number. So we support more values.
if i, err := strconv.ParseInt(cs, 10, 8); err == nil {
return uint8(i), nil
// We send a query to MySQL to set the connection's collation.
// See: https://dev.mysql.com/doc/refman/8.0/en/charset-connection.html
querySetCollation := fmt.Sprintf("SET collation_connection = %s;", coll.Name())
if _, err := c.ExecuteFetch(querySetCollation, 1, false); err != nil {
return err
}

// No luck.
return 0, NewSQLError(CRCantReadCharset, SSUnknownSQLState, "failed to interpret character set '%v'. Try using an integer value if needed", cs)
// The collation environment is stored inside the connection parameters struct.
// We will use it to verify that execution requests issued by VTGate match the
// same collation as the one used to communicate with MySQL.
c.CollationEnvironment = env
c.Collation = coll.ID()
return nil
}

// getHandshakeCharacterSet returns the collation ID of DefaultCollation in an
// 8 bits integer which will be used to feed the handshake protocol's packet.
func getHandshakeCharacterSet() (uint8, error) {
coll := collations.Default().LookupByName(DefaultCollation)
if coll == nil {
// theoretically, this should never happen from an end user perspective
return 0, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "cannot resolve collation ID for collation: '%s'", DefaultCollation)
}
if coll.ID() > 255 {
// same here, this should never happen
return 0, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "collation ID for '%s' will overflow, value: %d", DefaultCollation, coll.ID())
}
return uint8(coll.ID()), nil
}

// clientHandshake handles the client side of the handshake.
// Note the connection can be closed while this is running.
// Returns a SQLError.
func (c *Conn) clientHandshake(characterSet uint8, params *ConnParams) error {
func (c *Conn) clientHandshake(params *ConnParams) error {
// Wait for the server initial handshake packet, and parse it.
data, err := c.readPacket()
if err != nil {
Expand All @@ -252,6 +303,28 @@ func (c *Conn) clientHandshake(characterSet uint8, params *ConnParams) error {
c.Capabilities = capabilities & (CapabilityClientDeprecateEOF)
}

// The MySQL handshake package uses the "character set" field to define
// which character set must be used. But, the value we give to this field
// correspond in fact to the collation ID. MySQL will then deduce what the
// character set for this collation ID is, and use it.
// Problem is, this field is 8-bits long meaning that the ID can range from
// 0 to 255, which is smaller than the range of IDs we support.
// If, for instance, we used the collation "utf8mb4_0900_as_ci" that has an
// ID equal to 305, the value would overflow when transformed into an 8 bits
// integer.
// To alleviate this issue, we use a default and safe collation for the handshake
// and once the connection is established, we will manually set the collation.
// The code below gets that default character set for the Handshake packet.
//
// Note: this character set might be different from the one we will use
// for the connection.
//
// See: https://dev.mysql.com/doc/internals/en/connection-phase-packets.html
characterSet, err := getHandshakeCharacterSet()
if err != nil {
return err
}

// Handle switch to SSL if necessary.
if params.SslEnabled() {
// If client asked for SSL, but server doesn't support it,
Expand Down Expand Up @@ -720,7 +793,7 @@ func (c *Conn) handleAuthMoreDataPacket(data byte, params *ConnParams) error {
// Encrypt password with public key
enc, err := EncryptPasswordWithPublicKey(c.salt, []byte(params.Pass), pub)
if err != nil {
return vterrors.Errorf(vtrpc.Code_INTERNAL, "error encrypting password with public key: %v", err)
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "error encrypting password with public key: %v", err)
}
// Write encrypted password
if err := c.writeScrambledPassword(enc); err != nil {
Expand All @@ -738,7 +811,7 @@ func parseAuthSwitchRequest(data []byte) (AuthMethodDescription, []byte, error)
pos := 1
pluginName, pos, ok := readNullString(data, pos)
if !ok {
return "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "cannot get plugin name from AuthSwitchRequest: %v", data)
return "", nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "cannot get plugin name from AuthSwitchRequest: %v", data)
}

// If this was a request with a salt in it, max 20 bytes
Expand All @@ -755,7 +828,7 @@ func (c *Conn) requestPublicKey() (rsaKey *rsa.PublicKey, err error) {
data, pos := c.startEphemeralPacketWithHeader(1)
data[pos] = 0x02
if err := c.writeEphemeralPacket(); err != nil {
return nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "error sending public key request packet: %v", err)
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "error sending public key request packet: %v", err)
}

response, err := c.readPacket()
Expand All @@ -771,7 +844,7 @@ func (c *Conn) requestPublicKey() (rsaKey *rsa.PublicKey, err error) {
block, _ := pem.Decode(response[1:])
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "failed to parse public key from server: %v", err)
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to parse public key from server: %v", err)
}

return pub.(*rsa.PublicKey), nil
Expand All @@ -785,7 +858,7 @@ func (c *Conn) writeClearTextPassword(params *ConnParams) error {
pos = writeNullString(data, pos, params.Pass)
// Sanity check.
if pos != len(data) {
return vterrors.Errorf(vtrpc.Code_INTERNAL, "error building ClearTextPassword packet: got %v bytes expected %v", pos, len(data))
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "error building ClearTextPassword packet: got %v bytes expected %v", pos, len(data))
}
return c.writeEphemeralPacket()
}
Expand All @@ -797,7 +870,7 @@ func (c *Conn) writeScrambledPassword(scrambledPassword []byte) error {
pos += copy(data[pos:], scrambledPassword)
// Sanity check.
if pos != len(data) {
return vterrors.Errorf(vtrpc.Code_INTERNAL, "error building %v packet: got %v bytes expected %v", c.authPluginName, pos, len(data))
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "error building %v packet: got %v bytes expected %v", c.authPluginName, pos, len(data))
}
return c.writeEphemeralPacket()
}
24 changes: 13 additions & 11 deletions go/mysql/collations/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
package collations

import (
"fmt"
"strings"
"sync"
)
Expand Down Expand Up @@ -124,29 +123,32 @@ func fetchCacheEnvironment(version collver) *Environment {
// NewEnvironment creates a collation Environment for the given MySQL version string.
// The version string must be in the format that is sent by the server as the version packet
// when opening a new MySQL connection
func NewEnvironment(serverVersion string) (*Environment, error) {
var version collver
func NewEnvironment(serverVersion string) *Environment {
var version collver = collverMySQL56
switch {
case strings.HasSuffix(serverVersion, "-ripple"):
// the ripple binlog server can mask the actual version of mysqld;
// assume we have the highest
version = collverMySQL80
case strings.Contains(serverVersion, "MariaDB"):
switch {
case strings.HasPrefix(serverVersion, "10.0."):
case strings.Contains(serverVersion, "10.0."):
version = collverMariaDB100
case strings.HasPrefix(serverVersion, "10.1."):
case strings.Contains(serverVersion, "10.1."):
version = collverMariaDB101
case strings.HasPrefix(serverVersion, "10.2."):
case strings.Contains(serverVersion, "10.2."):
version = collverMariaDB102
case strings.HasPrefix(serverVersion, "10.3."):
case strings.Contains(serverVersion, "10.3."):
version = collverMariaDB103
}
case strings.HasPrefix(serverVersion, "5.6."):
version = collverMySQL56
case strings.HasPrefix(serverVersion, "5.7."):
version = collverMySQL57
case strings.HasPrefix(serverVersion, "8.0."):
version = collverMySQL80
}
if version == collverInvalid {
return nil, fmt.Errorf("unsupported ServerVersion value: %q", serverVersion)
}
return fetchCacheEnvironment(version), nil
return fetchCacheEnvironment(version)
}

func makeEnv(version collver) *Environment {
Expand Down
Loading

0 comments on commit 9085201

Please sign in to comment.