Skip to content

Initialize console if -U is specified #121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions cmd/sqlcmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,22 +201,28 @@ func setConnect(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments, vars *sq
connect.ErrorSeverityLevel = args.ErrorSeverityLevel
}

func isConsoleInitializationRequired(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments) bool {
iactive := args.InputFile == nil && args.Query == ""
return iactive || connect.RequiresPassword()
}

func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
wd, err := os.Getwd()
if err != nil {
return 1, err
}

iactive := args.InputFile == nil && args.Query == ""
var connectConfig sqlcmd.ConnectSettings
setConnect(&connectConfig, args, vars)
var line sqlcmd.Console = nil
if iactive {
if isConsoleInitializationRequired(&connectConfig, args) {
line = console.NewConsole("")
defer line.Close()
}

s := sqlcmd.New(line, wd, vars)
s.UnicodeOutputFile = args.UnicodeOutputFile
setConnect(&s.Connect, args, vars)

if args.BatchTerminator != "GO" {
err = s.Cmd.SetBatchTerminator(args.BatchTerminator)
if err != nil {
Expand All @@ -227,7 +233,7 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
return 1, err
}

setConnect(&s.Connect, args, vars)
s.Connect = &connectConfig
s.Format = sqlcmd.NewSQLCmdDefaultFormatter(false)
if args.OutputFile != "" {
err = s.RunCommand(s.Cmd["OUT"], []string{args.OutputFile})
Expand Down Expand Up @@ -257,10 +263,12 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
s.Query = args.Query
}
// connect using no overrides
err = s.ConnectDb(nil, !iactive)
err = s.ConnectDb(nil, line == nil)
if err != nil {
return 1, err
}

iactive := args.InputFile == nil && args.Query == ""
if iactive || s.Query != "" {
err = s.Run(once, false)
} else {
Expand Down
49 changes: 49 additions & 0 deletions cmd/sqlcmd/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"testing"

"github.com/alecthomas/kong"
"github.com/microsoft/go-mssqldb/azuread"
"github.com/microsoft/go-sqlcmd/pkg/sqlcmd"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -327,6 +328,54 @@ func TestMissingInputFile(t *testing.T) {
assert.Equal(t, 1, exitCode, "exitCode")
}

func TestConditionsForPasswordPrompt(t *testing.T) {

type test struct {
authenticationMethod string
inputFile []string
username string
pwd string
expectedResult bool
}
tests := []test{
// Positive Testcases
{sqlcmd.SqlPassword, []string{""}, "someuser", "", true},
{sqlcmd.NotSpecified, []string{"testdata/someFile.sql"}, "someuser", "", true},
{azuread.ActiveDirectoryPassword, []string{""}, "someuser", "", true},
{azuread.ActiveDirectoryPassword, []string{"testdata/someFile.sql"}, "someuser", "", true},
{azuread.ActiveDirectoryServicePrincipal, []string{""}, "someuser", "", true},
{azuread.ActiveDirectoryServicePrincipal, []string{"testdata/someFile.sql"}, "someuser", "", true},
{azuread.ActiveDirectoryApplication, []string{""}, "someuser", "", true},
{azuread.ActiveDirectoryApplication, []string{"testdata/someFile.sql"}, "someuser", "", true},

//Negative Testcases
{sqlcmd.NotSpecified, []string{""}, "", "", false},
{sqlcmd.NotSpecified, []string{"testdata/someFile.sql"}, "", "", false},
{azuread.ActiveDirectoryDefault, []string{""}, "someuser", "", false},
{azuread.ActiveDirectoryDefault, []string{"testdata/someFile.sql"}, "someuser", "", false},
{azuread.ActiveDirectoryInteractive, []string{""}, "someuser", "", false},
{azuread.ActiveDirectoryInteractive, []string{"testdata/someFile.sql"}, "someuser", "", false},
{azuread.ActiveDirectoryManagedIdentity, []string{""}, "someuser", "", false},
{azuread.ActiveDirectoryManagedIdentity, []string{"testdata/someFile.sql"}, "someuser", "", false},
}

for _, testcase := range tests {
t.Log(testcase.authenticationMethod, testcase.inputFile, testcase.username, testcase.pwd, testcase.expectedResult)
args := newArguments()
args.DisableCmdAndWarn = true
args.InputFile = testcase.inputFile
args.UserName = testcase.username
vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn)
setVars(vars, &args)
var connectConfig sqlcmd.ConnectSettings
setConnect(&connectConfig, &args, vars)
connectConfig.AuthenticationMethod = testcase.authenticationMethod
connectConfig.Password = testcase.pwd
assert.Equal(t, testcase.expectedResult, isConsoleInitializationRequired(&connectConfig, &args), "Unexpected test result encountered for console initialization")
assert.Equal(t, testcase.expectedResult, connectConfig.RequiresPassword() && connectConfig.Password == "", "Unexpected test result encountered for password prompt conditions")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sse

this assert is odd. Shouldn't you assert that if RequiresPassword is true then Password == ""?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. This was to avoid the conditional check and to assert the same condition for password prompt that we use in connectDb()

}
}

// Assuming public Azure, use AAD when SQLCMDUSER environment variable is not set
func canTestAzureAuth() bool {
server := os.Getenv(sqlcmd.SQLCMDSERVER)
Expand Down
2 changes: 1 addition & 1 deletion pkg/sqlcmd/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ func connectCommand(s *Sqlcmd, args []string, line uint) error {
return InvalidCommandError("CONNECT", line)
}

connect := s.Connect
connect := *s.Connect
connect.UserName, _ = resolveArgumentVariables(s, []rune(arguments.Username), false)
connect.Password, _ = resolveArgumentVariables(s, []rune(arguments.Password), false)
connect.ServerName, _ = resolveArgumentVariables(s, []rune(arguments.Server), false)
Expand Down
2 changes: 1 addition & 1 deletion pkg/sqlcmd/commands_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func TestConnectCommand(t *testing.T) {
err := connectCommand(s, []string{"someserver -U someuser"}, 1)
assert.NoError(t, err, "connectCommand with valid arguments doesn't return an error on connect failure")
assert.True(t, prompted, "connectCommand with user name and no password should prompt for password")
assert.NotEqual(t, "someserver", s.Connect.ServerName, "On error, sqlCmd.Connect does not copy inputs")
assert.NotEqual(t, "someserver", s.Connect.ServerName, "On connection failure, sqlCmd.Connect does not copy inputs")

Copy link
Collaborator

@shueybubbles shueybubbles Sep 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this changing? s.Connect.ServerName being changed implies that the connection attempt succeeded when it should have failed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the context. This was a side effect of changing type of connect from ConnectSettings to a pointer to ConnectSettings in sqlcmd struct.
In connectCommand() we copy s.Connect to local connect object to establish connection.
I will fix this in my next commit.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the message in the assert didn't provide sufficient context, please update the assert message to be clearer.

err = connectCommand(s, []string{}, 2)
assert.EqualError(t, err, InvalidCommandError("CONNECT", 2).Error(), ":Connect with no arguments should return an error")
Expand Down
2 changes: 1 addition & 1 deletion pkg/sqlcmd/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (connect ConnectSettings) sqlAuthentication() bool {
(!connect.UseTrustedConnection && connect.authenticationMethod() == NotSpecified && connect.UserName != "")
}

func (connect ConnectSettings) requiresPassword() bool {
func (connect ConnectSettings) RequiresPassword() bool {
requiresPassword := connect.sqlAuthentication()
if !requiresPassword {
switch connect.authenticationMethod() {
Expand Down
9 changes: 5 additions & 4 deletions pkg/sqlcmd/sqlcmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ type Sqlcmd struct {
batch *Batch
// Exitcode is returned to the operating system when the process exits
Exitcode int
Connect ConnectSettings
Connect *ConnectSettings
vars *Variables
Format Formatter
Query string
Expand All @@ -79,6 +79,7 @@ func New(l Console, workingDirectory string, vars *Variables) *Sqlcmd {
workingDirectory: workingDirectory,
vars: vars,
Cmd: newCommands(),
Connect: &ConnectSettings{},
}
s.batch = NewBatch(s.scanNext, s.Cmd)
mssql.SetContextLogger(s)
Expand Down Expand Up @@ -213,12 +214,12 @@ func (s *Sqlcmd) SetError(e io.WriteCloser) {
func (s *Sqlcmd) ConnectDb(connect *ConnectSettings, nopw bool) error {
newConnection := connect != nil
if connect == nil {
connect = &s.Connect
connect = s.Connect
}

var connector driver.Connector
useAad := !connect.sqlAuthentication() && !connect.integratedAuthentication()
if connect.requiresPassword() && !nopw && connect.Password == "" {
if connect.RequiresPassword() && !nopw && connect.Password == "" {
var err error
if connect.Password, err = s.promptPassword(); err != nil {
return err
Expand Down Expand Up @@ -259,7 +260,7 @@ func (s *Sqlcmd) ConnectDb(connect *ConnectSettings, nopw bool) error {
s.vars.Set(SQLCMDUSER, u.Username)
}
if newConnection {
s.Connect = *connect
s.Connect = connect
}
if s.batch != nil {
s.batch.batchline = 1
Expand Down
8 changes: 4 additions & 4 deletions pkg/sqlcmd/sqlcmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,10 @@ func TestPromptForPasswordPositive(t *testing.T) {
v := InitializeVariables(true)
s := New(console, "", v)
// attempt without password prompt
err := s.ConnectDb(&c, true)
err := s.ConnectDb(c, true)
assert.False(t, prompted, "ConnectDb with nopw=true should not prompt for password")
assert.Error(t, err, "ConnectDb with nopw==true and no password provided")
err = s.ConnectDb(&c, false)
err = s.ConnectDb(c, false)
assert.True(t, prompted, "ConnectDb with !nopw should prompt for password")
assert.NoError(t, err, "ConnectDb with !nopw and valid password returned from prompt")
if s.Connect.Password != password {
Expand Down Expand Up @@ -506,7 +506,7 @@ func canTestAzureAuth() bool {
return strings.Contains(server, ".database.windows.net") && userName == ""
}

func newConnect(t testing.TB) ConnectSettings {
func newConnect(t testing.TB) *ConnectSettings {
t.Helper()
connect := ConnectSettings{
UserName: os.Getenv(SQLCMDUSER),
Expand All @@ -518,5 +518,5 @@ func newConnect(t testing.TB) ConnectSettings {
t.Log("Using ActiveDirectoryDefault")
connect.AuthenticationMethod = azuread.ActiveDirectoryDefault
}
return connect
return &connect
}