diff --git a/lib/column/fixed_string.go b/lib/column/fixed_string.go index d9c5e84dcf..5db155968f 100644 --- a/lib/column/fixed_string.go +++ b/lib/column/fixed_string.go @@ -19,6 +19,7 @@ package column import ( "database/sql" + "database/sql/driver" "encoding" "fmt" "github.com/ClickHouse/ch-go/proto" @@ -153,15 +154,38 @@ func (col *FixedString) AppendRow(v interface{}) (err error) { return err } default: - if s, ok := v.(fmt.Stringer); ok { - return col.AppendRow(s.String()) - } else { + if s, ok := v.(driver.Valuer); ok { + val, err := s.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "String", + From: fmt.Sprintf("%T", s), + Hint: "could not get driver.Valuer value", + } + } + + if s, ok := val.(string); ok { + return col.AppendRow(s) + } + return &ColumnConverterError{ Op: "AppendRow", - To: "FixedString", + To: "String", From: fmt.Sprintf("%T", v), + Hint: "driver.Valuer value is not a string", } } + + if s, ok := v.(fmt.Stringer); ok { + return col.AppendRow(s.String()) + } + + return &ColumnConverterError{ + Op: "AppendRow", + To: "String", + From: fmt.Sprintf("%T", v), + } } col.col.Append(data) return nil diff --git a/lib/column/string.go b/lib/column/string.go index af358cc9f8..8639c19e2e 100644 --- a/lib/column/string.go +++ b/lib/column/string.go @@ -19,6 +19,7 @@ package column import ( "database/sql" + "database/sql/driver" "encoding" "fmt" "github.com/ClickHouse/ch-go/proto" @@ -115,15 +116,38 @@ func (col *String) AppendRow(v interface{}) error { case nil: col.col.Append("") default: - if s, ok := v.(fmt.Stringer); ok { - return col.AppendRow(s.String()) - } else { + if s, ok := v.(driver.Valuer); ok { + val, err := s.Value() + if err != nil { + return &ColumnConverterError{ + Op: "AppendRow", + To: "String", + From: fmt.Sprintf("%T", s), + Hint: "could not get driver.Valuer value", + } + } + + if s, ok := val.(string); ok { + return col.AppendRow(s) + } + return &ColumnConverterError{ Op: "AppendRow", To: "String", From: fmt.Sprintf("%T", v), + Hint: "driver.Valuer value is not a string", } } + + if s, ok := v.(fmt.Stringer); ok { + return col.AppendRow(s.String()) + } + + return &ColumnConverterError{ + Op: "AppendRow", + To: "String", + From: fmt.Sprintf("%T", v), + } } return nil } diff --git a/tests/fixed_string_test.go b/tests/fixed_string_test.go index 3bdbe628bb..a38d9371e9 100644 --- a/tests/fixed_string_test.go +++ b/tests/fixed_string_test.go @@ -20,6 +20,7 @@ package tests import ( "context" "crypto/rand" + "fmt" "github.com/stretchr/testify/require" "testing" @@ -362,3 +363,44 @@ func TestFixedStringFlush(t *testing.T) { } require.Equal(t, 1000, i) } + +func TestFixedStringFromDriverValuerType(t *testing.T) { + conn, err := GetConnection("native", nil, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + ctx := context.Background() + + require.NoError(t, err) + require.NoError(t, conn.Ping(ctx)) + if !CheckMinServerServerVersion(conn, 21, 9, 0) { + t.Skip(fmt.Errorf("unsupported clickhouse version")) + return + } + const ddl = ` + CREATE TABLE test_fixed_string ( + Col1 FixedString(5) + , Col2 FixedString(5) + ) Engine MergeTree() ORDER BY tuple() + ` + defer func() { + conn.Exec(ctx, "DROP TABLE test_fixed_string") + }() + require.NoError(t, conn.Exec(ctx, ddl)) + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_fixed_string") + require.NoError(t, err) + + type data struct { + Col1 string `ch:"Col1"` + Col2 testStringSerializer `ch:"Col2"` + } + require.NoError(t, batch.AppendStruct(&data{ + Col1: "Value", + Col2: testStringSerializer{"Value"}, + })) + require.NoError(t, batch.Send()) + + var dest data + require.NoError(t, conn.QueryRow(ctx, "SELECT * FROM test_fixed_string").ScanStruct(&dest)) + assert.Equal(t, "Value", dest.Col1) + assert.Equal(t, testStringSerializer{"Value"}, dest.Col2) +} diff --git a/tests/string_test.go b/tests/string_test.go index d7f4c1e8a9..d9a5a2523a 100644 --- a/tests/string_test.go +++ b/tests/string_test.go @@ -20,6 +20,7 @@ package tests import ( "context" "database/sql" + "database/sql/driver" "fmt" "github.com/stretchr/testify/require" "testing" @@ -309,3 +310,60 @@ func TestStringFlush(t *testing.T) { } require.Equal(t, 1000, i) } + +type testStringSerializer struct { + val string +} + +func (c testStringSerializer) Value() (driver.Value, error) { + return c.val, nil +} + +func (c *testStringSerializer) Scan(src any) error { + if t, ok := src.(string); ok { + *c = testStringSerializer{val: t} + return nil + } + return fmt.Errorf("cannot scan %T into testStringSerializer", src) +} + +func TestStringFromDriverValuerType(t *testing.T) { + conn, err := GetConnection("native", nil, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + ctx := context.Background() + + require.NoError(t, err) + require.NoError(t, conn.Ping(ctx)) + if !CheckMinServerServerVersion(conn, 21, 9, 0) { + t.Skip(fmt.Errorf("unsupported clickhouse version")) + return + } + const ddl = ` + CREATE TABLE test_string ( + Col1 String + , Col2 String + ) Engine MergeTree() ORDER BY tuple() + ` + defer func() { + conn.Exec(ctx, "DROP TABLE test_string") + }() + require.NoError(t, conn.Exec(ctx, ddl)) + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_string") + require.NoError(t, err) + + type data struct { + Col1 string `ch:"Col1"` + Col2 testStringSerializer `ch:"Col2"` + } + require.NoError(t, batch.AppendStruct(&data{ + Col1: "Value", + Col2: testStringSerializer{"Value"}, + })) + require.NoError(t, batch.Send()) + + var dest data + require.NoError(t, conn.QueryRow(ctx, "SELECT * FROM test_string").ScanStruct(&dest)) + assert.Equal(t, "Value", dest.Col1) + assert.Equal(t, testStringSerializer{"Value"}, dest.Col2) +}