Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wk989898 committed Oct 25, 2024
1 parent c264a09 commit d218ec5
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 40 deletions.
24 changes: 24 additions & 0 deletions pkg/sink/codec/common/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -400,3 +400,27 @@ func UnsafeStringToBytes(s string) []byte {
}{s, len(s)},
))
}

// UnsafeStringToBinary create binary string from string without copying
func UnsafeStringToBinary(s string) string {
var result string
bs := UnsafeStringToBytes(escapeBackslash(s))
if len(bs) >= 1 {
result += fmt.Sprintf("%b", bs[0])
}
for i := 1; i < len(bs); i++ {
result += fmt.Sprintf("%08b", bs[i])
}
return result
}

func escapeBackslash(s string) string {
var sb strings.Builder
for i := 0; i < len(s); i++ {
if s[i] == '\\' {
continue
}
sb.WriteByte(s[i])
}
return sb.String()
}
53 changes: 45 additions & 8 deletions pkg/sink/codec/debezium/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,31 @@ func (c *dbzCodec) writeDebeziumFieldSchema(
switch col.GetType() {
case mysql.TypeBit:
n := ft.GetFlen()
var v uint64
var err error
if col.GetDefaultValue() != nil {
val, ok := col.GetDefaultValue().(string)
if !ok {
log.Error(
"GetDefaultValue meet error",
zap.Any("column", col.GetName()), zap.Error(err))
return
}
v, err = strconv.ParseUint(common.UnsafeStringToBinary(val), 2, 64)
if err != nil {
log.Error(
"parsing uint from bit meet error",
zap.Any("column", col.GetName()), zap.Error(err))
return
}
}
if n == 1 {
writer.WriteObjectElement(func() {
writer.WriteStringField("type", "boolean")
writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag()))
writer.WriteStringField("field", col.GetName())
if col.GetDefaultValue() != nil {
writer.WriteAnyField("default", col.GetDefaultValue())
writer.WriteBoolField("default", v != 0) // bool
}
})
} else {
Expand All @@ -89,7 +107,13 @@ func (c *dbzCodec) writeDebeziumFieldSchema(
})
writer.WriteStringField("field", col.GetName())
if col.GetDefaultValue() != nil {
writer.WriteAnyField("default", col.GetDefaultValue())
var buf [8]byte
binary.LittleEndian.PutUint64(buf[:], v)
numBytes := n / 8
if n%8 != 0 {
numBytes += 1
}
c.writeBinaryField(writer, "default", buf[:numBytes]) // binary
}
})
}
Expand Down Expand Up @@ -156,7 +180,6 @@ func (c *dbzCodec) writeDebeziumFieldSchema(
if err != nil {
return
}

writer.WriteFloat64Field("default", floatV)
}
})
Expand Down Expand Up @@ -437,8 +460,9 @@ func (c *dbzCodec) writeDebeziumFieldValue(
var v uint64
switch val := col.Value.(type) {
case uint64:
v = val
case string:
hexValue, err := strconv.ParseUint(val, 0, 64)
hexValue, err := strconv.ParseUint(common.UnsafeStringToBinary(val), 2, 64)
if err != nil {
return cerror.ErrDebeziumEncodeFailed.GenWithStack(
"unexpected column value type string for bit column %s, error:%s",
Expand All @@ -460,6 +484,7 @@ func (c *dbzCodec) writeDebeziumFieldValue(
writer.WriteBoolField(col.GetName(), v != 0)
return nil
} else {
// 10110000100001111
var buf [8]byte
binary.LittleEndian.PutUint64(buf[:], v)
numBytes := n / 8
Expand Down Expand Up @@ -710,6 +735,20 @@ func (c *dbzCodec) writeDebeziumFieldValue(
}
return nil

case mysql.TypeDouble, mysql.TypeFloat:
if v, ok := col.Value.(string); ok {
val, err := strconv.ParseFloat(v, 64)
if err != nil {
return cerror.ErrDebeziumEncodeFailed.GenWithStack(
"unexpected column value type string for int column %s",
col.GetName())
}
writer.WriteFloat64Field(col.GetName(), val)
} else {
writer.WriteAnyField(col.GetName(), col.Value)
}
return nil

case mysql.TypeTiDBVectorFloat32:
v, ok := col.Value.(types.VectorFloat32)
if !ok {
Expand Down Expand Up @@ -1223,10 +1262,8 @@ func (c *dbzCodec) EncodeDDLEvent(
// jWriter.WriteAnyField("defaultValueExpression", "CURRENT_TIMESTAMP")
} else if v == "<nil>" {
jWriter.WriteNullField("defaultValueExpression")
} else if col.DefaultValueBit != nil && (strings.HasPrefix(v, "0x")) {
var hexValue int64
hexValue, err = strconv.ParseInt(v, 0, 64)
jWriter.WriteStringField("defaultValueExpression", fmt.Sprintf("%b", hexValue))
} else if col.DefaultValueBit != nil {
jWriter.WriteStringField("defaultValueExpression", common.UnsafeStringToBinary(v))
} else {
jWriter.WriteStringField("defaultValueExpression", v)
}
Expand Down
8 changes: 2 additions & 6 deletions pkg/sink/codec/debezium/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,6 @@ func parseType(c *timodel.ColumnInfo, col *ast.ColumnDef) {
if c.OriginDefaultValue != nil {
c.SetDefaultValue(c.OriginDefaultValue)
}
case mysql.TypeBit:
if c.OriginDefaultValue != nil {
c.SetDefaultValue(c.OriginDefaultValue)
}
default:
}
}
Expand Down Expand Up @@ -146,7 +142,7 @@ func getLen(ft types.FieldType) int {
case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDuration:
return decimal
case mysql.TypeBit, mysql.TypeVarchar, mysql.TypeString, mysql.TypeNewDecimal, mysql.TypeSet,
mysql.TypeVarString, mysql.TypeTiDBVectorFloat32, mysql.TypeTiny, mysql.TypeYear, mysql.TypeShort:
mysql.TypeVarString, mysql.TypeTiDBVectorFloat32, mysql.TypeYear:
return flen
case mysql.TypeEnum:
return 1
Expand All @@ -158,7 +154,7 @@ func getLen(ft types.FieldType) int {
if flen != defaultFlen {
return flen
}
case mysql.TypeLong:
case mysql.TypeLong, mysql.TypeTiny, mysql.TypeShort:
defaultFlen, _ := mysql.GetDefaultFieldLengthAndDecimal(ft.GetType())
if flen != defaultFlen {
return flen
Expand Down
20 changes: 9 additions & 11 deletions tests/integration_tests/debezium/sql/debezium/default_value.sql
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ CREATE TABLE STRING_TABLE (
I VARCHAR(10) NULL DEFAULT '100'
);
INSERT INTO STRING_TABLE
VALUES (1, DEFAULT, DEFAULT ,DEFAULT ,DEFAULT ,DEFAULT, NULL);
VALUES (1, DEFAULT, DEFAULT, DEFAULT, DEFAULT, DEFAULT, NULL);

CREATE TABLE BIT_TABLE (
id int PRIMARY KEY,
Expand All @@ -77,15 +77,13 @@ CREATE TABLE BIT_TABLE (
C BIT(1) DEFAULT 1,
D BIT(1) DEFAULT b'0',
E BIT(1) DEFAULT b'1',
F BIT(1) DEFAULT TRUE,
G BIT(1) DEFAULT FALSE,
H BIT(10) DEFAULT b'101000010',
I BIT(10) DEFAULT NULL,
J BIT(25) DEFAULT b'10110000100001111',
J BIT(25) DEFAULT b'00000000100001111',
K BIT(25) DEFAULT b'10110000100001111'
);
INSERT INTO BIT_TABLE
VALUES (1, false ,DEFAULT ,DEFAULT ,DEFAULT ,DEFAULT ,DEFAULT ,DEFAULT, DEFAULT ,NULL ,DEFAULT, NULL);
VALUES (1, false, DEFAULT, DEFAULT, DEFAULT, DEFAULT, DEFAULT, NULL, DEFAULT, NULL);

CREATE TABLE NUMBER_TABLE (
id int PRIMARY KEY,
Expand All @@ -94,22 +92,22 @@ CREATE TABLE NUMBER_TABLE (
C INTEGER NOT NULL DEFAULT 0,
D BIGINT NOT NULL DEFAULT 20,
E INT NULL DEFAULT NULL,
F INT NULL DEFAULT 30,
G TINYINT(1) NOT NULL DEFAULT TRUE,
H INT(1) NOT NULL DEFAULT TRUE
F INT NULL DEFAULT 30
);
INSERT INTO NUMBER_TABLE
VALUES (1, DEFAULT ,DEFAULT ,DEFAULT ,DEFAULT ,DEFAULT, NULL, DEFAULT, DEFAULT);
VALUES (1, DEFAULT, DEFAULT, DEFAULT, DEFAULT, DEFAULT, NULL);

CREATE TABLE FlOAT_DOUBLE_TABLE (
id int PRIMARY KEY,
F FLOAT NULL DEFAULT 0,
G DOUBLE NOT NULL DEFAULT 1.0,
H DOUBLE NULL DEFAULT 3.0
G DOUBLE NOT NULL DEFAULT 1.1,
H DOUBLE NULL DEFAULT 3.3
);
INSERT INTO FlOAT_DOUBLE_TABLE
VALUES (1, DEFAULT, DEFAULT, NULL);

-- set sql_mode REAL_AS_FLOAT is necessary
set @@session.sql_mode=concat(@@session.sql_mode, ',REAL_AS_FLOAT');
CREATE TABLE REAL_TABLE (
id int PRIMARY KEY,
A REAL NOT NULL DEFAULT 1,
Expand Down
29 changes: 14 additions & 15 deletions tests/integration_tests/debezium/src/test_cases.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@ import (
"github.com/pingcap/tidb/pkg/parser"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/format"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/segmentio/kafka-go"
"go.uber.org/zap"
)

var timeOut = time.Second * 30
var timeOut = time.Second * 10

var (
nFailed = 0
Expand Down Expand Up @@ -91,6 +92,11 @@ func runAllTestCases(dir string) bool {

for _, path := range files {
logger.Info("Run", zap.String("case", path))
failed := runTestCase(path)
if failed {
logger.Info("failed", zap.String("case", path))
return false
}
}

if nFailed > 0 {
Expand Down Expand Up @@ -165,18 +171,18 @@ func runTestCase(testCasePath string) bool {
}

func fetchNextCDCRecord(reader *kafka.Reader, kind Kind, timeout time.Duration) (map[string]any, map[string]any, bool, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
for {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
m, err := reader.FetchMessage(ctx)
cancel()
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return nil, nil, false, nil
}
return nil, nil, false, fmt.Errorf("Failed to read CDC record of %s: %w", kind, err)
}

if err = reader.CommitMessages(context.Background(), m); err != nil {
if err = reader.CommitMessages(ctx, m); err != nil {
return nil, nil, false, fmt.Errorf("Failed to commit CDC record of %s: %w", kind, err)
}

Expand Down Expand Up @@ -241,18 +247,10 @@ func fetchNextCDCRecord(reader *kafka.Reader, kind Kind, timeout time.Duration)
col["typeName"] = replaceString(col["typeName"], "NUMERIC", "DECIMAL")
col["typeExpression"] = replaceString(col["typeExpression"], "NUMERIC", "DECIMAL")
col["jdbcType"] = float64(3)
case "NVARCHAR":
col["typeName"] = replaceString(col["typeName"], "NVARCHAR", "VARCHAR")
col["typeExpression"] = replaceString(col["typeExpression"], "NVARCHAR", "VARCHAR")
col["jdbcType"] = float64(12)
case "NCHAR":
col["typeName"] = replaceString(col["typeName"], "NCHAR", "CHAR")
col["typeExpression"] = replaceString(col["typeExpression"], "NCHAR", "CHAR")
col["jdbcType"] = float64(1)
case "REAL":
col["typeName"] = replaceString(col["typeName"], "REAL", "DOUBLE")
col["typeExpression"] = replaceString(col["typeExpression"], "REAL", "DOUBLE")
col["jdbcType"] = float64(7)
col["typeName"] = replaceString(col["typeName"], "REAL", "FLOAT")
col["typeExpression"] = replaceString(col["typeExpression"], "REAL", "FLOAT")
col["jdbcType"] = float64(6)
}
}
}
Expand Down Expand Up @@ -284,6 +282,7 @@ func printRecord(obj any) {

func normalizeSQL(sql string) string {
p := parser.New()
p.SetSQLMode(mysql.ModeRealAsFloat) // necessary
stmt, err := p.ParseOneStmt(sql, "", "")
buf := new(bytes.Buffer)
if err != nil {
Expand Down

0 comments on commit d218ec5

Please sign in to comment.