From 779958451ebaeb08fcdc23c82fd31639a4e89abf Mon Sep 17 00:00:00 2001 From: stuartpa Date: Wed, 22 Jun 2022 08:21:49 -0700 Subject: [PATCH 1/8] Add --version flag --- build/azure-pipelines/build-common.yml | 2 +- cmd/sqlcmd/main.go | 564 ++++++++++++------------ cmd/sqlcmd/main_test.go | 577 +++++++++++++------------ 3 files changed, 579 insertions(+), 564 deletions(-) diff --git a/build/azure-pipelines/build-common.yml b/build/azure-pipelines/build-common.yml index 99a4bf13..b56108cb 100644 --- a/build/azure-pipelines/build-common.yml +++ b/build/azure-pipelines/build-common.yml @@ -45,7 +45,7 @@ steps: displayName: 'Go: build sqlcmd' inputs: command: 'build' - arguments: '-o $(Build.BinariesDirectory)' + arguments: '-o $(Build.BinariesDirectory) -ldflags="-X main.version=$(Build.BuildNumber)"' workingDirectory: '$(Build.SourcesDirectory)/cmd/sqlcmd' env: GOOS: ${{ parameters.OS }} diff --git a/cmd/sqlcmd/main.go b/cmd/sqlcmd/main.go index f9a893be..ac166527 100644 --- a/cmd/sqlcmd/main.go +++ b/cmd/sqlcmd/main.go @@ -1,276 +1,288 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. -//go:generate go-winres make --file-version=git-tag --product-version=git-tag -package main - -import ( - "fmt" - "os" - - "github.com/alecthomas/kong" - "github.com/microsoft/go-mssqldb/azuread" - "github.com/microsoft/go-sqlcmd/pkg/console" - "github.com/microsoft/go-sqlcmd/pkg/sqlcmd" -) - -// SQLCmdArguments defines the command line arguments for sqlcmd -// The exhaustive list is at https://docs.microsoft.com/sql/tools/sqlcmd-utility?view=sql-server-ver15 -type SQLCmdArguments struct { - // Which batch terminator to use. Default is GO - BatchTerminator string `short:"c" default:"GO" arghelp:"Specifies the batch terminator. The default value is GO."` - // Whether to trust the server certificate on an encrypted connection - TrustServerCertificate bool `short:"C" help:"Implicitly trust the server certificate without validation."` - DatabaseName string `short:"d" help:"This option sets the sqlcmd scripting variable SQLCMDDBNAME. This parameter specifies the initial database. The default is your login's default-database property. If the database does not exist, an error message is generated and sqlcmd exits."` - UseTrustedConnection bool `short:"E" xor:"uid, auth" help:"Uses a trusted connection instead of using a user name and password to sign in to SQL Server, ignoring any any environment variables that define user name and password."` - UserName string `short:"U" xor:"uid" help:"The login name or contained database user name. For contained database users, you must provide the database name option"` - // Files from which to read query text - InputFile []string `short:"i" xor:"input1, input2" type:"existingFile" help:"Identifies one or more files that contain batches of SQL statements. If one or more files do not exist, sqlcmd will exit. Mutually exclusive with -Q/-q."` - OutputFile string `short:"o" type:"path" help:"Identifies the file that receives output from sqlcmd."` - // First query to run in interactive mode - InitialQuery string `short:"q" xor:"input1" help:"Executes a query when sqlcmd starts, but does not exit sqlcmd when the query has finished running. Multiple-semicolon-delimited queries can be executed."` - // Query to run then exit - Query string `short:"Q" xor:"input2" help:"Executes a query when sqlcmd starts and then immediately exits sqlcmd. Multiple-semicolon-delimited queries can be executed."` - Server string `short:"S" help:"[tcp:]server[\\instance_name][,port]Specifies the instance of SQL Server to which to connect. It sets the sqlcmd scripting variable SQLCMDSERVER."` - // Disable syscommands with a warning - DisableCmdAndWarn bool `short:"X" xor:"syscmd" help:"Disables commands that might compromise system security. Sqlcmd issues a warning and continues."` - // AuthenticationMethod is new for go-sqlcmd - AuthenticationMethod string `xor:"auth" help:"Specifies the SQL authentication method to use to connect to Azure SQL Database. One of:ActiveDirectoryDefault,ActiveDirectoryIntegrated,ActiveDirectoryPassword,ActiveDirectoryInteractive,ActiveDirectoryManagedIdentity,ActiveDirectoryServicePrincipal,SqlPassword"` - UseAad bool `short:"G" xor:"auth" help:"Tells sqlcmd to use Active Directory authentication. If no user name is provided, authentication method ActiveDirectoryDefault is used. If a password is provided, ActiveDirectoryPassword is used. Otherwise ActiveDirectoryInteractive is used."` - DisableVariableSubstitution bool `short:"x" help:"Causes sqlcmd to ignore scripting variables. This parameter is useful when a script contains many INSERT statements that may contain strings that have the same format as regular variables, such as $(variable_name)."` - Variables map[string]string `short:"v" help:"Creates a sqlcmd scripting variable that can be used in a sqlcmd script. Enclose the value in quotation marks if the value contains spaces. You can specify multiple var=values values. If there are errors in any of the values specified, sqlcmd generates an error message and then exits"` - PacketSize int `short:"a" help:"Requests a packet of a different size. This option sets the sqlcmd scripting variable SQLCMDPACKETSIZE. packet_size must be a value between 512 and 32767. The default = 4096. A larger packet size can enhance performance for execution of scripts that have lots of SQL statements between GO commands. You can request a larger packet size. However, if the request is denied, sqlcmd uses the server default for packet size."` - LoginTimeout int `short:"l" default:"-1" help:"Specifies the number of seconds before a sqlcmd login to the go-mssqldb driver times out when you try to connect to a server. This option sets the sqlcmd scripting variable SQLCMDLOGINTIMEOUT. The default value is 30. 0 means infinite."` - WorkstationName string `short:"H" help:"This option sets the sqlcmd scripting variable SQLCMDWORKSTATION. The workstation name is listed in the hostname column of the sys.sysprocesses catalog view and can be returned using the stored procedure sp_who. If this option is not specified, the default is the current computer name. This name can be used to identify different sqlcmd sessions."` - ApplicationIntent string `short:"K" default:"default" enum:"default,ReadOnly" help:"Declares the application workload type when connecting to a server. The only currently supported value is ReadOnly. If -K is not specified, the sqlcmd utility will not support connectivity to a secondary replica in an Always On availability group."` - EncryptConnection string `short:"N" default:"default" enum:"default,false,true,disable" help:"This switch is used by the client to request an encrypted connection."` - DriverLoggingLevel int `help:"Level of mssql driver messages to print."` - ExitOnError bool `short:"b" help:"Specifies that sqlcmd exits and returns a DOS ERRORLEVEL value when an error occurs."` - ErrorSeverityLevel uint8 `short:"V" help:"Controls the severity level that is used to set the ERRORLEVEL variable on exit."` - ErrorLevel int `short:"m" help:"Controls which error messages are sent to stdout. Messages that have severity level greater than or equal to this level are sent."` - Format string `short:"F" help:"Specifies the formatting for results." default:"horiz" enum:"horiz,horizontal,vert,vertical"` - ErrorsToStderr int `short:"r" help:"Redirects the error message output to the screen (stderr). A value of 0 means messages with severity >= 11 will b redirected. A value of 1 means all error message output including PRINT is redirected." enum:"-1,0,1" default:"-1"` - Headers int `short:"h" help:"Specifies the number of rows to print between the column headings. Use -h-1 to specify that headers not be printed."` - UnicodeOutputFile bool `short:"u" help:"Specifies that all output files are encoded with little-endian Unicode"` - // Keep Help at the end of the list - Help bool `short:"?" help:"Show syntax summary."` -} - -// Validate accounts for settings not described by Kong attributes -func (a *SQLCmdArguments) Validate() error { - if a.PacketSize != 0 && (a.PacketSize < 512 || a.PacketSize > 32767) { - return fmt.Errorf(`'-a %d': Packet size has to be a number between 512 and 32767.`, a.PacketSize) - } - // Ignore 0 even though it's technically an invalid input - if a.Headers < -1 { - return fmt.Errorf(`'-h %d': header value must be either -1 or a value between 1 and 2147483647`, a.Headers) - } - return nil -} - -// newArguments constructs a SQLCmdArguments instance with default values -// Any parameter with a "default" Kong attribute should have an assignment here -func newArguments() SQLCmdArguments { - return SQLCmdArguments{ - BatchTerminator: "GO", - } -} - -// Breaking changes in command line are listed here. -// Any switch not listed in breaking changes and not also included in SqlCmdArguments just has not been implemented yet -// 1. -P: Passwords have to be provided through SQLCMDPASSWORD environment variable or typed when prompted -// 2. -R: Go runtime doesn't expose user locale information and syscall would only enable it on Windows, so we won't try to implement it - -var args SQLCmdArguments - -func (a SQLCmdArguments) authenticationMethod(hasPassword bool) string { - if a.UseTrustedConnection { - return sqlcmd.NotSpecified - } - if a.UseAad { - switch { - case a.UserName == "": - return azuread.ActiveDirectoryIntegrated - case hasPassword: - return azuread.ActiveDirectoryPassword - default: - return azuread.ActiveDirectoryInteractive - } - } - if a.AuthenticationMethod == "" { - return sqlcmd.NotSpecified - } - return a.AuthenticationMethod -} - -func main() { - ctx := kong.Parse(&args, kong.NoDefaultHelp()) - if args.Help { - _ = ctx.PrintUsage(false) - os.Exit(0) - } - vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) - setVars(vars, &args) - - // so far sqlcmd prints all the errors itself so ignore it - exitCode, _ := run(vars, &args) - os.Exit(exitCode) -} - -// setVars initializes scripting variables from command line arguments -func setVars(vars *sqlcmd.Variables, args *SQLCmdArguments) { - varmap := map[string]func(*SQLCmdArguments) string{ - sqlcmd.SQLCMDDBNAME: func(a *SQLCmdArguments) string { return a.DatabaseName }, - sqlcmd.SQLCMDLOGINTIMEOUT: func(a *SQLCmdArguments) string { - if a.LoginTimeout > -1 { - return fmt.Sprint(a.LoginTimeout) - } - return "" - }, - sqlcmd.SQLCMDUSEAAD: func(a *SQLCmdArguments) string { - if a.UseAad { - return "true" - } - switch a.AuthenticationMethod { - case azuread.ActiveDirectoryIntegrated: - case azuread.ActiveDirectoryInteractive: - case azuread.ActiveDirectoryPassword: - return "true" - } - return "" - }, - sqlcmd.SQLCMDWORKSTATION: func(a *SQLCmdArguments) string { return args.WorkstationName }, - sqlcmd.SQLCMDSERVER: func(a *SQLCmdArguments) string { return a.Server }, - sqlcmd.SQLCMDERRORLEVEL: func(a *SQLCmdArguments) string { return fmt.Sprint(a.ErrorLevel) }, - sqlcmd.SQLCMDPACKETSIZE: func(a *SQLCmdArguments) string { - if args.PacketSize > 0 { - return fmt.Sprint(args.PacketSize) - } - return "" - }, - sqlcmd.SQLCMDUSER: func(a *SQLCmdArguments) string { return a.UserName }, - sqlcmd.SQLCMDSTATTIMEOUT: func(a *SQLCmdArguments) string { return "" }, - sqlcmd.SQLCMDHEADERS: func(a *SQLCmdArguments) string { return fmt.Sprint(a.Headers) }, - sqlcmd.SQLCMDCOLSEP: func(a *SQLCmdArguments) string { return "" }, - sqlcmd.SQLCMDCOLWIDTH: func(a *SQLCmdArguments) string { return "" }, - sqlcmd.SQLCMDMAXVARTYPEWIDTH: func(a *SQLCmdArguments) string { return "" }, - sqlcmd.SQLCMDMAXFIXEDTYPEWIDTH: func(a *SQLCmdArguments) string { return "" }, - sqlcmd.SQLCMDFORMAT: func(a *SQLCmdArguments) string { return a.Format }, - } - for varname, set := range varmap { - val := set(args) - if val != "" { - vars.Set(varname, val) - } - } - - // Following sqlcmd tradition there's no validation of -v kvps - for v := range args.Variables { - vars.Set(v, args.Variables[v]) - } - -} - -func setConnect(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments, vars *sqlcmd.Variables) { - if !args.DisableCmdAndWarn { - connect.Password = os.Getenv(sqlcmd.SQLCMDPASSWORD) - } - connect.ServerName = args.Server - if connect.ServerName == "" { - connect.ServerName, _ = vars.Get(sqlcmd.SQLCMDSERVER) - } - connect.Database = args.DatabaseName - if connect.Database == "" { - connect.Database, _ = vars.Get(sqlcmd.SQLCMDDBNAME) - } - connect.UserName = args.UserName - if connect.UserName == "" { - connect.UserName, _ = vars.Get(sqlcmd.SQLCMDUSER) - } - connect.UseTrustedConnection = args.UseTrustedConnection - connect.TrustServerCertificate = args.TrustServerCertificate - connect.AuthenticationMethod = args.authenticationMethod(connect.Password != "") - connect.DisableEnvironmentVariables = args.DisableCmdAndWarn - connect.DisableVariableSubstitution = args.DisableVariableSubstitution - connect.ApplicationIntent = args.ApplicationIntent - connect.LoginTimeoutSeconds = args.LoginTimeout - connect.Encrypt = args.EncryptConnection - connect.PacketSize = args.PacketSize - connect.WorkstationName = args.WorkstationName - connect.LogLevel = args.DriverLoggingLevel - connect.ExitOnError = args.ExitOnError - connect.ErrorSeverityLevel = args.ErrorSeverityLevel -} - -func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { - wd, err := os.Getwd() - if err != nil { - return 1, err - } - - iactive := args.InputFile == nil && args.Query == "" - var line sqlcmd.Console = nil - if iactive { - line = console.NewConsole("") - defer line.Close() - } - - s := sqlcmd.New(line, wd, vars) - s.UnicodeOutputFile = args.UnicodeOutputFile - setConnect(&s.Connect, args, vars) - if args.BatchTerminator != "GO" { - err = s.Cmd.SetBatchTerminator(args.BatchTerminator) - if err != nil { - err = fmt.Errorf("invalid batch terminator '%s'", args.BatchTerminator) - } - } - if err != nil { - return 1, err - } - - setConnect(&s.Connect, args, vars) - s.Format = sqlcmd.NewSQLCmdDefaultFormatter(false) - if args.OutputFile != "" { - err = s.RunCommand(s.Cmd["OUT"], []string{args.OutputFile}) - if err != nil { - return 1, err - } - } else { - var stderrSeverity uint8 = 11 - if args.ErrorsToStderr == 1 { - stderrSeverity = 0 - } - if args.ErrorsToStderr >= 0 { - s.PrintError = func(msg string, severity uint8) bool { - if severity >= stderrSeverity { - _, _ = os.Stderr.Write([]byte(msg)) - return true - } - return false - } - } - } - once := false - if args.InitialQuery != "" { - s.Query = args.InitialQuery - } else if args.Query != "" { - once = true - s.Query = args.Query - } - // connect using no overrides - err = s.ConnectDb(nil, !iactive) - if err != nil { - return 1, err - } - if iactive || s.Query != "" { - err = s.Run(once, false) - } else { - for f := range args.InputFile { - if err = s.IncludeFile(args.InputFile[f], true); err != nil { - break - } - } - } - s.SetOutput(nil) - s.SetError(nil) - return s.Exitcode, err -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +//go:generate go-winres make --file-version=git-tag --product-version=git-tag +package main + +import ( + "fmt" + "os" + + "github.com/alecthomas/kong" + "github.com/microsoft/go-mssqldb/azuread" + "github.com/microsoft/go-sqlcmd/pkg/console" + "github.com/microsoft/go-sqlcmd/pkg/sqlcmd" +) + +var version = "Local-build" // overridden in pipeline builds with: -ldflags="-X main.version=$(Build.BuildNumber)" + +// SQLCmdArguments defines the command line arguments for sqlcmd +// The exhaustive list is at https://docs.microsoft.com/sql/tools/sqlcmd-utility?view=sql-server-ver15 +type SQLCmdArguments struct { + // Which batch terminator to use. Default is GO + BatchTerminator string `short:"c" default:"GO" arghelp:"Specifies the batch terminator. The default value is GO."` + // Whether to trust the server certificate on an encrypted connection + TrustServerCertificate bool `short:"C" help:"Implicitly trust the server certificate without validation."` + DatabaseName string `short:"d" help:"This option sets the sqlcmd scripting variable SQLCMDDBNAME. This parameter specifies the initial database. The default is your login's default-database property. If the database does not exist, an error message is generated and sqlcmd exits."` + UseTrustedConnection bool `short:"E" xor:"uid, auth" help:"Uses a trusted connection instead of using a user name and password to sign in to SQL Server, ignoring any any environment variables that define user name and password."` + UserName string `short:"U" xor:"uid" help:"The login name or contained database user name. For contained database users, you must provide the database name option"` + // Files from which to read query text + InputFile []string `short:"i" xor:"input1, input2" type:"existingFile" help:"Identifies one or more files that contain batches of SQL statements. If one or more files do not exist, sqlcmd will exit. Mutually exclusive with -Q/-q."` + OutputFile string `short:"o" type:"path" help:"Identifies the file that receives output from sqlcmd."` + // First query to run in interactive mode + InitialQuery string `short:"q" xor:"input1" help:"Executes a query when sqlcmd starts, but does not exit sqlcmd when the query has finished running. Multiple-semicolon-delimited queries can be executed."` + // Query to run then exit + Query string `short:"Q" xor:"input2" help:"Executes a query when sqlcmd starts and then immediately exits sqlcmd. Multiple-semicolon-delimited queries can be executed."` + Server string `short:"S" help:"[tcp:]server[\\instance_name][,port]Specifies the instance of SQL Server to which to connect. It sets the sqlcmd scripting variable SQLCMDSERVER."` + // Disable syscommands with a warning + DisableCmdAndWarn bool `short:"X" xor:"syscmd" help:"Disables commands that might compromise system security. Sqlcmd issues a warning and continues."` + // AuthenticationMethod is new for go-sqlcmd + AuthenticationMethod string `xor:"auth" help:"Specifies the SQL authentication method to use to connect to Azure SQL Database. One of:ActiveDirectoryDefault,ActiveDirectoryIntegrated,ActiveDirectoryPassword,ActiveDirectoryInteractive,ActiveDirectoryManagedIdentity,ActiveDirectoryServicePrincipal,SqlPassword"` + UseAad bool `short:"G" xor:"auth" help:"Tells sqlcmd to use Active Directory authentication. If no user name is provided, authentication method ActiveDirectoryDefault is used. If a password is provided, ActiveDirectoryPassword is used. Otherwise ActiveDirectoryInteractive is used."` + DisableVariableSubstitution bool `short:"x" help:"Causes sqlcmd to ignore scripting variables. This parameter is useful when a script contains many INSERT statements that may contain strings that have the same format as regular variables, such as $(variable_name)."` + Variables map[string]string `short:"v" help:"Creates a sqlcmd scripting variable that can be used in a sqlcmd script. Enclose the value in quotation marks if the value contains spaces. You can specify multiple var=values values. If there are errors in any of the values specified, sqlcmd generates an error message and then exits"` + PacketSize int `short:"a" help:"Requests a packet of a different size. This option sets the sqlcmd scripting variable SQLCMDPACKETSIZE. packet_size must be a value between 512 and 32767. The default = 4096. A larger packet size can enhance performance for execution of scripts that have lots of SQL statements between GO commands. You can request a larger packet size. However, if the request is denied, sqlcmd uses the server default for packet size."` + LoginTimeout int `short:"l" default:"-1" help:"Specifies the number of seconds before a sqlcmd login to the go-mssqldb driver times out when you try to connect to a server. This option sets the sqlcmd scripting variable SQLCMDLOGINTIMEOUT. The default value is 30. 0 means infinite."` + WorkstationName string `short:"H" help:"This option sets the sqlcmd scripting variable SQLCMDWORKSTATION. The workstation name is listed in the hostname column of the sys.sysprocesses catalog view and can be returned using the stored procedure sp_who. If this option is not specified, the default is the current computer name. This name can be used to identify different sqlcmd sessions."` + ApplicationIntent string `short:"K" default:"default" enum:"default,ReadOnly" help:"Declares the application workload type when connecting to a server. The only currently supported value is ReadOnly. If -K is not specified, the sqlcmd utility will not support connectivity to a secondary replica in an Always On availability group."` + EncryptConnection string `short:"N" default:"default" enum:"default,false,true,disable" help:"This switch is used by the client to request an encrypted connection."` + DriverLoggingLevel int `help:"Level of mssql driver messages to print."` + ExitOnError bool `short:"b" help:"Specifies that sqlcmd exits and returns a DOS ERRORLEVEL value when an error occurs."` + ErrorSeverityLevel uint8 `short:"V" help:"Controls the severity level that is used to set the ERRORLEVEL variable on exit."` + ErrorLevel int `short:"m" help:"Controls which error messages are sent to stdout. Messages that have severity level greater than or equal to this level are sent."` + Format string `short:"F" help:"Specifies the formatting for results." default:"horiz" enum:"horiz,horizontal,vert,vertical"` + ErrorsToStderr int `short:"r" help:"Redirects the error message output to the screen (stderr). A value of 0 means messages with severity >= 11 will b redirected. A value of 1 means all error message output including PRINT is redirected." enum:"-1,0,1" default:"-1"` + Headers int `short:"h" help:"Specifies the number of rows to print between the column headings. Use -h-1 to specify that headers not be printed."` + UnicodeOutputFile bool `short:"u" help:"Specifies that all output files are encoded with little-endian Unicode"` + + Version bool `help:"Show the sqlcmd version information"` + + // Keep Help at the end of the list + Help bool `short:"?" help:"Show syntax summary."` +} + +// Validate accounts for settings not described by Kong attributes +func (a *SQLCmdArguments) Validate() error { + if a.PacketSize != 0 && (a.PacketSize < 512 || a.PacketSize > 32767) { + return fmt.Errorf(`'-a %d': Packet size has to be a number between 512 and 32767.`, a.PacketSize) + } + // Ignore 0 even though it's technically an invalid input + if a.Headers < -1 { + return fmt.Errorf(`'-h %d': header value must be either -1 or a value between 1 and 2147483647`, a.Headers) + } + return nil +} + +// newArguments constructs a SQLCmdArguments instance with default values +// Any parameter with a "default" Kong attribute should have an assignment here +func newArguments() SQLCmdArguments { + return SQLCmdArguments{ + BatchTerminator: "GO", + } +} + +// Breaking changes in command line are listed here. +// Any switch not listed in breaking changes and not also included in SqlCmdArguments just has not been implemented yet +// 1. -P: Passwords have to be provided through SQLCMDPASSWORD environment variable or typed when prompted +// 2. -R: Go runtime doesn't expose user locale information and syscall would only enable it on Windows, so we won't try to implement it + +var args SQLCmdArguments + +func (a SQLCmdArguments) authenticationMethod(hasPassword bool) string { + if a.UseTrustedConnection { + return sqlcmd.NotSpecified + } + if a.UseAad { + switch { + case a.UserName == "": + return azuread.ActiveDirectoryIntegrated + case hasPassword: + return azuread.ActiveDirectoryPassword + default: + return azuread.ActiveDirectoryInteractive + } + } + if a.AuthenticationMethod == "" { + return sqlcmd.NotSpecified + } + return a.AuthenticationMethod +} + +func main() { + + ctx := kong.Parse(&args, kong.NoDefaultHelp()) + + if args.Version { + ctx.Printf("v%v", version) + os.Exit(0) + } + + if args.Help { + _ = ctx.PrintUsage(false) + os.Exit(0) + } + vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) + setVars(vars, &args) + + // so far sqlcmd prints all the errors itself so ignore it + exitCode, _ := run(vars, &args) + os.Exit(exitCode) +} + +// setVars initializes scripting variables from command line arguments +func setVars(vars *sqlcmd.Variables, args *SQLCmdArguments) { + varmap := map[string]func(*SQLCmdArguments) string{ + sqlcmd.SQLCMDDBNAME: func(a *SQLCmdArguments) string { return a.DatabaseName }, + sqlcmd.SQLCMDLOGINTIMEOUT: func(a *SQLCmdArguments) string { + if a.LoginTimeout > -1 { + return fmt.Sprint(a.LoginTimeout) + } + return "" + }, + sqlcmd.SQLCMDUSEAAD: func(a *SQLCmdArguments) string { + if a.UseAad { + return "true" + } + switch a.AuthenticationMethod { + case azuread.ActiveDirectoryIntegrated: + case azuread.ActiveDirectoryInteractive: + case azuread.ActiveDirectoryPassword: + return "true" + } + return "" + }, + sqlcmd.SQLCMDWORKSTATION: func(a *SQLCmdArguments) string { return args.WorkstationName }, + sqlcmd.SQLCMDSERVER: func(a *SQLCmdArguments) string { return a.Server }, + sqlcmd.SQLCMDERRORLEVEL: func(a *SQLCmdArguments) string { return fmt.Sprint(a.ErrorLevel) }, + sqlcmd.SQLCMDPACKETSIZE: func(a *SQLCmdArguments) string { + if args.PacketSize > 0 { + return fmt.Sprint(args.PacketSize) + } + return "" + }, + sqlcmd.SQLCMDUSER: func(a *SQLCmdArguments) string { return a.UserName }, + sqlcmd.SQLCMDSTATTIMEOUT: func(a *SQLCmdArguments) string { return "" }, + sqlcmd.SQLCMDHEADERS: func(a *SQLCmdArguments) string { return fmt.Sprint(a.Headers) }, + sqlcmd.SQLCMDCOLSEP: func(a *SQLCmdArguments) string { return "" }, + sqlcmd.SQLCMDCOLWIDTH: func(a *SQLCmdArguments) string { return "" }, + sqlcmd.SQLCMDMAXVARTYPEWIDTH: func(a *SQLCmdArguments) string { return "" }, + sqlcmd.SQLCMDMAXFIXEDTYPEWIDTH: func(a *SQLCmdArguments) string { return "" }, + sqlcmd.SQLCMDFORMAT: func(a *SQLCmdArguments) string { return a.Format }, + } + for varname, set := range varmap { + val := set(args) + if val != "" { + vars.Set(varname, val) + } + } + + // Following sqlcmd tradition there's no validation of -v kvps + for v := range args.Variables { + vars.Set(v, args.Variables[v]) + } + +} + +func setConnect(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments, vars *sqlcmd.Variables) { + if !args.DisableCmdAndWarn { + connect.Password = os.Getenv(sqlcmd.SQLCMDPASSWORD) + } + connect.ServerName = args.Server + if connect.ServerName == "" { + connect.ServerName, _ = vars.Get(sqlcmd.SQLCMDSERVER) + } + connect.Database = args.DatabaseName + if connect.Database == "" { + connect.Database, _ = vars.Get(sqlcmd.SQLCMDDBNAME) + } + connect.UserName = args.UserName + if connect.UserName == "" { + connect.UserName, _ = vars.Get(sqlcmd.SQLCMDUSER) + } + connect.UseTrustedConnection = args.UseTrustedConnection + connect.TrustServerCertificate = args.TrustServerCertificate + connect.AuthenticationMethod = args.authenticationMethod(connect.Password != "") + connect.DisableEnvironmentVariables = args.DisableCmdAndWarn + connect.DisableVariableSubstitution = args.DisableVariableSubstitution + connect.ApplicationIntent = args.ApplicationIntent + connect.LoginTimeoutSeconds = args.LoginTimeout + connect.Encrypt = args.EncryptConnection + connect.PacketSize = args.PacketSize + connect.WorkstationName = args.WorkstationName + connect.LogLevel = args.DriverLoggingLevel + connect.ExitOnError = args.ExitOnError + connect.ErrorSeverityLevel = args.ErrorSeverityLevel +} + +func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { + wd, err := os.Getwd() + if err != nil { + return 1, err + } + + iactive := args.InputFile == nil && args.Query == "" + var line sqlcmd.Console = nil + if iactive { + line = console.NewConsole("") + defer line.Close() + } + + s := sqlcmd.New(line, wd, vars) + s.UnicodeOutputFile = args.UnicodeOutputFile + setConnect(&s.Connect, args, vars) + if args.BatchTerminator != "GO" { + err = s.Cmd.SetBatchTerminator(args.BatchTerminator) + if err != nil { + err = fmt.Errorf("invalid batch terminator '%s'", args.BatchTerminator) + } + } + if err != nil { + return 1, err + } + + setConnect(&s.Connect, args, vars) + s.Format = sqlcmd.NewSQLCmdDefaultFormatter(false) + if args.OutputFile != "" { + err = s.RunCommand(s.Cmd["OUT"], []string{args.OutputFile}) + if err != nil { + return 1, err + } + } else { + var stderrSeverity uint8 = 11 + if args.ErrorsToStderr == 1 { + stderrSeverity = 0 + } + if args.ErrorsToStderr >= 0 { + s.PrintError = func(msg string, severity uint8) bool { + if severity >= stderrSeverity { + _, _ = os.Stderr.Write([]byte(msg)) + return true + } + return false + } + } + } + once := false + if args.InitialQuery != "" { + s.Query = args.InitialQuery + } else if args.Query != "" { + once = true + s.Query = args.Query + } + // connect using no overrides + err = s.ConnectDb(nil, !iactive) + if err != nil { + return 1, err + } + if iactive || s.Query != "" { + err = s.Run(once, false) + } else { + for f := range args.InputFile { + if err = s.IncludeFile(args.InputFile[f], true); err != nil { + break + } + } + } + s.SetOutput(nil) + s.SetError(nil) + return s.Exitcode, err +} diff --git a/cmd/sqlcmd/main_test.go b/cmd/sqlcmd/main_test.go index 3933c8e4..2f3ea914 100644 --- a/cmd/sqlcmd/main_test.go +++ b/cmd/sqlcmd/main_test.go @@ -1,287 +1,290 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. -package main - -import ( - "os" - "runtime" - "strings" - "testing" - - "github.com/alecthomas/kong" - "github.com/microsoft/go-sqlcmd/pkg/sqlcmd" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -const oneRowAffected = "(1 row affected)" - -func newKong(t *testing.T, cli interface{}, options ...kong.Option) *kong.Kong { - t.Helper() - options = append([]kong.Option{ - kong.Name("test"), - kong.NoDefaultHelp(), - kong.Exit(func(int) { - t.Helper() - t.Fatalf("unexpected exit()") - }), - }, options...) - parser, err := kong.New(cli, options...) - require.NoError(t, err) - return parser -} - -func TestValidCommandLineToArgsConversion(t *testing.T) { - type cmdLineTest struct { - commandLine []string - check func(SQLCmdArguments) bool - } - - // These tests only cover compatibility with the native sqlcmd, which only supports the short flags - // The long flag names are up for debate. - commands := []cmdLineTest{ - {[]string{}, func(args SQLCmdArguments) bool { - return args.Server == "" && !args.UseTrustedConnection && args.UserName == "" - }}, - {[]string{"-c", "MYGO", "-C", "-E", "-i", "file1", "-o", "outfile", "-i", "file2"}, func(args SQLCmdArguments) bool { - return args.BatchTerminator == "MYGO" && args.TrustServerCertificate && len(args.InputFile) == 2 && strings.HasSuffix(args.OutputFile, "outfile") - }}, - {[]string{"-U", "someuser", "-d", "somedatabase", "-S", "someserver"}, func(args SQLCmdArguments) bool { - return args.BatchTerminator == "GO" && !args.TrustServerCertificate && args.UserName == "someuser" && args.DatabaseName == "somedatabase" && args.Server == "someserver" - }}, - // native sqlcmd allows both -q and -Q but only runs the -Q query and exits. We could make them mutually exclusive if desired. - {[]string{"-q", "select 1", "-Q", "select 2"}, func(args SQLCmdArguments) bool { - return args.Server == "" && args.InitialQuery == "select 1" && args.Query == "select 2" - }}, - {[]string{"-S", "someserver/someinstance"}, func(args SQLCmdArguments) bool { - return args.Server == "someserver/someinstance" - }}, - {[]string{"-S", "tcp:someserver,10245"}, func(args SQLCmdArguments) bool { - return args.Server == "tcp:someserver,10245" && !args.DisableVariableSubstitution - }}, - {[]string{"-X", "-x"}, func(args SQLCmdArguments) bool { - return args.DisableCmdAndWarn && args.DisableVariableSubstitution - }}, - // Notice no "" around the value with a space in it. It seems quotes get stripped out somewhere before Parse when invoking on a real command line - {[]string{"-v", "x=y", "-v", `y=a space`}, func(args SQLCmdArguments) bool { - return args.LoginTimeout == -1 && args.Variables["x"] == "y" && args.Variables["y"] == "a space" - }}, - {[]string{"-a", "550", "-l", "45", "-H", "mystation", "-K", "ReadOnly", "-N", "true"}, func(args SQLCmdArguments) bool { - return args.PacketSize == 550 && args.LoginTimeout == 45 && args.WorkstationName == "mystation" && args.ApplicationIntent == "ReadOnly" && args.EncryptConnection == "true" - }}, - {[]string{"-b", "-m", "15", "-V", "20"}, func(args SQLCmdArguments) bool { - return args.ExitOnError && args.ErrorLevel == 15 && args.ErrorSeverityLevel == 20 - }}, - {[]string{"-F", "vert"}, func(args SQLCmdArguments) bool { - return args.Format == "vert" - }}, - {[]string{"-r", "1"}, func(args SQLCmdArguments) bool { - return args.ErrorsToStderr == 1 - }}, - {[]string{"-h", "2", "-?"}, func(args SQLCmdArguments) bool { - return args.Help && args.Headers == 2 - }}, - {[]string{"-u"}, func(args SQLCmdArguments) bool { - return args.UnicodeOutputFile - }}, - } - - for _, test := range commands { - arguments := &SQLCmdArguments{} - parser := newKong(t, arguments) - _, err := parser.Parse(test.commandLine) - msg := "" - if err != nil { - msg = err.Error() - } - if assert.Nil(t, err, "Unable to parse commandLine:%v\n%s", test.commandLine, msg) { - assert.True(t, test.check(*arguments), "Unexpected SqlCmdArguments from: %v\n%+v", test.commandLine, *arguments) - } - } -} - -func TestInvalidCommandLine(t *testing.T) { - type cmdLineTest struct { - commandLine []string - errorMessage string - } - - commands := []cmdLineTest{ - {[]string{"-E", "-U", "someuser"}, "--use-trusted-connection and --user-name can't be used together"}, - // the test prefix is a kong artifact https://github.com/alecthomas/kong/issues/221 - {[]string{"-a", "100"}, "test: '-a 100': Packet size has to be a number between 512 and 32767."}, - {[]string{"-F", "what"}, "--format must be one of \"horiz\",\"horizontal\",\"vert\",\"vertical\" but got \"what\""}, - {[]string{"-r", "5"}, `--errors-to-stderr must be one of "-1","0","1" but got '\x05'`}, - {[]string{"-h-4"}, "test: '-h -4': header value must be either -1 or a value between 1 and 2147483647"}, - } - - for _, test := range commands { - arguments := &SQLCmdArguments{} - parser := newKong(t, arguments) - _, err := parser.Parse(test.commandLine) - assert.EqualError(t, err, test.errorMessage, "Command line:%v", test.commandLine) - } -} - -// Simulate main() using files -func TestRunInputFiles(t *testing.T) { - o, err := os.CreateTemp("", "sqlcmdmain") - assert.NoError(t, err, "os.CreateTemp") - defer os.Remove(o.Name()) - defer o.Close() - args = newArguments() - args.InputFile = []string{"testdata/select100.sql", "testdata/select100.sql"} - args.OutputFile = o.Name() - if canTestAzureAuth() { - args.UseAad = true - } - vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) - vars.Set(sqlcmd.SQLCMDMAXVARTYPEWIDTH, "0") - setVars(vars, &args) - - exitCode, err := run(vars, &args) - assert.NoError(t, err, "run") - assert.Equal(t, 0, exitCode, "exitCode") - bytes, err := os.ReadFile(o.Name()) - if assert.NoError(t, err, "os.ReadFile") { - assert.Equal(t, "100"+sqlcmd.SqlcmdEol+sqlcmd.SqlcmdEol+oneRowAffected+sqlcmd.SqlcmdEol+"100"+sqlcmd.SqlcmdEol+sqlcmd.SqlcmdEol+oneRowAffected+sqlcmd.SqlcmdEol, string(bytes), "Incorrect output from run") - } -} - -func TestUnicodeOutput(t *testing.T) { - o, err := os.CreateTemp("", "sqlcmdmain") - assert.NoError(t, err, "os.CreateTemp") - defer os.Remove(o.Name()) - defer o.Close() - args = newArguments() - args.InputFile = []string{"testdata/selectutf8.txt"} - args.OutputFile = o.Name() - args.UnicodeOutputFile = true - if canTestAzureAuth() { - args.UseAad = true - } - vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) - setVars(vars, &args) - - exitCode, err := run(vars, &args) - assert.NoError(t, err, "run") - assert.Equal(t, 0, exitCode, "exitCode") - bytes, err := os.ReadFile(o.Name()) - if assert.NoError(t, err, "os.ReadFile") { - outfile := `testdata/unicodeout_linux.txt` - if runtime.GOOS == "windows" { - outfile = `testdata/unicodeout.txt` - } - expectedBytes, err := os.ReadFile(outfile) - if assert.NoErrorf(t, err, "Unable to open %s", outfile) { - assert.Equalf(t, expectedBytes, bytes, "unicode output bytes should match %s", outfile) - } - } -} - -func TestUnicodeInput(t *testing.T) { - testfiles := []string{ - `testdata/selectutf8.txt`, - `testdata/selectutf8_bom.txt`, - `testdata/selectunicode_BE.txt`, - `testdata/selectunicode_LE.txt`, - } - - for _, test := range testfiles { - for _, unicodeOutput := range []bool{true, false} { - var outfile string - if unicodeOutput { - outfile = `testdata/unicodeout_linux.txt` - if runtime.GOOS == "windows" { - outfile = `testdata/unicodeout.txt` - } - } else { - outfile = `testdata/utf8out_linux.txt` - if runtime.GOOS == "windows" { - outfile = `testdata/utf8out.txt` - } - } - o, err := os.CreateTemp("", "sqlcmdmain") - assert.NoError(t, err, "os.CreateTemp") - defer os.Remove(o.Name()) - defer o.Close() - args = newArguments() - args.InputFile = []string{test} - args.OutputFile = o.Name() - args.UnicodeOutputFile = unicodeOutput - if canTestAzureAuth() { - args.UseAad = true - } - vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) - setVars(vars, &args) - exitCode, err := run(vars, &args) - assert.NoError(t, err, "run") - assert.Equal(t, 0, exitCode, "exitCode") - bytes, err := os.ReadFile(o.Name()) - if assert.NoError(t, err, "os.ReadFile") { - expectedBytes, err := os.ReadFile(outfile) - if assert.NoErrorf(t, err, "Unable to open %s", outfile) { - assert.Equalf(t, expectedBytes, bytes, "input file: <%s> output bytes should match <%s>", test, outfile) - } - } - } - } -} - -func TestQueryAndExit(t *testing.T) { - o, err := os.CreateTemp("", "sqlcmdmain") - assert.NoError(t, err, "os.CreateTemp") - defer os.Remove(o.Name()) - defer o.Close() - args = newArguments() - args.Query = "SELECT '$(VAR1) $(VAR2)'" - args.OutputFile = o.Name() - args.Variables = map[string]string{"var2": "val2"} - if canTestAzureAuth() { - args.UseAad = true - } - vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) - vars.Set(sqlcmd.SQLCMDMAXVARTYPEWIDTH, "0") - vars.Set("VAR1", "100") - setVars(vars, &args) - - exitCode, err := run(vars, &args) - assert.NoError(t, err, "run") - assert.Equal(t, 0, exitCode, "exitCode") - bytes, err := os.ReadFile(o.Name()) - if assert.NoError(t, err, "os.ReadFile") { - assert.Equal(t, "100 val2"+sqlcmd.SqlcmdEol+sqlcmd.SqlcmdEol+oneRowAffected+sqlcmd.SqlcmdEol, string(bytes), "Incorrect output from run") - } -} - -func TestAzureAuth(t *testing.T) { - - if !canTestAzureAuth() { - t.Skip("Server name is not an Azure DB name") - } - o, err := os.CreateTemp("", "sqlcmdmain") - assert.NoError(t, err, "os.CreateTemp") - defer os.Remove(o.Name()) - defer o.Close() - args = newArguments() - args.Query = "SELECT 'AZURE'" - args.OutputFile = o.Name() - args.UseAad = true - vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) - vars.Set(sqlcmd.SQLCMDMAXVARTYPEWIDTH, "0") - setVars(vars, &args) - exitCode, err := run(vars, &args) - assert.NoError(t, err, "run") - assert.Equal(t, 0, exitCode, "exitCode") - bytes, err := os.ReadFile(o.Name()) - if assert.NoError(t, err, "os.ReadFile") { - assert.Equal(t, "AZURE"+sqlcmd.SqlcmdEol+sqlcmd.SqlcmdEol+oneRowAffected+sqlcmd.SqlcmdEol, string(bytes), "Incorrect output from run") - } -} - -// Assuming public Azure, use AAD when SQLCMDUSER environment variable is not set -func canTestAzureAuth() bool { - server := os.Getenv(sqlcmd.SQLCMDSERVER) - userName := os.Getenv(sqlcmd.SQLCMDUSER) - return strings.Contains(server, ".database.windows.net") && userName == "" -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +package main + +import ( + "os" + "runtime" + "strings" + "testing" + + "github.com/alecthomas/kong" + "github.com/microsoft/go-sqlcmd/pkg/sqlcmd" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const oneRowAffected = "(1 row affected)" + +func newKong(t *testing.T, cli interface{}, options ...kong.Option) *kong.Kong { + t.Helper() + options = append([]kong.Option{ + kong.Name("test"), + kong.NoDefaultHelp(), + kong.Exit(func(int) { + t.Helper() + t.Fatalf("unexpected exit()") + }), + }, options...) + parser, err := kong.New(cli, options...) + require.NoError(t, err) + return parser +} + +func TestValidCommandLineToArgsConversion(t *testing.T) { + type cmdLineTest struct { + commandLine []string + check func(SQLCmdArguments) bool + } + + // These tests only cover compatibility with the native sqlcmd, which only supports the short flags + // The long flag names are up for debate. + commands := []cmdLineTest{ + {[]string{}, func(args SQLCmdArguments) bool { + return args.Server == "" && !args.UseTrustedConnection && args.UserName == "" + }}, + {[]string{"-c", "MYGO", "-C", "-E", "-i", "file1", "-o", "outfile", "-i", "file2"}, func(args SQLCmdArguments) bool { + return args.BatchTerminator == "MYGO" && args.TrustServerCertificate && len(args.InputFile) == 2 && strings.HasSuffix(args.OutputFile, "outfile") + }}, + {[]string{"-U", "someuser", "-d", "somedatabase", "-S", "someserver"}, func(args SQLCmdArguments) bool { + return args.BatchTerminator == "GO" && !args.TrustServerCertificate && args.UserName == "someuser" && args.DatabaseName == "somedatabase" && args.Server == "someserver" + }}, + // native sqlcmd allows both -q and -Q but only runs the -Q query and exits. We could make them mutually exclusive if desired. + {[]string{"-q", "select 1", "-Q", "select 2"}, func(args SQLCmdArguments) bool { + return args.Server == "" && args.InitialQuery == "select 1" && args.Query == "select 2" + }}, + {[]string{"-S", "someserver/someinstance"}, func(args SQLCmdArguments) bool { + return args.Server == "someserver/someinstance" + }}, + {[]string{"-S", "tcp:someserver,10245"}, func(args SQLCmdArguments) bool { + return args.Server == "tcp:someserver,10245" && !args.DisableVariableSubstitution + }}, + {[]string{"-X", "-x"}, func(args SQLCmdArguments) bool { + return args.DisableCmdAndWarn && args.DisableVariableSubstitution + }}, + // Notice no "" around the value with a space in it. It seems quotes get stripped out somewhere before Parse when invoking on a real command line + {[]string{"-v", "x=y", "-v", `y=a space`}, func(args SQLCmdArguments) bool { + return args.LoginTimeout == -1 && args.Variables["x"] == "y" && args.Variables["y"] == "a space" + }}, + {[]string{"-a", "550", "-l", "45", "-H", "mystation", "-K", "ReadOnly", "-N", "true"}, func(args SQLCmdArguments) bool { + return args.PacketSize == 550 && args.LoginTimeout == 45 && args.WorkstationName == "mystation" && args.ApplicationIntent == "ReadOnly" && args.EncryptConnection == "true" + }}, + {[]string{"-b", "-m", "15", "-V", "20"}, func(args SQLCmdArguments) bool { + return args.ExitOnError && args.ErrorLevel == 15 && args.ErrorSeverityLevel == 20 + }}, + {[]string{"-F", "vert"}, func(args SQLCmdArguments) bool { + return args.Format == "vert" + }}, + {[]string{"-r", "1"}, func(args SQLCmdArguments) bool { + return args.ErrorsToStderr == 1 + }}, + {[]string{"-h", "2", "-?"}, func(args SQLCmdArguments) bool { + return args.Help && args.Headers == 2 + }}, + {[]string{"-u"}, func(args SQLCmdArguments) bool { + return args.UnicodeOutputFile + }}, + {[]string{"--version"}, func(args SQLCmdArguments) bool { + return args.Version + }}, + } + + for _, test := range commands { + arguments := &SQLCmdArguments{} + parser := newKong(t, arguments) + _, err := parser.Parse(test.commandLine) + msg := "" + if err != nil { + msg = err.Error() + } + if assert.Nil(t, err, "Unable to parse commandLine:%v\n%s", test.commandLine, msg) { + assert.True(t, test.check(*arguments), "Unexpected SqlCmdArguments from: %v\n%+v", test.commandLine, *arguments) + } + } +} + +func TestInvalidCommandLine(t *testing.T) { + type cmdLineTest struct { + commandLine []string + errorMessage string + } + + commands := []cmdLineTest{ + {[]string{"-E", "-U", "someuser"}, "--use-trusted-connection and --user-name can't be used together"}, + // the test prefix is a kong artifact https://github.com/alecthomas/kong/issues/221 + {[]string{"-a", "100"}, "test: '-a 100': Packet size has to be a number between 512 and 32767."}, + {[]string{"-F", "what"}, "--format must be one of \"horiz\",\"horizontal\",\"vert\",\"vertical\" but got \"what\""}, + {[]string{"-r", "5"}, `--errors-to-stderr must be one of "-1","0","1" but got '\x05'`}, + {[]string{"-h-4"}, "test: '-h -4': header value must be either -1 or a value between 1 and 2147483647"}, + } + + for _, test := range commands { + arguments := &SQLCmdArguments{} + parser := newKong(t, arguments) + _, err := parser.Parse(test.commandLine) + assert.EqualError(t, err, test.errorMessage, "Command line:%v", test.commandLine) + } +} + +// Simulate main() using files +func TestRunInputFiles(t *testing.T) { + o, err := os.CreateTemp("", "sqlcmdmain") + assert.NoError(t, err, "os.CreateTemp") + defer os.Remove(o.Name()) + defer o.Close() + args = newArguments() + args.InputFile = []string{"testdata/select100.sql", "testdata/select100.sql"} + args.OutputFile = o.Name() + if canTestAzureAuth() { + args.UseAad = true + } + vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) + vars.Set(sqlcmd.SQLCMDMAXVARTYPEWIDTH, "0") + setVars(vars, &args) + + exitCode, err := run(vars, &args) + assert.NoError(t, err, "run") + assert.Equal(t, 0, exitCode, "exitCode") + bytes, err := os.ReadFile(o.Name()) + if assert.NoError(t, err, "os.ReadFile") { + assert.Equal(t, "100"+sqlcmd.SqlcmdEol+sqlcmd.SqlcmdEol+oneRowAffected+sqlcmd.SqlcmdEol+"100"+sqlcmd.SqlcmdEol+sqlcmd.SqlcmdEol+oneRowAffected+sqlcmd.SqlcmdEol, string(bytes), "Incorrect output from run") + } +} + +func TestUnicodeOutput(t *testing.T) { + o, err := os.CreateTemp("", "sqlcmdmain") + assert.NoError(t, err, "os.CreateTemp") + defer os.Remove(o.Name()) + defer o.Close() + args = newArguments() + args.InputFile = []string{"testdata/selectutf8.txt"} + args.OutputFile = o.Name() + args.UnicodeOutputFile = true + if canTestAzureAuth() { + args.UseAad = true + } + vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) + setVars(vars, &args) + + exitCode, err := run(vars, &args) + assert.NoError(t, err, "run") + assert.Equal(t, 0, exitCode, "exitCode") + bytes, err := os.ReadFile(o.Name()) + if assert.NoError(t, err, "os.ReadFile") { + outfile := `testdata/unicodeout_linux.txt` + if runtime.GOOS == "windows" { + outfile = `testdata/unicodeout.txt` + } + expectedBytes, err := os.ReadFile(outfile) + if assert.NoErrorf(t, err, "Unable to open %s", outfile) { + assert.Equalf(t, expectedBytes, bytes, "unicode output bytes should match %s", outfile) + } + } +} + +func TestUnicodeInput(t *testing.T) { + testfiles := []string{ + `testdata/selectutf8.txt`, + `testdata/selectutf8_bom.txt`, + `testdata/selectunicode_BE.txt`, + `testdata/selectunicode_LE.txt`, + } + + for _, test := range testfiles { + for _, unicodeOutput := range []bool{true, false} { + var outfile string + if unicodeOutput { + outfile = `testdata/unicodeout_linux.txt` + if runtime.GOOS == "windows" { + outfile = `testdata/unicodeout.txt` + } + } else { + outfile = `testdata/utf8out_linux.txt` + if runtime.GOOS == "windows" { + outfile = `testdata/utf8out.txt` + } + } + o, err := os.CreateTemp("", "sqlcmdmain") + assert.NoError(t, err, "os.CreateTemp") + defer os.Remove(o.Name()) + defer o.Close() + args = newArguments() + args.InputFile = []string{test} + args.OutputFile = o.Name() + args.UnicodeOutputFile = unicodeOutput + if canTestAzureAuth() { + args.UseAad = true + } + vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) + setVars(vars, &args) + exitCode, err := run(vars, &args) + assert.NoError(t, err, "run") + assert.Equal(t, 0, exitCode, "exitCode") + bytes, err := os.ReadFile(o.Name()) + if assert.NoError(t, err, "os.ReadFile") { + expectedBytes, err := os.ReadFile(outfile) + if assert.NoErrorf(t, err, "Unable to open %s", outfile) { + assert.Equalf(t, expectedBytes, bytes, "input file: <%s> output bytes should match <%s>", test, outfile) + } + } + } + } +} + +func TestQueryAndExit(t *testing.T) { + o, err := os.CreateTemp("", "sqlcmdmain") + assert.NoError(t, err, "os.CreateTemp") + defer os.Remove(o.Name()) + defer o.Close() + args = newArguments() + args.Query = "SELECT '$(VAR1) $(VAR2)'" + args.OutputFile = o.Name() + args.Variables = map[string]string{"var2": "val2"} + if canTestAzureAuth() { + args.UseAad = true + } + vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) + vars.Set(sqlcmd.SQLCMDMAXVARTYPEWIDTH, "0") + vars.Set("VAR1", "100") + setVars(vars, &args) + + exitCode, err := run(vars, &args) + assert.NoError(t, err, "run") + assert.Equal(t, 0, exitCode, "exitCode") + bytes, err := os.ReadFile(o.Name()) + if assert.NoError(t, err, "os.ReadFile") { + assert.Equal(t, "100 val2"+sqlcmd.SqlcmdEol+sqlcmd.SqlcmdEol+oneRowAffected+sqlcmd.SqlcmdEol, string(bytes), "Incorrect output from run") + } +} + +func TestAzureAuth(t *testing.T) { + + if !canTestAzureAuth() { + t.Skip("Server name is not an Azure DB name") + } + o, err := os.CreateTemp("", "sqlcmdmain") + assert.NoError(t, err, "os.CreateTemp") + defer os.Remove(o.Name()) + defer o.Close() + args = newArguments() + args.Query = "SELECT 'AZURE'" + args.OutputFile = o.Name() + args.UseAad = true + vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) + vars.Set(sqlcmd.SQLCMDMAXVARTYPEWIDTH, "0") + setVars(vars, &args) + exitCode, err := run(vars, &args) + assert.NoError(t, err, "run") + assert.Equal(t, 0, exitCode, "exitCode") + bytes, err := os.ReadFile(o.Name()) + if assert.NoError(t, err, "os.ReadFile") { + assert.Equal(t, "AZURE"+sqlcmd.SqlcmdEol+sqlcmd.SqlcmdEol+oneRowAffected+sqlcmd.SqlcmdEol, string(bytes), "Incorrect output from run") + } +} + +// Assuming public Azure, use AAD when SQLCMDUSER environment variable is not set +func canTestAzureAuth() bool { + server := os.Getenv(sqlcmd.SQLCMDSERVER) + userName := os.Getenv(sqlcmd.SQLCMDUSER) + return strings.Contains(server, ".database.windows.net") && userName == "" +} From c1107a690d153bdeaa44e90e47c2055c8f464b38 Mon Sep 17 00:00:00 2001 From: stuartpa Date: Wed, 7 Sep 2022 07:11:10 -0700 Subject: [PATCH 2/8] Use VERSION_TAG --- cmd/sqlcmd/main_test.go | 86 ----------------------------------------- 1 file changed, 86 deletions(-) diff --git a/cmd/sqlcmd/main_test.go b/cmd/sqlcmd/main_test.go index ccf06ade..283bd525 100644 --- a/cmd/sqlcmd/main_test.go +++ b/cmd/sqlcmd/main_test.go @@ -256,89 +256,3 @@ func TestQueryAndExit(t *testing.T) { assert.Equal(t, "100 val2"+sqlcmd.SqlcmdEol+sqlcmd.SqlcmdEol+oneRowAffected+sqlcmd.SqlcmdEol, string(bytes), "Incorrect output from run") } } - -<<<<<<< HEAD -======= -// Test to verify fix for issue: https://github.com/microsoft/go-sqlcmd/issues/98 -// 1. Verify when -b is passed in (ExitOnError), we don't always get an error (even when input is good) -// 2, Verify when the input is actually bad, we do get an error -func TestExitOnError(t *testing.T) { - args = newArguments() - args.InputFile = []string{"testdata/select100.sql"} - args.ErrorsToStderr = 0 - args.ExitOnError = true - if canTestAzureAuth() { - args.UseAad = true - } - - vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) - setVars(vars, &args) - - exitCode, err := run(vars, &args) - assert.NoError(t, err, "run") - assert.Equal(t, 0, exitCode, "exitCode") - - args.InputFile = []string{"testdata/bad.sql"} - - vars = sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) - setVars(vars, &args) - - exitCode, err = run(vars, &args) - assert.NoError(t, err, "run") - assert.Equal(t, 1, exitCode, "exitCode") - - t.Logf("Test Completed") // Needs some output to stdout to count as a test -} - ->>>>>>> main -func TestAzureAuth(t *testing.T) { - - if !canTestAzureAuth() { - t.Skip("Server name is not an Azure DB name") - } - o, err := os.CreateTemp("", "sqlcmdmain") - assert.NoError(t, err, "os.CreateTemp") - defer os.Remove(o.Name()) - defer o.Close() - args = newArguments() - args.Query = "SELECT 'AZURE'" - args.OutputFile = o.Name() - args.UseAad = true - vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) - vars.Set(sqlcmd.SQLCMDMAXVARTYPEWIDTH, "0") - setVars(vars, &args) - exitCode, err := run(vars, &args) - assert.NoError(t, err, "run") - assert.Equal(t, 0, exitCode, "exitCode") - bytes, err := os.ReadFile(o.Name()) - if assert.NoError(t, err, "os.ReadFile") { - assert.Equal(t, "AZURE"+sqlcmd.SqlcmdEol+sqlcmd.SqlcmdEol+oneRowAffected+sqlcmd.SqlcmdEol, string(bytes), "Incorrect output from run") - } -} - -<<<<<<< HEAD -======= -func TestMissingInputFile(t *testing.T) { - args = newArguments() - args.InputFile = []string{"testdata/missingFile.sql"} - - if canTestAzureAuth() { - args.UseAad = true - } - - vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) - setVars(vars, &args) - - exitCode, err := run(vars, &args) - assert.Error(t, err, "run") - assert.Contains(t, err.Error(), "Error occurred while opening or operating on file", "Unexpected error: "+err.Error()) - assert.Equal(t, 1, exitCode, "exitCode") -} - ->>>>>>> main -// Assuming public Azure, use AAD when SQLCMDUSER environment variable is not set -func canTestAzureAuth() bool { - server := os.Getenv(sqlcmd.SQLCMDSERVER) - userName := os.Getenv(sqlcmd.SQLCMDUSER) - return strings.Contains(server, ".database.windows.net") && userName == "" -} From 677001a4dcc61d17547c1a31a5c2547e858d2fa7 Mon Sep 17 00:00:00 2001 From: stuartpa Date: Wed, 7 Sep 2022 07:24:22 -0700 Subject: [PATCH 3/8] Merge conflict --- cmd/sqlcmd/main_test.go | 80 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/cmd/sqlcmd/main_test.go b/cmd/sqlcmd/main_test.go index 283bd525..27170bce 100644 --- a/cmd/sqlcmd/main_test.go +++ b/cmd/sqlcmd/main_test.go @@ -256,3 +256,83 @@ func TestQueryAndExit(t *testing.T) { assert.Equal(t, "100 val2"+sqlcmd.SqlcmdEol+sqlcmd.SqlcmdEol+oneRowAffected+sqlcmd.SqlcmdEol, string(bytes), "Incorrect output from run") } } + +// Test to verify fix for issue: https://github.com/microsoft/go-sqlcmd/issues/98 +// 1. Verify when -b is passed in (ExitOnError), we don't always get an error (even when input is good) +// 2, Verify when the input is actually bad, we do get an error +func TestExitOnError(t *testing.T) { + args = newArguments() + args.InputFile = []string{"testdata/select100.sql"} + args.ErrorsToStderr = 0 + args.ExitOnError = true + if canTestAzureAuth() { + args.UseAad = true + } + + vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) + setVars(vars, &args) + + exitCode, err := run(vars, &args) + assert.NoError(t, err, "run") + assert.Equal(t, 0, exitCode, "exitCode") + + args.InputFile = []string{"testdata/bad.sql"} + + vars = sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) + setVars(vars, &args) + + exitCode, err = run(vars, &args) + assert.NoError(t, err, "run") + assert.Equal(t, 1, exitCode, "exitCode") + + t.Logf("Test Completed") // Needs some output to stdout to count as a test +} + +func TestAzureAuth(t *testing.T) { + + if !canTestAzureAuth() { + t.Skip("Server name is not an Azure DB name") + } + o, err := os.CreateTemp("", "sqlcmdmain") + assert.NoError(t, err, "os.CreateTemp") + defer os.Remove(o.Name()) + defer o.Close() + args = newArguments() + args.Query = "SELECT 'AZURE'" + args.OutputFile = o.Name() + args.UseAad = true + vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) + vars.Set(sqlcmd.SQLCMDMAXVARTYPEWIDTH, "0") + setVars(vars, &args) + exitCode, err := run(vars, &args) + assert.NoError(t, err, "run") + assert.Equal(t, 0, exitCode, "exitCode") + bytes, err := os.ReadFile(o.Name()) + if assert.NoError(t, err, "os.ReadFile") { + assert.Equal(t, "AZURE"+sqlcmd.SqlcmdEol+sqlcmd.SqlcmdEol+oneRowAffected+sqlcmd.SqlcmdEol, string(bytes), "Incorrect output from run") + } +} + +func TestMissingInputFile(t *testing.T) { + args = newArguments() + args.InputFile = []string{"testdata/missingFile.sql"} + + if canTestAzureAuth() { + args.UseAad = true + } + + vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) + setVars(vars, &args) + + exitCode, err := run(vars, &args) + assert.Error(t, err, "run") + assert.Contains(t, err.Error(), "Error occurred while opening or operating on file", "Unexpected error: "+err.Error()) + assert.Equal(t, 1, exitCode, "exitCode") +} + +// Assuming public Azure, use AAD when SQLCMDUSER environment variable is not set +func canTestAzureAuth() bool { + server := os.Getenv(sqlcmd.SQLCMDSERVER) + userName := os.Getenv(sqlcmd.SQLCMDUSER) + return strings.Contains(server, ".database.windows.net") && userName == "" +} From a0216c7b6a946ff51d595d8d16d2b071f0cfdf7d Mon Sep 17 00:00:00 2001 From: stuartpa Date: Wed, 7 Sep 2022 09:48:22 -0700 Subject: [PATCH 4/8] Refactor --- build/azure-pipelines/build-common.yml | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/build/azure-pipelines/build-common.yml b/build/azure-pipelines/build-common.yml index 5dcf43c7..62e91668 100644 --- a/build/azure-pipelines/build-common.yml +++ b/build/azure-pipelines/build-common.yml @@ -7,17 +7,11 @@ parameters: default: - name: ArtifactName type: string +- name: VersionTag + type: string + default: $(Build.BuildNumber) steps: -- task: PowerShell@2 - displayName: Set last tag to variable - inputs: - targetType: 'inline' - script: | - $VERSION_TAG = git describe --tags (git rev-list --tags --max-count=1) - Write-Host("##vso[task.setvariable variable=VERSION_TAG]$VERSION_TAG") - Write-Host($VERSION_TAG) - - task: GoTool@0 inputs: version: '1.18' @@ -54,7 +48,7 @@ steps: displayName: 'Go: build sqlcmd' inputs: command: 'build' - arguments: '-o $(Build.BinariesDirectory) -ldflags="-X main.version=$(VERSION_TAG)"' + arguments: '-o $(Build.BinariesDirectory) -ldflags="-X main.version=${{ parameters.VersionTag }}"' workingDirectory: '$(Build.SourcesDirectory)/cmd/sqlcmd' env: GOOS: ${{ parameters.OS }} From 75c7dfedf813e21607454c8e5a330b113d9b5c0f Mon Sep 17 00:00:00 2001 From: stuartpa Date: Wed, 7 Sep 2022 10:10:45 -0700 Subject: [PATCH 5/8] Pass version through --- build/azure-pipelines/build-product.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/build/azure-pipelines/build-product.yml b/build/azure-pipelines/build-product.yml index 790ec5c9..67758551 100644 --- a/build/azure-pipelines/build-product.yml +++ b/build/azure-pipelines/build-product.yml @@ -45,6 +45,7 @@ stages: OS: $(os) Arch: $(arch) ArtifactName: $(artifact) + VersionTag: $(VERSION_TAG) - stage: CreatePackages displayName: Create packages to publish From 334326a73438840106c7bb20a4ba3c27d71b5dab Mon Sep 17 00:00:00 2001 From: stuartpa Date: Thu, 8 Sep 2022 03:30:38 -0700 Subject: [PATCH 6/8] Use output variable for version tag --- build/azure-pipelines/build-product.yml | 35 +++++++++++++------------ 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/build/azure-pipelines/build-product.yml b/build/azure-pipelines/build-product.yml index 67758551..07e8a1d9 100644 --- a/build/azure-pipelines/build-product.yml +++ b/build/azure-pipelines/build-product.yml @@ -9,7 +9,7 @@ stages: - stage: Compile displayName: Compile sqlcmd on all 3 platforms jobs: - - job: Compile_sqlcmd + - job: Sqlcmd strategy: matrix: linux: @@ -40,6 +40,15 @@ stages: pool: vmImage: $(imageName) steps: + - task: PowerShell@2 + displayName: Set last tag to variable + name: getVersion + inputs: + targetType: 'inline' + script: | + $VERSION_TAG = git describe --tags (git rev-list --tags --max-count=1) + Write-Host("##vso[task.setvariable variable=VERSION_TAG]$VERSION_TAG") + Write-Host($VERSION_TAG) - template: build-common.yml parameters: OS: $(os) @@ -53,21 +62,13 @@ stages: - job: Sign_and_pack pool: vmImage: 'windows-latest' + variables: + versionTag: $[ stageDependencies.Compile.Sqlcmd.outputs['getVersion.VERSION_TAG'] ] steps: - - task: PowerShell@2 - displayName: Set last tag to variable - inputs: - targetType: 'inline' - script: | - $VERSION_TAG = git describe --tags (git rev-list --tags --max-count=1) - Write-Host("##vso[task.setvariable variable=VERSION_TAG]$VERSION_TAG") - Write-Host($VERSION_TAG) - - task: DownloadPipelineArtifact@2 inputs: buildType: 'current' targetPath: '$(Pipeline.Workspace)' - - task: EsrpCodeSigning@1 displayName: Sign Windows binary inputs: @@ -127,7 +128,7 @@ stages: rootFolderOrFile: '$(Pipeline.Workspace)\SqlcmdWindowsAmd64\Sqlcmd.exe' includeRootFolder: false archiveType: 'zip' - archiveFile: '$(Build.ArtifactStagingDirectory)/sqlcmd-$(VERSION_TAG)-windows-x64.zip' + archiveFile: '$(Build.ArtifactStagingDirectory)/sqlcmd-$(versionTag)-windows-x64.zip' - task: ArchiveFiles@2 displayName: Zip Windows arm binary @@ -135,7 +136,7 @@ stages: rootFolderOrFile: '$(Pipeline.Workspace)\SqlcmdWindowsArm\Sqlcmd.exe' includeRootFolder: false archiveType: 'zip' - archiveFile: '$(Build.ArtifactStagingDirectory)/sqlcmd-$(VERSION_TAG)-windows-arm.zip' + archiveFile: '$(Build.ArtifactStagingDirectory)/sqlcmd-$(versionTag)-windows-arm.zip' - task: ArchiveFiles@2 displayName: Tar Linux amd64 binary @@ -144,7 +145,7 @@ stages: includeRootFolder: false archiveType: 'tar' tarCompression: 'bz2' - archiveFile: '$(Build.ArtifactStagingDirectory)/sqlcmd-$(VERSION_TAG)-linux-x64.tar.bz2' + archiveFile: '$(Build.ArtifactStagingDirectory)/sqlcmd-$(versionTag)-linux-x64.tar.bz2' - task: ArchiveFiles@2 displayName: Tar Darwin binary @@ -153,7 +154,7 @@ stages: includeRootFolder: false archiveType: 'tar' tarCompression: 'bz2' - archiveFile: '$(Build.ArtifactStagingDirectory)/sqlcmd-$(VERSION_TAG)-darwin-x64.tar.bz2' + archiveFile: '$(Build.ArtifactStagingDirectory)/sqlcmd-$(versionTag)-darwin-x64.tar.bz2' - task: ArchiveFiles@2 displayName: Tar Linux arm64 binary @@ -162,7 +163,7 @@ stages: includeRootFolder: false archiveType: 'tar' tarCompression: 'bz2' - archiveFile: '$(Build.ArtifactStagingDirectory)/sqlcmd-$(VERSION_TAG)-linux-arm64.tar.bz2' + archiveFile: '$(Build.ArtifactStagingDirectory)/sqlcmd-$(versionTag)-linux-arm64.tar.bz2' - task: PublishPipelineArtifact@1 displayName: 'Publish release archives' @@ -177,6 +178,6 @@ stages: action: 'create' target: '$(Build.SourceVersion)' tagSource: 'userSpecifiedTag' - tag: '$(VERSION_TAG)' + tag: '$(versionTag)' changeLogCompareToRelease: 'lastFullRelease' changeLogType: 'commitBased' From 119c90ff4d572202b0d21940c006d6fa1829df69 Mon Sep 17 00:00:00 2001 From: stuartpa Date: Thu, 8 Sep 2022 03:48:29 -0700 Subject: [PATCH 7/8] Remove extra 'v' --- cmd/sqlcmd/main.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/sqlcmd/main.go b/cmd/sqlcmd/main.go index 48290fef..b89605d5 100644 --- a/cmd/sqlcmd/main.go +++ b/cmd/sqlcmd/main.go @@ -13,7 +13,7 @@ import ( "github.com/microsoft/go-sqlcmd/pkg/sqlcmd" ) -var version = "Local-build" // overridden in pipeline builds with: -ldflags="-X main.version=$(Build.BuildNumber)" +var version = "Local-build" // overridden in pipeline builds with: -ldflags="-X main.version=$(VersionTag)" // SQLCmdArguments defines the command line arguments for sqlcmd // The exhaustive list is at https://docs.microsoft.com/sql/tools/sqlcmd-utility?view=sql-server-ver15 @@ -108,7 +108,7 @@ func (a SQLCmdArguments) authenticationMethod(hasPassword bool) string { func main() { ctx := kong.Parse(&args, kong.NoDefaultHelp()) if args.Version { - ctx.Printf("v%v", version) + ctx.Printf("%v", version) os.Exit(0) } if args.Help { From 2b9065d4f0490fd2ffe42c831bd6c024e9945666 Mon Sep 17 00:00:00 2001 From: stuartpa Date: Thu, 15 Sep 2022 09:55:56 -0700 Subject: [PATCH 8/8] Merge changes --- build/azure-pipelines/build-product.yml | 2 +- release/windows/choco/tools/VERIFICATION.txt | 9 +-------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/build/azure-pipelines/build-product.yml b/build/azure-pipelines/build-product.yml index e4cd05ef..6913a742 100644 --- a/build/azure-pipelines/build-product.yml +++ b/build/azure-pipelines/build-product.yml @@ -183,7 +183,7 @@ stages: includeRootFolder: false archiveType: 'tar' tarCompression: 'bz2' - archiveFile: '$(Build.ArtifactStagingDirectory)/sqlcmd-$(VERSION_TAG)-linux-s390x.tar.bz2' + archiveFile: '$(Build.ArtifactStagingDirectory)/sqlcmd-$(versionTag)-linux-s390x.tar.bz2' - task: PublishPipelineArtifact@1 displayName: 'Publish release archives' diff --git a/release/windows/choco/tools/VERIFICATION.txt b/release/windows/choco/tools/VERIFICATION.txt index f4fe34a7..7a1a6530 100644 --- a/release/windows/choco/tools/VERIFICATION.txt +++ b/release/windows/choco/tools/VERIFICATION.txt @@ -1,11 +1,4 @@ - -Note: Include this file if including binaries you have the right to distribute. -Otherwise delete. this file. If you are the software author, you can change this -mention you are the author of the software. - -===DELETE ABOVE THIS LINE AND THIS LINE=== - -VERIFICATION +VERIFICATION Verification is intended to assist the Chocolatey moderators and community in verifying that this package's contents are trustworthy.