diff --git a/infile.go b/infile.go index 952bd3d6d..121a04c71 100644 --- a/infile.go +++ b/infile.go @@ -9,7 +9,6 @@ package mysql import ( - "database/sql/driver" "fmt" "io" "os" @@ -21,11 +20,6 @@ var ( readerRegister map[string]func() io.Reader ) -func init() { - fileRegister = make(map[string]bool) - readerRegister = make(map[string]func() io.Reader) -} - // RegisterLocalFile adds the given file to the file whitelist, // so that it can be used by "LOAD DATA LOCAL INFILE ". // Alternatively you can allow the use of all local files with @@ -38,6 +32,11 @@ func init() { // ... // func RegisterLocalFile(filePath string) { + // lazy map init + if fileRegister == nil { + fileRegister = make(map[string]bool) + } + fileRegister[strings.Trim(filePath, `"`)] = true } @@ -62,6 +61,11 @@ func DeregisterLocalFile(filePath string) { // ... // func RegisterReaderHandler(name string, handler func() io.Reader) { + // lazy map init + if readerRegister == nil { + readerRegister = make(map[string]func() io.Reader) + } + readerRegister[name] = handler } @@ -71,71 +75,81 @@ func DeregisterReaderHandler(name string) { delete(readerRegister, name) } +func deferredClose(err *error, closer io.Closer) { + closeErr := closer.Close() + if *err == nil { + *err = closeErr + } +} + func (mc *mysqlConn) handleInFileRequest(name string) (err error) { var rdr io.Reader - data := make([]byte, 4+mc.maxWriteSize) + var data []byte if strings.HasPrefix(name, "Reader::") { // io.Reader name = name[8:] - handler, inMap := readerRegister[name] - if handler != nil { + if handler, inMap := readerRegister[name]; inMap { rdr = handler() - } - if rdr == nil { - if !inMap { - err = fmt.Errorf("Reader '%s' is not registered", name) + if rdr != nil { + data = make([]byte, 4+mc.maxWriteSize) + + if cl, ok := rdr.(io.Closer); ok { + defer deferredClose(&err, cl) + } } else { err = fmt.Errorf("Reader '%s' is ", name) } + } else { + err = fmt.Errorf("Reader '%s' is not registered", name) } } else { // File name = strings.Trim(name, `"`) if mc.cfg.allowAllFiles || fileRegister[name] { - rdr, err = os.Open(name) + var file *os.File + var fi os.FileInfo + + if file, err = os.Open(name); err == nil { + defer deferredClose(&err, file) + + // get file size + if fi, err = file.Stat(); err == nil { + rdr = file + if fileSize := int(fi.Size()); fileSize <= mc.maxWriteSize { + data = make([]byte, 4+fileSize) + } else if fileSize <= mc.maxPacketAllowed { + data = make([]byte, 4+mc.maxWriteSize) + } else { + err = fmt.Errorf("Local File '%s' too large: Size: %d, Max: %d", name, fileSize, mc.maxPacketAllowed) + } + } + } } else { err = fmt.Errorf("Local File '%s' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files", name) } } - if rdc, ok := rdr.(io.ReadCloser); ok { - defer func() { - if err == nil { - err = rdc.Close() - } else { - rdc.Close() - } - }() - } - // send content packets - var ioErr error if err == nil { var n int - for err == nil && ioErr == nil { + for err == nil { n, err = rdr.Read(data[4:]) if n > 0 { - ioErr = mc.writePacket(data[:4+n]) + if ioErr := mc.writePacket(data[:4+n]); ioErr != nil { + return ioErr + } } } if err == io.EOF { err = nil } - if ioErr != nil { - errLog.Print(ioErr.Error()) - return driver.ErrBadConn - } } // send empty packet (termination) - ioErr = mc.writePacket([]byte{ - 0x00, - 0x00, - 0x00, - mc.sequence, - }) - if ioErr != nil { - errLog.Print(ioErr.Error()) - return driver.ErrBadConn + if data == nil { + data = make([]byte, 4) + } + if ioErr := mc.writePacket(data[:4]); ioErr != nil { + return ioErr } // read OK packet