Skip to content

Commit

Permalink
Merge pull request #13 from shrek/master
Browse files Browse the repository at this point in the history
return ErrBadConn upon connection errors
  • Loading branch information
asifjalil authored Aug 2, 2019
2 parents de4f9ef + 4c0e4e7 commit b3a0e23
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 3 deletions.
2 changes: 1 addition & 1 deletion column.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ loop:
break loop
case C.SQL_SUCCESS_WITH_INFO:
err := formatError(C.SQL_HANDLE_STMT, C.SQLHANDLE(c.h))
if err.SQLState() != "01004" {
if cliErr, ok := err.(*cliError); !ok || cliErr.SQLState() != "01004" {
return nil, err
}
// buf is not big enough; data has been truncated
Expand Down
51 changes: 51 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ func newTestDB() (*testDB, error) {
connStr := fmt.Sprintf("DATABASE = %s; UID = %s; PWD = %s;",
config.database, config.uid, config.pwd)

if os.Getenv("DATABASE_DSN") != "" {
connStr = os.Getenv("DATABASE_DSN")
}

db, err := sql.Open("cli", connStr)
if err != nil {
return nil, err
Expand Down Expand Up @@ -1351,6 +1355,53 @@ func TestOverflow(t *testing.T) {
info(t, "values: %+v err: %+v", values, rows.Err())
}

func TestConnError(t *testing.T) {
db, err := newTestDB()
if err != nil {
die(t, "Failed to connect to db: %v", err)
}
defer db.Close()

fmt.Println("connected to db")
query := "select current timestamp from sysibm.sysdummy1"
now := time.Time{}

// verify
err = db.QueryRow(query).Scan(&now)
if err != nil {
die(t, "%+v", err)
}
fmt.Printf("%v\n", now)
fmt.Printf("%+v\n", db.Stats())

fmt.Printf("Testing failure, stop the database\n")
for err == nil {
err = db.QueryRow(query).Scan(&now)
if err != nil {
fmt.Printf("found error: %s\n", err.Error())
break
} else {
fmt.Printf("%v\n", now)
fmt.Printf("%+v\n", db.Stats())
}
time.Sleep(2 * time.Second)
}

fmt.Printf("Testing recovery, restart the database\n")
for err != nil {
err = db.QueryRow(query).Scan(&now)
if err == nil {
fmt.Printf("%v\n", now)
fmt.Printf("%+v\n", db.Stats())
break
}
fmt.Printf("found error: %s\n", err.Error())
fmt.Printf("%+v\n", db.Stats())
time.Sleep(2 * time.Second)
}

}

func logf(t *testing.T, format string, a ...interface{}) {
t.Logf(format, a...)
}
Expand Down
14 changes: 12 additions & 2 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package cli
*/
import "C"
import (
"database/sql/driver"
"fmt"
"strings"
"unsafe"
Expand Down Expand Up @@ -42,12 +43,12 @@ func success(ret C.SQLRETURN) bool {
return int(ret) == C.SQL_SUCCESS || int(ret) == C.SQL_SUCCESS_WITH_INFO
}

func formatError(ht C.SQLSMALLINT, h C.SQLHANDLE) (err *cliError) {
func formatError(ht C.SQLSMALLINT, h C.SQLHANDLE) error {
sqlState := make([]uint16, 6)
var sqlCode C.SQLINTEGER
messageText := make([]uint16, C.SQL_MAX_MESSAGE_LENGTH)
var textLength C.SQLSMALLINT
err = &cliError{}
err := &cliError{}
for i := 1; ; i++ {
ret := C.SQLGetDiagRecW(C.SQLSMALLINT(ht),
h,
Expand All @@ -71,6 +72,15 @@ func formatError(ht C.SQLSMALLINT, h C.SQLHANDLE) (err *cliError) {
err.message = strings.TrimSpace(err.message)
}

// https://www.ibm.com/support/knowledgecenter/en/SSEPGG_11.1.0/com.ibm.db2.luw.messages.cli.doc/com.ibm.db2.luw.messages.cli.doc-gentopic1.html
if strings.Contains(err.message, "CLI0106E") ||
strings.Contains(err.message, "CLI0107E") ||
strings.Contains(err.message, "CLI0108E") ||
// http://www-01.ibm.com/support/docview.wss?uid=swg21164785
strings.Contains(err.message, "SQL30081N") {
return driver.ErrBadConn
}

return err
}

Expand Down

0 comments on commit b3a0e23

Please sign in to comment.