diff --git a/cmd/sqlcmd/sqlcmd.go b/cmd/sqlcmd/sqlcmd.go index 5abc0860..766fd018 100644 --- a/cmd/sqlcmd/sqlcmd.go +++ b/cmd/sqlcmd/sqlcmd.go @@ -83,6 +83,7 @@ type SQLCmdArguments struct { ChangePasswordAndExit string TraceFile string ServerNameOverride string + RawErrors bool // Keep Help at the end of the list Help bool } @@ -481,6 +482,7 @@ func setFlags(rootCmd *cobra.Command, args *SQLCmdArguments) { rootCmd.Flags().IntVar(&args.DriverLoggingLevel, "driver-logging-level", 0, localizer.Sprintf("Level of mssql driver messages to print")) rootCmd.Flags().BoolVarP(&args.ExitOnError, "exit-on-error", "b", false, localizer.Sprintf("Specifies that sqlcmd exits and returns a %s value when an error occurs", localizer.DosErrorLevel)) rootCmd.Flags().IntVarP(&args.ErrorLevel, "error-level", "m", 0, localizer.Sprintf("Controls which error messages are sent to %s. Messages that have severity level greater than or equal to this level are sent", localizer.StdoutName)) + rootCmd.Flags().BoolVarP(&args.RawErrors, "raw-errors", "j", false, localizer.Sprintf("Do not strip the \"mssql: \" prefix from error messages")) //Need to decide on short of Header , as "h" is already used in help command in Cobra rootCmd.Flags().IntVarP(&args.Headers, "headers", "h", 0, localizer.Sprintf("Specifies the number of rows to print between the column headings. Use -h-1 to specify that headers not be printed")) @@ -862,7 +864,7 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { } s.Connect = &connectConfig - s.Format = sqlcmd.NewSQLCmdDefaultFormatter(args.TrimSpaces, args.getControlCharacterBehavior()) + s.Format = sqlcmd.NewSQLCmdDefaultFormatter(args.TrimSpaces, args.getControlCharacterBehavior(), sqlcmd.WithRawErrors(args.RawErrors)) if args.OutputFile != "" { err = s.RunCommand(s.Cmd["OUT"], []string{args.OutputFile}) if err != nil { diff --git a/cmd/sqlcmd/sqlcmd_test.go b/cmd/sqlcmd/sqlcmd_test.go index 4554998a..5f6f2740 100644 --- a/cmd/sqlcmd/sqlcmd_test.go +++ b/cmd/sqlcmd/sqlcmd_test.go @@ -123,6 +123,12 @@ func TestValidCommandLineToArgsConversion(t *testing.T) { {[]string{"-N", "true", "-J", "/path/to/cert2.pem"}, func(args SQLCmdArguments) bool { return args.EncryptConnection == "true" && args.ServerCertificate == "/path/to/cert2.pem" }}, + {[]string{"-j"}, func(args SQLCmdArguments) bool { + return args.RawErrors + }}, + {[]string{"--raw-errors"}, func(args SQLCmdArguments) bool { + return args.RawErrors + }}, } for _, test := range commands { diff --git a/pkg/sqlcmd/format.go b/pkg/sqlcmd/format.go index 55bd2e25..3a016b62 100644 --- a/pkg/sqlcmd/format.go +++ b/pkg/sqlcmd/format.go @@ -85,16 +85,35 @@ type sqlCmdFormatterType struct { maxColNameLen int colorizer color.Colorizer xml bool + rawErrors bool } -// NewSQLCmdDefaultFormatter returns a Formatter that mimics the original ODBC-based sqlcmd formatter -func NewSQLCmdDefaultFormatter(removeTrailingSpaces bool, ccb ControlCharacterBehavior) Formatter { - return &sqlCmdFormatterType{ +// FormatterOption customizes the default formatter built by NewSQLCmdDefaultFormatter. +// Use the provided With* constructors (e.g. WithRawErrors) to supply options. +type FormatterOption func(*sqlCmdFormatterType) + +// WithRawErrors controls AddError prefix handling: when raw is true, AddError +// keeps the "mssql: " prefix that go-mssqldb adds to error text instead of +// stripping it. +func WithRawErrors(raw bool) FormatterOption { + return func(f *sqlCmdFormatterType) { f.rawErrors = raw } +} + +// NewSQLCmdDefaultFormatter returns a Formatter that mimics the original ODBC-based sqlcmd formatter. +// Any FormatterOption values passed via opts (e.g. WithRawErrors) are applied to the new formatter. +func NewSQLCmdDefaultFormatter(removeTrailingSpaces bool, ccb ControlCharacterBehavior, opts ...FormatterOption) Formatter { + f := &sqlCmdFormatterType{ removeTrailingSpaces: removeTrailingSpaces, format: "horizontal", colorizer: color.New(false), ccb: ccb, } + for _, opt := range opts { + if opt != nil { + opt(f) + } + } + return f } // Adds the given string to the current line, wrapping it based on the screen width setting @@ -228,7 +247,9 @@ func (f *sqlCmdFormatterType) AddError(err error) { } else { b.WriteString(localizer.Sprintf("Msg %#v, Level %d, State %d, Server %s, Line %#v%s", e.Number, e.Class, e.State, e.ServerName, e.LineNo, SqlcmdEol)) } - msg = strings.TrimPrefix(msg, "mssql: ") + if !f.rawErrors { + msg = strings.TrimPrefix(msg, "mssql: ") + } } } if print { diff --git a/pkg/sqlcmd/format_test.go b/pkg/sqlcmd/format_test.go index f4bee464..f51f2cbc 100644 --- a/pkg/sqlcmd/format_test.go +++ b/pkg/sqlcmd/format_test.go @@ -8,6 +8,7 @@ import ( "strings" "testing" + mssql "github.com/microsoft/go-mssqldb" "github.com/microsoft/go-sqlcmd/internal/color" "github.com/stretchr/testify/assert" ) @@ -162,3 +163,28 @@ func TestFormatterXmlMode(t *testing.T) { assert.NoError(t, err, "runSqlCmd returned error") assert.Equal(t, ``+SqlcmdEol, buf.buf.String()) } + +func TestAddErrorStripsMssqlPrefixByDefault(t *testing.T) { + out, errOut := new(strings.Builder), new(strings.Builder) + f := NewSQLCmdDefaultFormatter(false, ControlIgnore) + f.BeginBatch("", InitializeVariables(false), out, errOut) + + f.AddError(mssql.Error{Number: 50000, State: 1, Class: 16, Message: "Something failed", ServerName: "server", LineNo: 7}) + + got := errOut.String() + assert.Contains(t, got, "Msg 50000, Level 16, State 1, Server server, Line 7") + assert.Contains(t, got, "Something failed") + assert.NotContains(t, got, "mssql:") +} + +func TestAddErrorWithRawErrorsKeepsMssqlPrefix(t *testing.T) { + out, errOut := new(strings.Builder), new(strings.Builder) + f := NewSQLCmdDefaultFormatter(false, ControlIgnore, WithRawErrors(true)) + f.BeginBatch("", InitializeVariables(false), out, errOut) + + f.AddError(mssql.Error{Number: 50000, State: 1, Class: 16, Message: "Something failed", ServerName: "server", LineNo: 7}) + + got := errOut.String() + assert.Contains(t, got, "Msg 50000, Level 16, State 1, Server server, Line 7") + assert.Contains(t, got, "mssql: Something failed") +}