-
Notifications
You must be signed in to change notification settings - Fork 15
/
dynamic.go
329 lines (271 loc) · 9.1 KB
/
dynamic.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
// SPDX-FileCopyrightText: 2020 SAP SE
// SPDX-FileCopyrightText: 2021 SAP SE
// SPDX-FileCopyrightText: 2022 SAP SE
// SPDX-FileCopyrightText: 2023 SAP SE
//
// SPDX-License-Identifier: Apache-2.0
package ase
import (
"context"
"database/sql/driver"
"errors"
"fmt"
"io"
"github.com/SAP/go-dblib"
"github.com/SAP/go-dblib/asetypes"
"github.com/SAP/go-dblib/namepool"
"github.com/SAP/go-dblib/tds"
)
// Interface satisfaction checks.
var (
_ driver.Stmt = (*Stmt)(nil)
_ driver.StmtExecContext = (*Stmt)(nil)
_ driver.StmtQueryContext = (*Stmt)(nil)
_ driver.NamedValueChecker = (*Stmt)(nil)
stmtIdPool = namepool.Pool("stmt%d")
)
// Stmt implements the driver.Stmt interface.
type Stmt struct {
conn *Conn
stmtId *namepool.Name
pkg *tds.DynamicPackage
paramFmt *tds.ParamFmtPackage
rowFmt *tds.RowFmtPackage
cursor *Cursor
}
// Prepare implements the driver.Conn interface.
func (c *Conn) Prepare(query string) (driver.Stmt, error) {
return c.PrepareContext(context.Background(), query)
}
// PrepareContext implements the driver.ConnPrepareContext interface.
func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
// TODO option for create_proc
return c.NewStmt(ctx, "", query, true)
}
// NewStmt creates a new statement.
func (c *Conn) NewStmt(ctx context.Context, name, query string, create_proc bool) (*Stmt, error) {
stmt := &Stmt{conn: c}
if name == "" {
// TODO different pools for procs and prepares
stmt.stmtId = stmtIdPool.Acquire()
name = stmt.stmtId.Name()
}
stmt.pkg = tds.NewDynamicPackage(true)
stmt.pkg.ID = name
if create_proc {
stmt.pkg.Stmt = fmt.Sprintf("create proc %s as %s", name, query)
} else {
stmt.pkg.Stmt = query
}
// Reset statement to default before proceeding
stmt.Reset()
if err := stmt.allocateOnServer(ctx); err != nil {
return nil, fmt.Errorf("go-ase: error allocating dynamic statement '%s': %w", query, err)
}
return stmt, nil
}
// allocateOnServer communicates the allocation of the dynamic statement
// on the server and retrieves the input and output formats.
func (stmt *Stmt) allocateOnServer(ctx context.Context) error {
stmt.pkg.Type = tds.TDS_DYN_PREPARE
if err := stmt.conn.Channel.SendPackage(ctx, stmt.pkg); err != nil {
return fmt.Errorf("error queueing dynamic prepare package: %w", err)
}
stmt.Reset()
if err := stmt.recvDynAck(ctx); err != nil {
return err
}
_, err := stmt.conn.Channel.NextPackageUntil(ctx, true,
func(pkg tds.Package) (bool, error) {
switch typed := pkg.(type) {
case *tds.ParamFmtPackage:
stmt.paramFmt = typed
return false, nil
case *tds.RowFmtPackage:
stmt.rowFmt = typed
return false, nil
case *tds.DonePackage:
ok, err := handleDonePackage(typed)
if err != nil {
return true, err
}
return ok, nil
default:
return false, fmt.Errorf("unexpected package received: %#v", typed)
}
},
)
if err != nil && !errors.Is(err, io.EOF) {
stmt.close(ctx)
return err
}
return nil
}
// Reset resets a statement.
func (stmt *Stmt) Reset() {
stmt.pkg.Type = tds.TDS_DYN_INVALID
stmt.pkg.Status = tds.TDS_DYNAMIC_UNUSED
}
// Close implements the driver.Stmt interface.
func (stmt *Stmt) Close() error {
return stmt.close(context.Background())
}
func (stmt *Stmt) close(ctx context.Context) error {
if stmt.stmtId != nil {
defer stmtIdPool.Release(stmt.stmtId)
}
// communicate deallocation with server
// TODO option to not deallocate procs
stmt.pkg.Type = tds.TDS_DYN_DEALLOC
if err := stmt.conn.Channel.SendPackage(ctx, stmt.pkg); err != nil {
return fmt.Errorf("error sending dealloc package: %w", err)
}
stmt.Reset()
if err := stmt.recvDynAck(ctx); err != nil {
return err
}
_, err := stmt.conn.Channel.NextPackageUntil(ctx, true, func(pkg tds.Package) (bool, error) {
switch typed := pkg.(type) {
case *tds.CurInfoPackage:
if typed.Command != tds.TDS_CUR_CMD_INFORM {
return true, fmt.Errorf("received %T with command %s instead of TDS_CUR_CMD_INFORM",
typed, typed.Command)
}
if typed.Status&tds.TDS_CUR_ISTAT_CLOSED == tds.TDS_CUR_ISTAT_CLOSED {
stmt.cursor.closed = true
}
return false, nil
case *tds.DonePackage:
if typed.Status != tds.TDS_DONE_FINAL {
return false, fmt.Errorf("DonePackage does not have status TDS_DONE_FINAL set: %s", typed)
}
return true, io.EOF
default:
return true, fmt.Errorf("unhandled package type %T: %s", typed, typed)
}
})
if err != nil && !errors.Is(err, io.EOF) {
return fmt.Errorf("error handling response to stmt deallocation: %w", err)
}
return nil
}
// NumInput implements the driver.Stmt interface.
func (stmt Stmt) NumInput() int {
if stmt.paramFmt == nil || stmt.paramFmt.Fmts == nil {
return 0
}
return len(stmt.paramFmt.Fmts)
}
// Exec implements the driver.Stmt interface.
func (stmt Stmt) Exec(args []driver.Value) (driver.Result, error) {
return stmt.ExecContext(context.Background(), dblib.ValuesToNamedValues(args))
}
// ExecContext implements the driver.StmtExecContext interface.
func (stmt Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
rows, result, err := stmt.GenericExec(ctx, args)
if rows != nil {
rows.Close()
}
return result, err
}
// Query implements the driver.Stmt interface.
func (stmt Stmt) Query(args []driver.Value) (driver.Rows, error) {
return stmt.QueryContext(context.Background(), dblib.ValuesToNamedValues(args))
}
// QueryContext implements the driver.StmtQueryContext interface.
func (stmt Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
rows, _, err := stmt.GenericExec(ctx, args)
return rows, err
}
// DirectExec is a wrapper for GenericExec and meant to be used when
// directly accessing this library, rather than using database/sql.
//
// The primary advantage are the variadic args, which can be normal
// values and are automatically transformed to driver.NamedValues for
// GenericExec.
func (stmt Stmt) DirectExec(ctx context.Context, args ...interface{}) (driver.Rows, driver.Result, error) {
var namedArgs []driver.NamedValue
if len(args) > 0 {
values := make([]driver.Value, len(args))
for i, arg := range args {
values[i] = driver.Value(arg)
}
namedArgs = dblib.ValuesToNamedValues(values)
}
return stmt.GenericExec(ctx, namedArgs)
}
// GenericExec is the central method through which SQL statements are
// sent to ASE.
func (stmt Stmt) GenericExec(ctx context.Context, args []driver.NamedValue) (driver.Rows, driver.Result, error) {
// Prepare and send payload
stmt.pkg.Type = tds.TDS_DYN_EXEC
if stmt.paramFmt != nil {
stmt.pkg.Status |= tds.TDS_DYNAMIC_HASARGS
}
if err := stmt.conn.Channel.QueuePackage(ctx, stmt.pkg); err != nil {
return nil, nil, fmt.Errorf("error queueing dynamic statement exec package: %w", err)
}
stmt.Reset()
if stmt.paramFmt != nil {
if err := stmt.sendArgs(ctx, args); err != nil {
return nil, nil, fmt.Errorf("error queueing arguments: %w", err)
}
}
if err := stmt.conn.Channel.SendRemainingPackets(ctx); err != nil {
return nil, nil, fmt.Errorf("error sending queued packages for dynamic statement execution: %w", err)
}
// Receive response
if err := stmt.recvDynAck(ctx); err != nil {
return nil, nil, err
}
return stmt.conn.genericResults(ctx)
}
func (stmt Stmt) sendArgs(ctx context.Context, args []driver.NamedValue) error {
dataFields := []tds.FieldData{}
for i, arg := range args {
if err := stmt.CheckNamedValue(&arg); err != nil {
return fmt.Errorf("error checking argument: %w", err)
}
fmtField := stmt.paramFmt.Fmts[i]
// If value is nil, we must check if the datatype is nullable
// and switch to it, if necessary (Nullable datatypes do
// not have a fixed length).
if arg.Value == nil && fmtField.IsFixedLength() {
nullableType, err := fmtField.DataType().NullableType()
if err != nil {
return fmt.Errorf("cannot get nullable datatype of %v: %w", fmtField.DataType(), err)
}
fmtField.SetDataType(nullableType)
}
dataField, err := tds.LookupFieldData(fmtField)
if err != nil {
return fmt.Errorf("unable to find FieldData for datatype %s: %w", fmtField.DataType(), err)
}
dataField.SetValue(arg.Value)
dataFields = append(dataFields, dataField)
}
if err := stmt.conn.Channel.QueuePackage(ctx, stmt.paramFmt); err != nil {
return fmt.Errorf("error queueing dynamic statement parameter format: %w", err)
}
if err := stmt.conn.Channel.QueuePackage(ctx, tds.NewParamsPackage(dataFields...)); err != nil {
return fmt.Errorf("error queueing dynamic statement parameters: %w", err)
}
return nil
}
// CheckNamedValue implements the driver.NamedValueChecker interface.
func (stmt Stmt) CheckNamedValue(nv *driver.NamedValue) error {
if stmt.paramFmt == nil || len(stmt.paramFmt.Fmts) == 0 {
return errors.New("go-ase: statement has no reported arguments")
}
fieldFmts := stmt.paramFmt.Fmts
if nv.Ordinal-1 >= len(fieldFmts) {
return fmt.Errorf("go-ase: ordinal %d (index %d) is larger than the number of expected arguments %d",
nv.Ordinal, nv.Ordinal-1, len(fieldFmts))
}
v, err := asetypes.DefaultValueConverter.ConvertValue(nv.Value)
if err != nil {
return err
}
nv.Value = v
return nil
}