diff --git a/cmd/sqlcmd/main.go b/cmd/sqlcmd/main.go index 37a85f48..d39c839e 100644 --- a/cmd/sqlcmd/main.go +++ b/cmd/sqlcmd/main.go @@ -230,6 +230,10 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { s := sqlcmd.New(line, wd, vars) s.UnicodeOutputFile = args.UnicodeOutputFile + if args.DisableCmdAndWarn { + s.Cmd.DisableSysCommands(false) + } + if args.BatchTerminator != "GO" { err = s.Cmd.SetBatchTerminator(args.BatchTerminator) if err != nil { diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go index f9eec1bf..9deb3d09 100644 --- a/pkg/sqlcmd/commands.go +++ b/pkg/sqlcmd/commands.go @@ -94,9 +94,17 @@ func newCommands() Commands { name: "EXEC", }, } - } +// DisableSysCommands disables the ED and :!! commands. +// When exitOnCall is true, running those commands will exit the process. +func (c Commands) DisableSysCommands(exitOnCall bool) { + f := warnDisabled + if exitOnCall { + f = errorDisabled + } + c["EXEC"].action = f +} func (c Commands) matchCommand(line string) (*Command, []string) { for _, cmd := range c { matchedCommand := cmd.regex.FindStringSubmatch(line) @@ -107,6 +115,17 @@ func (c Commands) matchCommand(line string) (*Command, []string) { return nil, nil } +func warnDisabled(s *Sqlcmd, args []string, line uint) error { + _, _ = s.GetError().Write([]byte(ErrCommandsDisabled.Error() + SqlcmdEol)) + return nil +} + +func errorDisabled(s *Sqlcmd, args []string, line uint) error { + _, _ = s.GetError().Write([]byte(ErrCommandsDisabled.Error() + SqlcmdEol)) + s.Exitcode = 1 + return ErrExitRequested +} + func batchTerminatorRegex(terminator string) string { return fmt.Sprintf(`(?im)^[\t ]*?%s(?:[ ]+(.*$)|$)`, regexp.QuoteMeta(terminator)) } diff --git a/pkg/sqlcmd/commands_test.go b/pkg/sqlcmd/commands_test.go index 9299cac7..5a59444c 100644 --- a/pkg/sqlcmd/commands_test.go +++ b/pkg/sqlcmd/commands_test.go @@ -267,3 +267,23 @@ func TestExecCommand(t *testing.T) { assert.Equal(t, buf.buf.String(), "hello"+SqlcmdEol, "echo output should be in sqlcmd output") } } + +func TestDisableSysCommandBlocksExec(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + s.Cmd.DisableSysCommands(false) + c := []string{"set nocount on", ":!! echo hello", "select 100", "go"} + err := runSqlCmd(t, s, c) + if assert.NoError(t, err, ":!! with warning should not raise error") { + assert.Contains(t, buf.buf.String(), ErrCommandsDisabled.Error()+SqlcmdEol+"100"+SqlcmdEol) + assert.Equal(t, 0, s.Exitcode, "ExitCode after warning") + } + buf.buf.Reset() + s.Cmd.DisableSysCommands(true) + err = runSqlCmd(t, s, c) + if assert.NoError(t, err, ":!! with error should not return error") { + assert.Contains(t, buf.buf.String(), ErrCommandsDisabled.Error()+SqlcmdEol) + assert.NotContains(t, buf.buf.String(), "100", "query should not run when syscommand disabled") + assert.Equal(t, 1, s.Exitcode, "ExitCode after error") + } +} diff --git a/pkg/sqlcmd/sqlcmd.go b/pkg/sqlcmd/sqlcmd.go index d0ebb491..1bd95d81 100644 --- a/pkg/sqlcmd/sqlcmd.go +++ b/pkg/sqlcmd/sqlcmd.go @@ -32,6 +32,8 @@ var ( ErrNeedPassword = errors.New("need password") // ErrCtrlC indicates execution was ended by ctrl-c or ctrl-break ErrCtrlC = errors.New(WarningPrefix + "The last operation was terminated because the user pressed CTRL+C") + // ErrCommandsDisabled indicates system commands and startup script are disabled + ErrCommandsDisabled = errors.New(ErrorPrefix + "ED and !! commands, startup script, and environment variables are disabled.") ) const maxLineBuffer = 2 * 1024 * 1024 // 2Mb