Skip to content

Commit c87f84b

Browse files
committed
Merge pull request go-sql-driver#134 from go-sql-driver/write-buffer
Use the Connection Buffer for Writing
2 parents e29272a + 228ba34 commit c87f84b

File tree

7 files changed

+448
-337
lines changed

7 files changed

+448
-337
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ Changes:
77
- Made closing rows and connections error tolerant. This allows for example deferring rows.Close() without checking for errors
88
- New Logo
99
- Changed the copyright header to include all contributors
10-
- Optimized the read buffer
10+
- Optimized the buffer for reading
11+
- Use the buffer also for writing. This results in zero allocations (by the driver) for most queries
1112
- Improved the LOAD INFILE documentation
1213
- The driver struct is now exported to make the driver directly accessible
1314
- Refactored the driver tests

benchmark_test.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,23 +68,26 @@ func BenchmarkQuery(b *testing.B) {
6868

6969
stmt := tb.checkStmt(db.Prepare("SELECT val FROM foo WHERE id=?"))
7070
defer stmt.Close()
71-
b.StartTimer()
7271

7372
remain := int64(b.N)
7473
var wg sync.WaitGroup
7574
wg.Add(concurrencyLevel)
7675
defer wg.Wait()
76+
b.StartTimer()
77+
7778
for i := 0; i < concurrencyLevel; i++ {
7879
go func() {
79-
defer wg.Done()
8080
for {
8181
if atomic.AddInt64(&remain, -1) < 0 {
82+
wg.Done()
8283
return
8384
}
85+
8486
var got string
8587
tb.check(stmt.QueryRow(1).Scan(&got))
8688
if got != "one" {
8789
b.Errorf("query = %q; want one", got)
90+
wg.Done()
8891
return
8992
}
9093
}

buffer.go

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ import "io"
1212

1313
const defaultBufSize = 4096
1414

15-
// A read buffer similar to bufio.Reader but zero-copy-ish
15+
// A buffer which is used for both reading and writing.
16+
// This is possible since communication on each connection is synchronous.
17+
// In other words, we can't write and read simultaneously on the same connection.
18+
// The buffer is similar to bufio.Reader / Writer but zero-copy-ish
1619
// Also highly optimized for this particular use case.
1720
type buffer struct {
1821
buf []byte
@@ -37,8 +40,11 @@ func (b *buffer) fill(need int) (err error) {
3740
}
3841

3942
// grow buffer if necessary
43+
// TODO: let the buffer shrink again at some point
44+
// Maybe keep the org buf slice and swap back?
4045
if need > len(b.buf) {
41-
newBuf := make([]byte, need)
46+
// Round up to the next multiple of the default size
47+
newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize)
4248
copy(newBuf, b.buf)
4349
b.buf = newBuf
4450
}
@@ -74,3 +80,44 @@ func (b *buffer) readNext(need int) (p []byte, err error) {
7480
b.length -= need
7581
return
7682
}
83+
84+
// returns a buffer with the requested size.
85+
// If possible, a slice from the existing buffer is returned.
86+
// Otherwise a bigger buffer is made.
87+
// Only one buffer (total) can be used at a time.
88+
func (b *buffer) takeBuffer(length int) []byte {
89+
if b.length > 0 {
90+
return nil
91+
}
92+
93+
// test (cheap) general case first
94+
if length <= defaultBufSize || length <= cap(b.buf) {
95+
return b.buf[:length]
96+
}
97+
98+
if length < maxPacketSize {
99+
b.buf = make([]byte, length)
100+
return b.buf
101+
}
102+
return make([]byte, length)
103+
}
104+
105+
// shortcut which can be used if the requested buffer is guaranteed to be
106+
// smaller than defaultBufSize
107+
// Only one buffer (total) can be used at a time.
108+
func (b *buffer) takeSmallBuffer(length int) []byte {
109+
if b.length == 0 {
110+
return b.buf[:length]
111+
}
112+
return nil
113+
}
114+
115+
// takeCompleteBuffer returns the complete existing buffer.
116+
// This can be used if the necessary buffer size is unknown.
117+
// Only one buffer (total) can be used at a time.
118+
func (b *buffer) takeCompleteBuffer() []byte {
119+
if b.length == 0 {
120+
return b.buf
121+
}
122+
return nil
123+
}

connection.go

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -136,14 +136,14 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
136136
columnCount, err := stmt.readPrepareResultPacket()
137137
if err == nil {
138138
if stmt.paramCount > 0 {
139-
stmt.params, err = stmt.mc.readColumns(stmt.paramCount)
139+
stmt.params, err = mc.readColumns(stmt.paramCount)
140140
if err != nil {
141141
return nil, err
142142
}
143143
}
144144

145145
if columnCount > 0 {
146-
err = stmt.mc.readUntilEOF()
146+
err = mc.readUntilEOF()
147147
}
148148
}
149149

@@ -171,26 +171,24 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
171171
}
172172

173173
// Internal function to execute commands
174-
func (mc *mysqlConn) exec(query string) (err error) {
174+
func (mc *mysqlConn) exec(query string) error {
175175
// Send command
176-
err = mc.writeCommandPacketStr(comQuery, query)
176+
err := mc.writeCommandPacketStr(comQuery, query)
177177
if err != nil {
178-
return
178+
return err
179179
}
180180

181181
// Read Result
182-
var resLen int
183-
resLen, err = mc.readResultSetHeaderPacket()
182+
resLen, err := mc.readResultSetHeaderPacket()
184183
if err == nil && resLen > 0 {
185-
err = mc.readUntilEOF()
186-
if err != nil {
187-
return
184+
if err = mc.readUntilEOF(); err != nil {
185+
return err
188186
}
189187

190188
err = mc.readUntilEOF()
191189
}
192190

193-
return
191+
return err
194192
}
195193

196194
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
@@ -211,7 +209,6 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
211209
return rows, err
212210
}
213211
}
214-
215212
return nil, err
216213
}
217214

@@ -221,29 +218,29 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
221218

222219
// Gets the value of the given MySQL System Variable
223220
// The returned byte slice is only valid until the next read
224-
func (mc *mysqlConn) getSystemVar(name string) (val []byte, err error) {
221+
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
225222
// Send command
226-
err = mc.writeCommandPacketStr(comQuery, "SELECT @@"+name)
223+
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
224+
return nil, err
225+
}
226+
227+
// Read Result
228+
resLen, err := mc.readResultSetHeaderPacket()
227229
if err == nil {
228-
// Read Result
229-
var resLen int
230-
resLen, err = mc.readResultSetHeaderPacket()
231-
if err == nil {
232-
rows := &mysqlRows{mc, false, nil, false}
230+
rows := &mysqlRows{mc, false, nil, false}
233231

234-
if resLen > 0 {
235-
// Columns
236-
rows.columns, err = mc.readColumns(resLen)
232+
if resLen > 0 {
233+
// Columns
234+
rows.columns, err = mc.readColumns(resLen)
235+
if err != nil {
236+
return nil, err
237237
}
238+
}
238239

239-
dest := make([]driver.Value, resLen)
240-
err = rows.readRow(dest)
241-
if err == nil {
242-
val = dest[0].([]byte)
243-
err = mc.readUntilEOF()
244-
}
240+
dest := make([]driver.Value, resLen)
241+
if err = rows.readRow(dest); err == nil {
242+
return dest[0].([]byte), mc.readUntilEOF()
245243
}
246244
}
247-
248-
return
245+
return nil, err
249246
}

0 commit comments

Comments
 (0)