Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement compress() , uncompress(), and uncompressed_length() #2668

Merged
merged 10 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions enginetest/queries/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -9995,6 +9995,30 @@ from typestable`,
{"latin1"},
},
},
{
Query: "select uncompress(compress('thisisastring'))",
Expected: []sql.Row{
{[]byte{0x74, 0x68, 0x69, 0x73, 0x69, 0x73, 0x61, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67}},
},
},
{
Query: "select length(compress(repeat('a', 1000)))",
Expected: []sql.Row{
{24}, // 21 in MySQL because of library implementation differences
},
},
{
Query: "select length(uncompress(compress(repeat('a', 1000))))",
Expected: []sql.Row{
{1000},
},
},
{
Query: "select uncompressed_length(compress(repeat('a', 1000)))",
Expected: []sql.Row{
{uint32(1000)},
},
},
}

var KeylessQueries = []QueryTest{
Expand Down
239 changes: 239 additions & 0 deletions sql/expression/function/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
package function

import (
"bytes"
"compress/zlib"
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"encoding/binary"
"encoding/hex"
"fmt"
"hash"
Expand Down Expand Up @@ -232,3 +235,239 @@ func (f *SHA2) WithChildren(children ...sql.Expression) (sql.Expression, error)
}
return NewSHA2(children[0], children[1]), nil
}

// Compress function returns the compressed binary string of the input.
// https://dev.mysql.com/doc/refman/8.4/en/encryption-functions.html#function_compress
type Compress struct {
*UnaryFunc
}

var _ sql.FunctionExpression = (*Compress)(nil)
var _ sql.CollationCoercible = (*Compress)(nil)

// NewCompress returns a new Compress function expression
func NewCompress(arg sql.Expression) sql.Expression {
return &Compress{NewUnaryFunc(arg, "Compress", types.LongBlob)}
}

// Description implements sql.FunctionExpression
func (f *Compress) Description() string {
return "compresses a string and returns the result as a binary string."
}

func (f *Compress) Type() sql.Type {
return types.LongBlob
}

// CollationCoercibility implements the interface sql.CollationCoercible.
func (*Compress) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return ctx.GetCollation(), 4
}

// Eval implements sql.Expression
func (f *Compress) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
arg, err := f.EvalChild(ctx, row)
if err != nil {
return nil, err
}
if arg == nil {
return nil, nil
}

val, _, err := types.LongBlob.Convert(arg)
if err != nil {
return nil, err
}
valBytes := val.([]byte)
if len(valBytes) == 0 {
return []byte{}, nil
}

// TODO: the golang standard library implementation of zlib is different than the original C implementation that
// MySQL uses. This means that the output of compressed data will be different. However, this library claims to be
// able to uncompress MySQL compressed data. There are unit tests for this in hash_test.go.
var buf bytes.Buffer
writer, err := zlib.NewWriterLevel(&buf, zlib.BestCompression)
if err != nil {
return nil, err
}
_, err = writer.Write(valBytes)
if err != nil {
return nil, err
}
err = writer.Close()
if err != nil {
return nil, err
}

// Prepend length of original string
lenHeader := make([]byte, 4)
binary.LittleEndian.PutUint32(lenHeader, uint32(len(valBytes)))
res := append(lenHeader, buf.Bytes()...)
return res, nil
}

// WithChildren implements sql.Expression
func (f *Compress) WithChildren(children ...sql.Expression) (sql.Expression, error) {
if len(children) != 1 {
return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 1)
}
return NewCompress(children[0]), nil
}

// Uncompress function returns the binary string from the compressed input.
// https://dev.mysql.com/doc/refman/8.4/en/encryption-functions.html#function_uncompress
type Uncompress struct {
*UnaryFunc
}

var _ sql.FunctionExpression = (*Uncompress)(nil)
var _ sql.CollationCoercible = (*Uncompress)(nil)

const (
compressHeaderSize = 4
compressMaxSize = 0x04000000
)

// NewUncompress returns a new Uncompress function expression
func NewUncompress(arg sql.Expression) sql.Expression {
return &Uncompress{NewUnaryFunc(arg, "Uncompress", types.LongBlob)}
}

// Description implements sql.FunctionExpression
func (f *Uncompress) Description() string {
return "uncompresses a string compressed by the COMPRESS() function."
}

func (f *Uncompress) Type() sql.Type {
return types.LongBlob
}

// CollationCoercibility implements the interface sql.CollationCoercible.
func (*Uncompress) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return ctx.GetCollation(), 4
}

// Eval implements sql.Expression
func (f *Uncompress) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
arg, err := f.EvalChild(ctx, row)
if err != nil {
return nil, err
}
if arg == nil {
return nil, nil
}

val, _, err := types.LongBlob.Convert(arg)
if err != nil {
ctx.Warn(1258, err.Error())
return nil, nil
}
valBytes := val.([]byte)
if len(valBytes) == 0 {
return []byte{}, nil
}
if len(valBytes) <= compressHeaderSize {
ctx.Warn(1258, "input data corrupted")
return nil, nil
}

var inBuf bytes.Buffer
inBuf.Write(valBytes[compressHeaderSize:]) // skip length header
reader, err := zlib.NewReader(&inBuf)
if err != nil {
return nil, err
}
defer reader.Close()

outLen := binary.LittleEndian.Uint32(valBytes[:compressHeaderSize])
if outLen > compressMaxSize {
ctx.Warn(1258, fmt.Sprintf("Uncompressed data too large; the maximum size is %d", compressMaxSize))
return nil, nil
}

outBuf := make([]byte, outLen)
readLen, err := reader.Read(outBuf)
if err != nil && err != io.EOF {
ctx.Warn(1258, err.Error())
return nil, nil
}
// if we don't receive io.EOF, then received outLen was too small
if err == nil {
ctx.Warn(1258, "not enough room in output buffer")
return nil, nil
}
return outBuf[:readLen], nil
}

// WithChildren implements sql.Expression
func (f *Uncompress) WithChildren(children ...sql.Expression) (sql.Expression, error) {
if len(children) != 1 {
return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 1)
}
return NewUncompress(children[0]), nil
}

// UncompressedLength function returns the length of the original string from the compressed string input.
// https://dev.mysql.com/doc/refman/8.4/en/encryption-functions.html#function_uncompress
type UncompressedLength struct {
*UnaryFunc
}

var _ sql.FunctionExpression = (*UncompressedLength)(nil)
var _ sql.CollationCoercible = (*UncompressedLength)(nil)

// NewUncompressedLength returns a new UncompressedLength function expression
func NewUncompressedLength(arg sql.Expression) sql.Expression {
return &UncompressedLength{NewUnaryFunc(arg, "UncompressedLength", types.Uint32)}
}

// Description implements sql.FunctionExpression
func (f *UncompressedLength) Description() string {
return "returns length of original uncompressed string from compressed string input."
}

func (f *UncompressedLength) Type() sql.Type {
return types.Uint32
}

// CollationCoercibility implements the interface sql.CollationCoercible.
func (*UncompressedLength) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return ctx.GetCollation(), 4
}

// Eval implements sql.Expression
func (f *UncompressedLength) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
arg, err := f.EvalChild(ctx, row)
if err != nil {
return nil, err
}
if arg == nil {
return nil, nil
}

val, _, err := types.LongBlob.Convert(arg)
if err != nil {
ctx.Warn(1258, err.Error())
return nil, nil
}
valBytes := val.([]byte)
if len(valBytes) == 0 {
return uint32(0), nil
}
if len(valBytes) <= compressHeaderSize {
ctx.Warn(1258, "input data corrupted")
return nil, nil
}

outLen := binary.LittleEndian.Uint32(valBytes[:compressHeaderSize])
return outLen, nil
}

// WithChildren implements sql.Expression
func (f *UncompressedLength) WithChildren(children ...sql.Expression) (sql.Expression, error) {
if len(children) != 1 {
return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 1)
}
return NewUncompressedLength(children[0]), nil
}
Loading
Loading