Skip to content

Commit 671214b

Browse files
authored
Merge pull request #12 from agoncear-mwb/main
Implement QueryRow and Exec methods of sql driver interface
2 parents f3d6a72 + d1572fe commit 671214b

File tree

2 files changed

+187
-7
lines changed

2 files changed

+187
-7
lines changed

chdb/driver/driver.go

+97-7
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,60 @@ func init() {
7979
sql.Register("chdb", Driver{})
8080
}
8181

82+
// Row is the result of calling [DB.QueryRow] to select a single row.
83+
type singleRow struct {
84+
// One of these two will be non-nil:
85+
err error // deferred error for easy chaining
86+
rows driver.Rows
87+
}
88+
89+
// Scan copies the columns from the matched row into the values
90+
// pointed at by dest. See the documentation on [Rows.Scan] for details.
91+
// If more than one row matches the query,
92+
// Scan uses the first row and discards the rest. If no row matches
93+
// the query, Scan returns [ErrNoRows].
94+
func (r *singleRow) Scan(dest ...any) error {
95+
if r.err != nil {
96+
return r.err
97+
}
98+
vals := make([]driver.Value, 0)
99+
for _, v := range dest {
100+
vals = append(vals, v)
101+
}
102+
err := r.rows.Next(vals)
103+
if err != nil {
104+
return err
105+
}
106+
// Make sure the query can be processed to completion with no errors.
107+
return r.rows.Close()
108+
}
109+
110+
// Err provides a way for wrapping packages to check for
111+
// query errors without calling [Row.Scan].
112+
// Err returns the error, if any, that was encountered while running the query.
113+
// If this error is not nil, this error will also be returned from [Row.Scan].
114+
func (r *singleRow) Err() error {
115+
return r.err
116+
}
117+
118+
type execResult struct {
119+
err error
120+
}
121+
122+
func (e *execResult) LastInsertId() (int64, error) {
123+
if e.err != nil {
124+
return 0, e.err
125+
}
126+
return -1, fmt.Errorf("does not support LastInsertId")
127+
128+
}
129+
func (e *execResult) RowsAffected() (int64, error) {
130+
if e.err != nil {
131+
return 0, e.err
132+
}
133+
return -1, fmt.Errorf("does not support RowsAffected")
134+
}
135+
82136
type queryHandle func(string, ...string) (*chdbstable.LocalResult, error)
83137

84138
type connector struct {
@@ -192,6 +246,18 @@ type conn struct {
192246
QueryFun queryHandle
193247
}
194248

249+
func prepareValues(values []driver.Value) []driver.NamedValue {
250+
namedValues := make([]driver.NamedValue, len(values))
251+
for i, value := range values {
252+
namedValues[i] = driver.NamedValue{
253+
// nb: Name field is optional
254+
Ordinal: i,
255+
Value: value,
256+
}
257+
}
258+
return namedValues
259+
}
260+
195261
func (c *conn) Close() error {
196262
return nil
197263
}
@@ -204,15 +270,39 @@ func (c *conn) SetupQueryFun() {
204270
}
205271

206272
func (c *conn) Query(query string, values []driver.Value) (driver.Rows, error) {
207-
namedValues := make([]driver.NamedValue, len(values))
208-
for i, value := range values {
209-
namedValues[i] = driver.NamedValue{
210-
// nb: Name field is optional
211-
Ordinal: i,
212-
Value: value,
273+
return c.QueryContext(context.Background(), query, prepareValues(values))
274+
}
275+
276+
func (c *conn) QueryRow(query string, values []driver.Value) *singleRow {
277+
return c.QueryRowContext(context.Background(), query, values)
278+
}
279+
280+
func (c *conn) Exec(query string, values []driver.Value) (sql.Result, error) {
281+
return c.ExecContext(context.Background(), query, prepareValues(values))
282+
}
283+
284+
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
285+
_, err := c.QueryContext(ctx, query, args)
286+
if err != nil && err.Error() != "result is nil" {
287+
return nil, err
288+
}
289+
return &execResult{
290+
err: nil,
291+
}, nil
292+
}
293+
294+
func (c *conn) QueryRowContext(ctx context.Context, query string, values []driver.Value) *singleRow {
295+
296+
v, err := c.QueryContext(ctx, query, prepareValues(values))
297+
if err != nil {
298+
return &singleRow{
299+
err: err,
300+
rows: nil,
213301
}
214302
}
215-
return c.QueryContext(context.Background(), query, namedValues)
303+
return &singleRow{
304+
rows: v,
305+
}
216306
}
217307

218308
func (c *conn) compileArguments(query string, args []driver.NamedValue) (string, error) {

chdb/driver/driver_test.go

+90
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,93 @@ func TestDbWithSession(t *testing.T) {
167167
count++
168168
}
169169
}
170+
171+
func TestQueryRow(t *testing.T) {
172+
sessionDir, err := os.MkdirTemp("", "unittest-sessiondata")
173+
if err != nil {
174+
t.Fatalf("create temp directory fail, err: %s", err)
175+
}
176+
defer os.RemoveAll(sessionDir)
177+
session, err := chdb.NewSession(sessionDir)
178+
if err != nil {
179+
t.Fatalf("new session fail, err: %s", err)
180+
}
181+
defer session.Cleanup()
182+
183+
session.Query("USE testdb; INSERT INTO testtable VALUES (1), (2), (3);")
184+
185+
ret, err := session.Query("SELECT * FROM testtable;")
186+
if err != nil {
187+
t.Fatalf("Query fail, err: %s", err)
188+
}
189+
if string(ret.Buf()) != "1\n2\n3\n" {
190+
t.Errorf("Query result should be 1\n2\n3\n, got %s", string(ret.Buf()))
191+
}
192+
db, err := sql.Open("chdb", fmt.Sprintf("session=%s", sessionDir))
193+
if err != nil {
194+
t.Fatalf("open db fail, err: %s", err)
195+
}
196+
if db.Ping() != nil {
197+
t.Fatalf("ping db fail, err: %s", err)
198+
}
199+
rows := db.QueryRow("select * from testtable;")
200+
201+
var bar = 0
202+
var count = 1
203+
err = rows.Scan(&bar)
204+
if err != nil {
205+
t.Fatalf("scan fail, err: %s", err)
206+
}
207+
if bar != count {
208+
t.Fatalf("result is not match, want: %d actual: %d", count, bar)
209+
}
210+
err2 := rows.Scan(&bar)
211+
if err2 == nil {
212+
t.Fatalf("QueryRow method should return only one item")
213+
}
214+
215+
}
216+
217+
func TestExec(t *testing.T) {
218+
sessionDir, err := os.MkdirTemp("", "unittest-sessiondata")
219+
if err != nil {
220+
t.Fatalf("create temp directory fail, err: %s", err)
221+
}
222+
defer os.RemoveAll(sessionDir)
223+
session, err := chdb.NewSession(sessionDir)
224+
if err != nil {
225+
t.Fatalf("new session fail, err: %s", err)
226+
}
227+
defer session.Cleanup()
228+
session.Query("CREATE DATABASE IF NOT EXISTS testdb; " +
229+
"CREATE TABLE IF NOT EXISTS testdb.testtable (id UInt32) ENGINE = MergeTree() ORDER BY id;")
230+
231+
db, err := sql.Open("chdb", fmt.Sprintf("session=%s", sessionDir))
232+
if err != nil {
233+
t.Fatalf("open db fail, err: %s", err)
234+
}
235+
if db.Ping() != nil {
236+
t.Fatalf("ping db fail, err: %s", err)
237+
}
238+
239+
_, err = db.Exec("INSERT INTO testdb.testtable VALUES (1), (2), (3);")
240+
if err != nil {
241+
t.Fatalf("exec failed, err: %s", err)
242+
}
243+
rows := db.QueryRow("select * from testdb.testtable;")
244+
245+
var bar = 0
246+
var count = 1
247+
err = rows.Scan(&bar)
248+
if err != nil {
249+
t.Fatalf("scan fail, err: %s", err)
250+
}
251+
if bar != count {
252+
t.Fatalf("result is not match, want: %d actual: %d", count, bar)
253+
}
254+
err2 := rows.Scan(&bar)
255+
if err2 == nil {
256+
t.Fatalf("QueryRow method should return only one item")
257+
}
258+
259+
}

0 commit comments

Comments
 (0)