Skip to content

Distinguish non-error output from error output #122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions pkg/sqlcmd/sqlcmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ type Sqlcmd struct {
Query string
Cmd Commands
// PrintError allows the host to redirect errors away from the default output. Returns false if the error is not redirected by the host.
PrintError func(msg string, severity uint8) bool
UnicodeOutputFile bool
PrintError func(msg string, severity uint8) bool
UnicodeOutputFile bool
IsInteractiveSession bool
}

// New creates a new Sqlcmd instance
Expand All @@ -86,10 +87,13 @@ func New(l Console, workingDirectory string, vars *Variables) *Sqlcmd {
s.PrintError = func(msg string, severity uint8) bool {
return false
}
s.SetOutput(os.Stdout)
s.SetError(os.Stderr)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s.SetError(os.Stderr)

this is a change in behavior for any other places that call GetError(), now it's going to return stderr by default instead of stdout. Please verify that's the right behavior in those other places.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This functionality change is leading to undesired side effects such as redirecting server errors to stderr as well. (I will close this PR and create new one)
I have reworked the changes to keep existing functionality as is in PR #143 and redirecting only the sqlcmd errors which is what the ODBC is also doing.

return s
}

func (s *Sqlcmd) scanNext() (string, error) {
s.IsInteractiveSession = true
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment on why this is being set

return s.lineIo.Readline()
}

Expand Down Expand Up @@ -134,7 +138,11 @@ func (s *Sqlcmd) Run(once bool, processAll bool) error {
args = make([]string, 0)
once = true
} else {
_, _ = s.GetOutput().Write([]byte(err.Error() + SqlcmdEol))
if iactive && s.IsInteractiveSession {
_, _ = s.GetOutput().Write([]byte(err.Error() + SqlcmdEol))
} else {
_, _ = s.GetError().Write([]byte(err.Error() + SqlcmdEol))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you need to distinguish the case of iactive being true because of the user name parameter being passed in, which is part of your other PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe Sqlcmd should have a property like "WriteErrorsToErrorStream bool" that would be set by main before it calls Run.

There are some other settings that control command output destination that this could be like.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added flag in sqlcmd struct to indicate an interactive session.
The interactive session will be determined when the first input is scanned.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it make sense to fix s.GetError to do the right thing based on context and just change this code to use GetError() only?
I'm not 100% certain that every kind of error in sqlcmd goes to the same stream, though.

}
}
}
if cmd != nil {
Expand All @@ -144,7 +152,11 @@ func (s *Sqlcmd) Run(once bool, processAll bool) error {
break
}
if err != nil {
_, _ = s.GetOutput().Write([]byte(err.Error() + SqlcmdEol))
if iactive && s.IsInteractiveSession {
_, _ = s.GetOutput().Write([]byte(err.Error() + SqlcmdEol))
} else {
_, _ = s.GetError().Write([]byte(err.Error() + SqlcmdEol))
}
lastError = err
}
}
Expand Down
62 changes: 62 additions & 0 deletions pkg/sqlcmd/sqlcmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,44 @@ func TestSqlCmdQueryAndExit(t *testing.T) {
}
}

func TestSqlCmdOutputAndError(t *testing.T) {
s, outfile, errfile := setupSqlcmdWithFileErrorOutput(t)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't you need to test the s.IncludeFile path too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added test for s.IncludeFile

defer os.Remove(outfile.Name())
defer os.Remove(errfile.Name())
s.Query = "select $(X"
err := s.Run(true, false)
if assert.NoError(t, err, "s.Run(once = true)") {
bytes, err := os.ReadFile(errfile.Name())
if assert.NoError(t, err, "os.ReadFile") {
assert.Equal(t, "Sqlcmd: Error: Syntax error at line 1."+SqlcmdEol, string(bytes), "Expected syntax error not received for query execution")
}
}
s.Query = "select '1'"
err = s.Run(true, false)
if assert.NoError(t, err, "s.Run(once = true)") {
bytes, err := os.ReadFile(outfile.Name())
if assert.NoError(t, err, "os.ReadFile") {
assert.Equal(t, "1"+SqlcmdEol+SqlcmdEol+"(1 row affected)"+SqlcmdEol, string(bytes), "Unexpected output for query execution")
}
}

s, outfile, errfile = setupSqlcmdWithFileErrorOutput(t)
defer os.Remove(outfile.Name())
defer os.Remove(errfile.Name())
dataPath := "testdata" + string(os.PathSeparator)
err = s.IncludeFile(dataPath+"teststdouterr.sql", false)
if assert.NoError(t, err, "IncludeFile teststdouterr.sql false") {
bytes, err := os.ReadFile(outfile.Name())
if assert.NoError(t, err, "os.ReadFile outfile") {
assert.Equal(t, "1"+SqlcmdEol+SqlcmdEol+"(1 row affected)"+SqlcmdEol, string(bytes), "Unexpected output for sql file execution in outfile")
}
bytes, err = os.ReadFile(errfile.Name())
if assert.NoError(t, err, "os.ReadFile errfile") {
assert.Equal(t, "Sqlcmd: Error: Syntax error at line 1."+SqlcmdEol, string(bytes), "Expected syntax error not found in errfile")
}
}
}

// Simulate :r command
func TestIncludeFileNoExecutions(t *testing.T) {
s, file := setupSqlcmdWithFileOutput(t)
Expand Down Expand Up @@ -476,6 +514,7 @@ func setupSqlCmdWithMemoryOutput(t testing.TB) (*Sqlcmd, *memoryBuffer) {
s.Format = NewSQLCmdDefaultFormatter(true)
buf := &memoryBuffer{buf: new(bytes.Buffer)}
s.SetOutput(buf)
s.SetError(buf)
err := s.ConnectDb(nil, true)
assert.NoError(t, err, "s.ConnectDB")
return s, buf
Expand All @@ -491,6 +530,7 @@ func setupSqlcmdWithFileOutput(t testing.TB) (*Sqlcmd, *os.File) {
file, err := os.CreateTemp("", "sqlcmdout")
assert.NoError(t, err, "os.CreateTemp")
s.SetOutput(file)
s.SetError(file)
err = s.ConnectDb(nil, true)
if err != nil {
os.Remove(file.Name())
Expand All @@ -499,6 +539,28 @@ func setupSqlcmdWithFileOutput(t testing.TB) (*Sqlcmd, *os.File) {
return s, file
}

func setupSqlcmdWithFileErrorOutput(t testing.TB) (*Sqlcmd, *os.File, *os.File) {
t.Helper()
v := InitializeVariables(true)
v.Set(SQLCMDMAXVARTYPEWIDTH, "0")
s := New(nil, "", v)
s.Connect = newConnect(t)
s.Format = NewSQLCmdDefaultFormatter(true)
outfile, err := os.CreateTemp("", "sqlcmdout")
assert.NoError(t, err, "os.CreateTemp")
errfile, err := os.CreateTemp("", "sqlcmderr")
assert.NoError(t, err, "os.CreateTemp")
s.SetOutput(outfile)
s.SetError(errfile)
err = s.ConnectDb(nil, true)
if err != nil {
os.Remove(outfile.Name())
os.Remove(errfile.Name())
}
assert.NoError(t, err, "s.ConnectDB")
return s, outfile, errfile
}

// Assuming public Azure, use AAD when SQLCMDUSER environment variable is not set
func canTestAzureAuth() bool {
server := os.Getenv(SQLCMDSERVER)
Expand Down
4 changes: 4 additions & 0 deletions pkg/sqlcmd/testdata/teststdouterr.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
select $(X
go
select '1'
go