From 944568a2c0715eaa25cf0a4eb6d1d1493d27a31f Mon Sep 17 00:00:00 2001 From: Apoorv Deshmukh Date: Fri, 19 Aug 2022 19:31:51 +0530 Subject: [PATCH 1/3] Distinguish non-error output from error output In certain cases such as non-interactive mode it was observed that error messages meant for error stream were captured under output stream. This is not desirable if the user only wants to retrieve the error messages. This commit aligns the go-sqlcmd behaviour with ODBC sqlcmd where error messages are captured in the appropriate error stream. --- pkg/sqlcmd/sqlcmd.go | 14 ++++++++++-- pkg/sqlcmd/sqlcmd_test.go | 46 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/pkg/sqlcmd/sqlcmd.go b/pkg/sqlcmd/sqlcmd.go index a3d34898..fb41addd 100644 --- a/pkg/sqlcmd/sqlcmd.go +++ b/pkg/sqlcmd/sqlcmd.go @@ -85,6 +85,8 @@ func New(l Console, workingDirectory string, vars *Variables) *Sqlcmd { s.PrintError = func(msg string, severity uint8) bool { return false } + s.SetOutput(os.Stdout) + s.SetError(os.Stderr) return s } @@ -133,7 +135,11 @@ func (s *Sqlcmd) Run(once bool, processAll bool) error { args = make([]string, 0) once = true } else { - _, _ = s.GetOutput().Write([]byte(err.Error() + SqlcmdEol)) + if iactive { + _, _ = s.GetOutput().Write([]byte(err.Error() + SqlcmdEol)) + } else { + _, _ = s.GetError().Write([]byte(err.Error() + SqlcmdEol)) + } } } if cmd != nil { @@ -143,7 +149,11 @@ func (s *Sqlcmd) Run(once bool, processAll bool) error { break } if err != nil { - _, _ = s.GetOutput().Write([]byte(err.Error() + SqlcmdEol)) + if iactive { + _, _ = s.GetOutput().Write([]byte(err.Error() + SqlcmdEol)) + } else { + _, _ = s.GetError().Write([]byte(err.Error() + SqlcmdEol)) + } lastError = err } } diff --git a/pkg/sqlcmd/sqlcmd_test.go b/pkg/sqlcmd/sqlcmd_test.go index 60f78263..0f405eda 100644 --- a/pkg/sqlcmd/sqlcmd_test.go +++ b/pkg/sqlcmd/sqlcmd_test.go @@ -109,6 +109,28 @@ func TestSqlCmdQueryAndExit(t *testing.T) { } } +func TestSqlCmdOutputAndError(t *testing.T) { + s, outfile, errfile := setupSqlcmdWithFileErrorOutput(t) + defer os.Remove(outfile.Name()) + defer os.Remove(errfile.Name()) + s.Query = "select $(X" + err := s.Run(true, false) + if assert.NoError(t, err, "s.Run(once = true)") { + bytes, err := os.ReadFile(errfile.Name()) + if assert.NoError(t, err, "os.ReadFile") { + assert.Equal(t, "Sqlcmd: Error: Syntax error at line 1."+SqlcmdEol, string(bytes), "Incorrect output from Run") + } + } + s.Query = "select '1'" + err = s.Run(true, false) + if assert.NoError(t, err, "s.Run(once = true)") { + bytes, err := os.ReadFile(outfile.Name()) + if assert.NoError(t, err, "os.ReadFile") { + assert.Equal(t, "1"+SqlcmdEol+SqlcmdEol+"(1 row affected)"+SqlcmdEol, string(bytes), "Incorrect output from Run") + } + } +} + // Simulate :r command func TestIncludeFileNoExecutions(t *testing.T) { s, file := setupSqlcmdWithFileOutput(t) @@ -475,6 +497,7 @@ func setupSqlCmdWithMemoryOutput(t testing.TB) (*Sqlcmd, *memoryBuffer) { s.Format = NewSQLCmdDefaultFormatter(true) buf := &memoryBuffer{buf: new(bytes.Buffer)} s.SetOutput(buf) + s.SetError(buf) err := s.ConnectDb(nil, true) assert.NoError(t, err, "s.ConnectDB") return s, buf @@ -490,6 +513,7 @@ func setupSqlcmdWithFileOutput(t testing.TB) (*Sqlcmd, *os.File) { file, err := os.CreateTemp("", "sqlcmdout") assert.NoError(t, err, "os.CreateTemp") s.SetOutput(file) + s.SetError(file) err = s.ConnectDb(nil, true) if err != nil { os.Remove(file.Name()) @@ -498,6 +522,28 @@ func setupSqlcmdWithFileOutput(t testing.TB) (*Sqlcmd, *os.File) { return s, file } +func setupSqlcmdWithFileErrorOutput(t testing.TB) (*Sqlcmd, *os.File, *os.File) { + t.Helper() + v := InitializeVariables(true) + v.Set(SQLCMDMAXVARTYPEWIDTH, "0") + s := New(nil, "", v) + s.Connect = newConnect(t) + s.Format = NewSQLCmdDefaultFormatter(true) + outfile, err := os.CreateTemp("", "sqlcmdout") + assert.NoError(t, err, "os.CreateTemp") + errfile, err := os.CreateTemp("", "sqlcmderr") + assert.NoError(t, err, "os.CreateTemp") + s.SetOutput(outfile) + s.SetError(errfile) + err = s.ConnectDb(nil, true) + if err != nil { + os.Remove(outfile.Name()) + os.Remove(errfile.Name()) + } + assert.NoError(t, err, "s.ConnectDB") + return s, outfile, errfile +} + // Assuming public Azure, use AAD when SQLCMDUSER environment variable is not set func canTestAzureAuth() bool { server := os.Getenv(SQLCMDSERVER) From a008562a53ec002f1b2e323aef4f9dfe43a235cd Mon Sep 17 00:00:00 2001 From: Apoorv Deshmukh Date: Tue, 6 Sep 2022 16:39:48 +0530 Subject: [PATCH 2/3] Add flag to indicate interactive session --- pkg/sqlcmd/sqlcmd.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pkg/sqlcmd/sqlcmd.go b/pkg/sqlcmd/sqlcmd.go index fb41addd..5e7b09a9 100644 --- a/pkg/sqlcmd/sqlcmd.go +++ b/pkg/sqlcmd/sqlcmd.go @@ -68,8 +68,9 @@ type Sqlcmd struct { Query string Cmd Commands // PrintError allows the host to redirect errors away from the default output. Returns false if the error is not redirected by the host. - PrintError func(msg string, severity uint8) bool - UnicodeOutputFile bool + PrintError func(msg string, severity uint8) bool + UnicodeOutputFile bool + IsInteractiveSession bool } // New creates a new Sqlcmd instance @@ -91,6 +92,7 @@ func New(l Console, workingDirectory string, vars *Variables) *Sqlcmd { } func (s *Sqlcmd) scanNext() (string, error) { + s.IsInteractiveSession = true return s.lineIo.Readline() } @@ -135,7 +137,7 @@ func (s *Sqlcmd) Run(once bool, processAll bool) error { args = make([]string, 0) once = true } else { - if iactive { + if iactive && s.IsInteractiveSession { _, _ = s.GetOutput().Write([]byte(err.Error() + SqlcmdEol)) } else { _, _ = s.GetError().Write([]byte(err.Error() + SqlcmdEol)) @@ -149,7 +151,7 @@ func (s *Sqlcmd) Run(once bool, processAll bool) error { break } if err != nil { - if iactive { + if iactive && s.IsInteractiveSession { _, _ = s.GetOutput().Write([]byte(err.Error() + SqlcmdEol)) } else { _, _ = s.GetError().Write([]byte(err.Error() + SqlcmdEol)) From e04603020e728b3439d1c11fe4a1aeeec141a282 Mon Sep 17 00:00:00 2001 From: Apoorv Deshmukh Date: Wed, 7 Sep 2022 21:46:54 +0530 Subject: [PATCH 3/3] Add test scenario for includeFile --- pkg/sqlcmd/sqlcmd_test.go | 20 ++++++++++++++++++-- pkg/sqlcmd/testdata/teststdouterr.sql | 4 ++++ 2 files changed, 22 insertions(+), 2 deletions(-) create mode 100644 pkg/sqlcmd/testdata/teststdouterr.sql diff --git a/pkg/sqlcmd/sqlcmd_test.go b/pkg/sqlcmd/sqlcmd_test.go index 25987d37..accaa965 100644 --- a/pkg/sqlcmd/sqlcmd_test.go +++ b/pkg/sqlcmd/sqlcmd_test.go @@ -118,7 +118,7 @@ func TestSqlCmdOutputAndError(t *testing.T) { if assert.NoError(t, err, "s.Run(once = true)") { bytes, err := os.ReadFile(errfile.Name()) if assert.NoError(t, err, "os.ReadFile") { - assert.Equal(t, "Sqlcmd: Error: Syntax error at line 1."+SqlcmdEol, string(bytes), "Incorrect output from Run") + assert.Equal(t, "Sqlcmd: Error: Syntax error at line 1."+SqlcmdEol, string(bytes), "Expected syntax error not received for query execution") } } s.Query = "select '1'" @@ -126,7 +126,23 @@ func TestSqlCmdOutputAndError(t *testing.T) { if assert.NoError(t, err, "s.Run(once = true)") { bytes, err := os.ReadFile(outfile.Name()) if assert.NoError(t, err, "os.ReadFile") { - assert.Equal(t, "1"+SqlcmdEol+SqlcmdEol+"(1 row affected)"+SqlcmdEol, string(bytes), "Incorrect output from Run") + assert.Equal(t, "1"+SqlcmdEol+SqlcmdEol+"(1 row affected)"+SqlcmdEol, string(bytes), "Unexpected output for query execution") + } + } + + s, outfile, errfile = setupSqlcmdWithFileErrorOutput(t) + defer os.Remove(outfile.Name()) + defer os.Remove(errfile.Name()) + dataPath := "testdata" + string(os.PathSeparator) + err = s.IncludeFile(dataPath+"teststdouterr.sql", false) + if assert.NoError(t, err, "IncludeFile teststdouterr.sql false") { + bytes, err := os.ReadFile(outfile.Name()) + if assert.NoError(t, err, "os.ReadFile outfile") { + assert.Equal(t, "1"+SqlcmdEol+SqlcmdEol+"(1 row affected)"+SqlcmdEol, string(bytes), "Unexpected output for sql file execution in outfile") + } + bytes, err = os.ReadFile(errfile.Name()) + if assert.NoError(t, err, "os.ReadFile errfile") { + assert.Equal(t, "Sqlcmd: Error: Syntax error at line 1."+SqlcmdEol, string(bytes), "Expected syntax error not found in errfile") } } } diff --git a/pkg/sqlcmd/testdata/teststdouterr.sql b/pkg/sqlcmd/testdata/teststdouterr.sql new file mode 100644 index 00000000..d09b4113 --- /dev/null +++ b/pkg/sqlcmd/testdata/teststdouterr.sql @@ -0,0 +1,4 @@ +select $(X +go +select '1' +go \ No newline at end of file