diff --git a/pkg/sqlcmd/format_test.go b/pkg/sqlcmd/format_test.go index 0ba31670..39224500 100644 --- a/pkg/sqlcmd/format_test.go +++ b/pkg/sqlcmd/format_test.go @@ -4,6 +4,7 @@ package sqlcmd import ( + "context" "strings" "testing" @@ -62,7 +63,7 @@ func TestCalcColumnDetails(t *testing.T) { if assert.NoError(t, err, "ConnectDB failed") { defer db.Close() for x, test := range tests { - rows, err := db.Query(test.query) + rows, err := db.QueryContext(context.Background(), test.query) if assert.NoError(t, err, "Query failed: %s", test.query) { defer rows.Close() cols, err := rows.ColumnTypes() diff --git a/pkg/sqlcmd/sqlcmd.go b/pkg/sqlcmd/sqlcmd.go index ccbdb499..9f42540d 100644 --- a/pkg/sqlcmd/sqlcmd.go +++ b/pkg/sqlcmd/sqlcmd.go @@ -50,7 +50,7 @@ type Console interface { type Sqlcmd struct { lineIo Console workingDirectory string - db *sql.DB + db *sql.Conn out io.WriteCloser err io.WriteCloser batch *Batch @@ -232,8 +232,7 @@ func (s *Sqlcmd) ConnectDb(connect *ConnectSettings, nopw bool) error { if err != nil { return err } - db := sql.OpenDB(connector) - err = db.Ping() + db, err := sql.OpenDB(connector).Conn(context.Background()) if err != nil { fmt.Fprintln(s.GetOutput(), err) return err diff --git a/pkg/sqlcmd/sqlcmd_test.go b/pkg/sqlcmd/sqlcmd_test.go index f2666fa7..7014f336 100644 --- a/pkg/sqlcmd/sqlcmd_test.go +++ b/pkg/sqlcmd/sqlcmd_test.go @@ -87,7 +87,7 @@ func TestSqlCmdConnectDb(t *testing.T) { } } -func ConnectDb(t testing.TB) (*sql.DB, error) { +func ConnectDb(t testing.TB) (*sql.Conn, error) { v := InitializeVariables(true) s := &Sqlcmd{vars: v} s.Connect = newConnect(t) @@ -408,6 +408,15 @@ func TestSqlCmdDefersToPrintError(t *testing.T) { } } +func TestSqlCmdMaintainsConnectionBetweenBatches(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + err := runSqlCmd(t, s, []string{"CREATE TABLE #tmp1 (col1 int)", "insert into #tmp1 values (1)", "GO", "select * from #tmp1", "drop table #tmp1", "GO"}) + if assert.NoError(t, err, "runSqlCmd failed") { + assert.Equal(t, oneRowAffected+SqlcmdEol+"1"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol, buf.buf.String(), "Sqlcmd uses the same connection for all queries") + } +} + // runSqlCmd uses lines as input for sqlcmd instead of relying on file or console input func runSqlCmd(t testing.TB, s *Sqlcmd, lines []string) error { t.Helper()