From a94a98f115624c0ef90378e595f52d74fc215edc Mon Sep 17 00:00:00 2001 From: Song Gao Date: Fri, 8 Nov 2024 16:02:39 +0800 Subject: [PATCH] fix Signed-off-by: Song Gao --- extensions/impl/sql/source.go | 17 +++++++++++++++-- extensions/impl/sql/source_test.go | 28 +++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/extensions/impl/sql/source.go b/extensions/impl/sql/source.go index 95b0df763b..e6ac96b86b 100644 --- a/extensions/impl/sql/source.go +++ b/extensions/impl/sql/source.go @@ -226,7 +226,8 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns func prepareValues(ctx api.StreamContext, values []interface{}, columnTypes []*sql.ColumnType, columns []string) { if len(columnTypes) > 0 { for idx, columnType := range columnTypes { - if got := buildScanValueByColumnType(ctx, columnType.Name(), columnType.DatabaseTypeName()); got != nil { + nullable, ok := columnType.Nullable() + if got := buildScanValueByColumnType(ctx, columnType.Name(), columnType.DatabaseTypeName(), nullable && ok); got != nil { values[idx] = got continue } @@ -264,15 +265,27 @@ func (sc *SQLConf) resolveDBURL(props map[string]any) (map[string]any, error) { return props, nil } -func buildScanValueByColumnType(ctx api.StreamContext, colName, colType string) interface{} { +func buildScanValueByColumnType(ctx api.StreamContext, colName, colType string, nullable bool) interface{} { switch strings.ToUpper(colType) { case "CHAR", "VARCHAR", "NCHAR", "NVARCHAR", "TEXT", "NTEXT": + if nullable { + return &sql.NullString{} + } return new(string) case "DECIMAL", "NUMERIC", "FLOAT", "REAL": + if nullable { + return &sql.NullFloat64{} + } return new(float64) case "BOOL": + if nullable { + return &sql.NullBool{} + } return new(bool) case "INT", "BIGINT", "SMALLINT", "TINYINT": + if nullable { + return &sql.NullInt64{} + } return new(int64) default: ctx.GetLogger().Infof("sql source meet column %v unknown columnType:%v", colName, colType) diff --git a/extensions/impl/sql/source_test.go b/extensions/impl/sql/source_test.go index b09e904e87..94dce183b7 100644 --- a/extensions/impl/sql/source_test.go +++ b/extensions/impl/sql/source_test.go @@ -15,6 +15,7 @@ package sql import ( + "database/sql" "errors" "fmt" "testing" @@ -396,7 +397,32 @@ func TestBuildScanValueByColumnType(t *testing.T) { } ctx := mockContext.NewMockContext("1", "2") for _, tc := range testcases { - got := buildScanValueByColumnType(ctx, "col", tc.colType) + got := buildScanValueByColumnType(ctx, "col", tc.colType, false) + require.Equal(t, tc.exp, got) + } + testcases2 := []struct { + colType string + exp interface{} + }{ + { + colType: "varchar", + exp: &sql.NullString{}, + }, + { + colType: "DECIMAL", + exp: &sql.NullFloat64{}, + }, + { + colType: "BOOL", + exp: &sql.NullBool{}, + }, + { + colType: "int", + exp: &sql.NullInt64{}, + }, + } + for _, tc := range testcases2 { + got := buildScanValueByColumnType(ctx, "col", tc.colType, true) require.Equal(t, tc.exp, got) } }