From 6929ac2d66f91d8fdf347c999ee0caa41acdded7 Mon Sep 17 00:00:00 2001 From: simon Date: Wed, 27 May 2026 12:43:29 +0200 Subject: [PATCH 1/2] aitools: add --param flag for parameterized SQL queries Wires the Statement Execution API's named parameter support into the experimental aitools query and statement submit commands. Use ":name" markers in SQL and bind values with "--param name=value" or "--param name:TYPE=value" (typed). Empty value is sent as NULL. Co-authored-by: Isaac --- experimental/aitools/cmd/batch.go | 11 +- experimental/aitools/cmd/batch_test.go | 31 ++++- experimental/aitools/cmd/params.go | 50 +++++++ experimental/aitools/cmd/params_test.go | 129 +++++++++++++++++++ experimental/aitools/cmd/query.go | 28 +++- experimental/aitools/cmd/query_test.go | 30 ++++- experimental/aitools/cmd/statement_submit.go | 22 +++- experimental/aitools/cmd/statement_test.go | 24 +++- 8 files changed, 299 insertions(+), 26 deletions(-) create mode 100644 experimental/aitools/cmd/params.go create mode 100644 experimental/aitools/cmd/params_test.go diff --git a/experimental/aitools/cmd/batch.go b/experimental/aitools/cmd/batch.go index 38ecea531e6..f63599c6e82 100644 --- a/experimental/aitools/cmd/batch.go +++ b/experimental/aitools/cmd/batch.go @@ -53,7 +53,11 @@ type batchResultError struct { // On context cancellation (Ctrl+C or parent context), still-running statements // are cancelled server-side via CancelExecution. Statements that finished // before cancellation are left as-is. -func executeBatch(ctx context.Context, api sql.StatementExecutionInterface, warehouseID string, sqls []string, concurrency int) []batchResult { +// +// params, if non-nil, are bound on every statement. The same parameter set is +// reused across the batch, so callers must ensure each SQL uses only markers +// that are covered. +func executeBatch(ctx context.Context, api sql.StatementExecutionInterface, warehouseID string, sqls []string, params []sql.StatementParameterListItem, concurrency int) []batchResult { pollCtx, pollCancel := context.WithCancel(ctx) defer pollCancel() @@ -97,7 +101,7 @@ func executeBatch(ctx context.Context, api sql.StatementExecutionInterface, ware g.SetLimit(concurrency) for i, sqlStr := range sqls { g.Go(func() error { - results[i] = runOneBatchQuery(pollCtx, api, warehouseID, sqlStr, statementIDs, i) + results[i] = runOneBatchQuery(pollCtx, api, warehouseID, sqlStr, params, statementIDs, i) completed.Add(1) return nil }) @@ -115,13 +119,14 @@ func executeBatch(ctx context.Context, api sql.StatementExecutionInterface, ware // runOneBatchQuery submits one SQL, polls to completion, and returns its // batchResult. All errors are encoded into the result; never returns an error. -func runOneBatchQuery(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, sqlStr string, statementIDs []string, idx int) batchResult { +func runOneBatchQuery(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, sqlStr string, params []sql.StatementParameterListItem, statementIDs []string, idx int) batchResult { start := time.Now() result := batchResult{SQL: sqlStr} resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{ WarehouseId: warehouseID, Statement: sqlStr, + Parameters: params, WaitTimeout: "0s", OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue, }) diff --git a/experimental/aitools/cmd/batch_test.go b/experimental/aitools/cmd/batch_test.go index f6f468768f9..8288a43da5d 100644 --- a/experimental/aitools/cmd/batch_test.go +++ b/experimental/aitools/cmd/batch_test.go @@ -73,7 +73,7 @@ func TestExecuteBatchAllSucceed(t *testing.T) { }, nil).Once() } - results := executeBatch(ctx, mockAPI, "wh-123", sqls, 8) + results := executeBatch(ctx, mockAPI, "wh-123", sqls, nil, 8) require.Len(t, results, 3) for i, r := range results { @@ -86,6 +86,25 @@ func TestExecuteBatchAllSucceed(t *testing.T) { } } +func TestExecuteBatchPassesParametersToAllStatements(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + params := []sql.StatementParameterListItem{ + {Name: "since", Type: "DATE", Value: "2026-01-01"}, + } + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return assert.ObjectsAreEqual(params, req.Parameters) + })).Return(&sql.StatementResponse{ + StatementId: "stmt", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + }, nil).Times(2) + + results := executeBatch(ctx, mockAPI, "wh-1", []string{"SELECT 1 WHERE 1=1 AND :since IS NOT NULL", "SELECT 2 WHERE :since IS NOT NULL"}, params, 8) + require.Len(t, results, 2) +} + func TestExecuteBatchPartialFailure(t *testing.T) { ctx := cmdio.MockDiscard(t.Context()) mockAPI := mocksql.NewMockStatementExecutionInterface(t) @@ -112,7 +131,7 @@ func TestExecuteBatchPartialFailure(t *testing.T) { }, }, nil).Once() - results := executeBatch(ctx, mockAPI, "wh-123", []string{"SELECT 1", "SELECT bad"}, 8) + results := executeBatch(ctx, mockAPI, "wh-123", []string{"SELECT 1", "SELECT bad"}, nil, 8) require.Len(t, results, 2) assert.Nil(t, results[0].Error) @@ -141,7 +160,7 @@ func TestExecuteBatchSubmissionFailure(t *testing.T) { return req.Statement == "SELECT broken" })).Return(nil, errors.New("network unreachable")).Once() - results := executeBatch(ctx, mockAPI, "wh-123", []string{"SELECT good", "SELECT broken"}, 8) + results := executeBatch(ctx, mockAPI, "wh-123", []string{"SELECT good", "SELECT broken"}, nil, 8) require.Len(t, results, 2) assert.Nil(t, results[0].Error) @@ -163,7 +182,7 @@ func TestExecuteBatchSetsOnWaitTimeoutContinue(t *testing.T) { Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, }, nil).Times(2) - results := executeBatch(ctx, mockAPI, "wh-123", []string{"q1", "q2"}, 8) + results := executeBatch(ctx, mockAPI, "wh-123", []string{"q1", "q2"}, nil, 8) require.Len(t, results, 2) } @@ -196,7 +215,7 @@ func TestExecuteBatchPreservesInputOrder(t *testing.T) { } sqls := []string{"SELECT 'slow'", "SELECT 'fast1'", "SELECT 'fast2'"} - results := executeBatch(ctx, mockAPI, "wh-1", sqls, 8) + results := executeBatch(ctx, mockAPI, "wh-1", sqls, nil, 8) require.Len(t, results, 3) for i, r := range results { @@ -233,7 +252,7 @@ func TestExecuteBatchContextCancellationCancelsInFlight(t *testing.T) { cancel() - results := executeBatch(ctx, mockAPI, "wh", []string{"q1", "q2", "q3"}, 8) + results := executeBatch(ctx, mockAPI, "wh", []string{"q1", "q2", "q3"}, nil, 8) require.Len(t, results, 3) for i, r := range results { diff --git a/experimental/aitools/cmd/params.go b/experimental/aitools/cmd/params.go new file mode 100644 index 00000000000..5edc9e17fca --- /dev/null +++ b/experimental/aitools/cmd/params.go @@ -0,0 +1,50 @@ +package aitools + +import ( + "fmt" + "strings" + + "github.com/databricks/databricks-sdk-go/service/sql" +) + +// parseParams converts --param flag values into SDK parameter list items for +// the Databricks SQL Statement Execution API. Each input is either +// "name=value" (defaults to STRING server-side) or "name:TYPE=value" (typed, +// e.g. "since:DATE=2026-01-01"). An empty value becomes NULL on the wire +// because StatementParameterListItem.Value uses omitempty. +// +// The Databricks API only supports named markers (`:name`), not positional +// `?`, and parameter names must be unique within a statement. +func parseParams(raw []string) ([]sql.StatementParameterListItem, error) { + if len(raw) == 0 { + return nil, nil + } + + out := make([]sql.StatementParameterListItem, 0, len(raw)) + seen := make(map[string]struct{}, len(raw)) + for _, s := range raw { + lhs, value, ok := strings.Cut(s, "=") + if !ok { + return nil, fmt.Errorf("invalid --param %q: expected name=value or name:TYPE=value", s) + } + + name, typ, _ := strings.Cut(lhs, ":") + name = strings.TrimSpace(name) + typ = strings.TrimSpace(typ) + + if name == "" { + return nil, fmt.Errorf("invalid --param %q: name is empty", s) + } + if _, dup := seen[name]; dup { + return nil, fmt.Errorf("duplicate --param name %q", name) + } + seen[name] = struct{}{} + + out = append(out, sql.StatementParameterListItem{ + Name: name, + Type: typ, + Value: value, + }) + } + return out, nil +} diff --git a/experimental/aitools/cmd/params_test.go b/experimental/aitools/cmd/params_test.go new file mode 100644 index 00000000000..e4a56f8a878 --- /dev/null +++ b/experimental/aitools/cmd/params_test.go @@ -0,0 +1,129 @@ +package aitools + +import ( + "testing" + + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseParams(t *testing.T) { + tests := []struct { + name string + in []string + want []sql.StatementParameterListItem + }{ + { + name: "nil input returns nil", + in: nil, + want: nil, + }, + { + name: "empty input returns nil", + in: []string{}, + want: nil, + }, + { + name: "single string param defaults type to empty (server-side STRING)", + in: []string{"name=alice"}, + want: []sql.StatementParameterListItem{ + {Name: "name", Value: "alice"}, + }, + }, + { + name: "typed param splits name and type on first colon", + in: []string{"since:DATE=2026-01-01"}, + want: []sql.StatementParameterListItem{ + {Name: "since", Type: "DATE", Value: "2026-01-01"}, + }, + }, + { + name: "value can contain = and :", + in: []string{"clause=ts >= '2026-01-01T00:00:00'"}, + want: []sql.StatementParameterListItem{ + {Name: "clause", Value: "ts >= '2026-01-01T00:00:00'"}, + }, + }, + { + name: "decimal type with parens preserved", + in: []string{"amount:DECIMAL(10,2)=42.50"}, + want: []sql.StatementParameterListItem{ + {Name: "amount", Type: "DECIMAL(10,2)", Value: "42.50"}, + }, + }, + { + name: "empty value becomes NULL on the wire via omitempty", + in: []string{"opt="}, + want: []sql.StatementParameterListItem{ + {Name: "opt", Value: ""}, + }, + }, + { + name: "whitespace around name and type is trimmed", + in: []string{" name : INT =42"}, + want: []sql.StatementParameterListItem{ + {Name: "name", Type: "INT", Value: "42"}, + }, + }, + { + name: "multiple params preserve input order", + in: []string{"a=1", "b:INT=2", "c=three"}, + want: []sql.StatementParameterListItem{ + {Name: "a", Value: "1"}, + {Name: "b", Type: "INT", Value: "2"}, + {Name: "c", Value: "three"}, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := parseParams(tc.in) + require.NoError(t, err) + assert.Equal(t, tc.want, got) + }) + } +} + +func TestParseParamsErrors(t *testing.T) { + tests := []struct { + name string + in []string + wantMsg string + }{ + { + name: "no equals sign", + in: []string{"foo"}, + wantMsg: "expected name=value", + }, + { + name: "empty name", + in: []string{"=value"}, + wantMsg: "name is empty", + }, + { + name: "empty name with type", + in: []string{":INT=42"}, + wantMsg: "name is empty", + }, + { + name: "whitespace-only name", + in: []string{" =value"}, + wantMsg: "name is empty", + }, + { + name: "duplicate name", + in: []string{"name=alice", "name=bob"}, + wantMsg: `duplicate --param name "name"`, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := parseParams(tc.in) + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantMsg) + }) + } +} diff --git a/experimental/aitools/cmd/query.go b/experimental/aitools/cmd/query.go index c6291733edb..7f4818afeb1 100644 --- a/experimental/aitools/cmd/query.go +++ b/experimental/aitools/cmd/query.go @@ -68,6 +68,7 @@ func newQueryCmd() *cobra.Command { var filePaths []string var outputFormat string var concurrency int + var paramFlags []string cmd := &cobra.Command{ Use: "query [SQL | file.sql]...", @@ -91,13 +92,19 @@ or the DATABRICKS_WAREHOUSE_ID environment variable is configured. For a single query, output is JSON in non-interactive contexts. In interactive terminals it renders tables, and large results open an -interactive table browser. Use --output csv to export results as CSV.`, +interactive table browser. Use --output csv to export results as CSV. + +Pass named parameters with --param. Use ":name" markers in the SQL and +"--param name=value" (string) or "--param name:TYPE=value" (typed, e.g. +DATE, INT) to bind values. Positional "?" markers are not supported.`, Example: ` databricks experimental aitools tools query "SELECT * FROM samples.nyctaxi.trips LIMIT 5" databricks experimental aitools tools query --warehouse abc123 "SELECT 1" databricks experimental aitools tools query --file report.sql databricks experimental aitools tools query report.sql databricks experimental aitools tools query --output csv "SELECT * FROM samples.nyctaxi.trips LIMIT 5" databricks experimental aitools tools query --output json "SELECT 1" "SELECT 2" "SELECT 3" + databricks experimental aitools tools query --param name=alice "SELECT * FROM users WHERE name = :name" + databricks experimental aitools tools query --param since:DATE=2026-01-01 "SELECT * FROM events WHERE ts > :since" echo "SELECT 1" | databricks experimental aitools tools query`, Args: cobra.ArbitraryArgs, PreRunE: func(cmd *cobra.Command, args []string) error { @@ -125,6 +132,11 @@ interactive table browser. Use --output csv to export results as CSV.`, return err } + params, err := parseParams(paramFlags) + if err != nil { + return err + } + // Reject incompatible flag combinations before any API call so the // user sees the real error instead of an auth/warehouse failure. if len(sqls) > 1 && format != sqlcli.OutputJSON { @@ -139,10 +151,10 @@ interactive table browser. Use --output csv to export results as CSV.`, } if len(sqls) > 1 { - return runBatch(ctx, cmd, w.StatementExecution, wID, sqls, concurrency) + return runBatch(ctx, cmd, w.StatementExecution, wID, sqls, params, concurrency) } - resp, err := executeAndPoll(ctx, w.StatementExecution, wID, sqls[0]) + resp, err := executeAndPoll(ctx, w.StatementExecution, wID, sqls[0], params) if err != nil { return err } @@ -185,6 +197,7 @@ interactive table browser. Use --output csv to export results as CSV.`, cmd.Flags().StringVarP(&warehouseID, "warehouse", "w", "", "SQL warehouse ID to use for execution") cmd.Flags().StringSliceVarP(&filePaths, "file", "f", nil, "Path to a SQL file to execute (repeatable; pair with positional SQLs to run a batch)") cmd.Flags().IntVar(&concurrency, "concurrency", defaultBatchConcurrency, "Maximum in-flight statements when running a batch of queries") + cmd.Flags().StringArrayVar(¶mFlags, "param", nil, "Named parameter, repeatable. Format: name=value (STRING) or name:TYPE=value (e.g. name:DATE=2026-01-01). Empty value is sent as NULL.") // Local --output flag shadows the root command's persistent --output flag, // adding csv support for this command only. cmd.Flags().StringVarP(&outputFormat, "output", "o", string(sqlcli.OutputText), "Output format: text, json, or csv") @@ -222,8 +235,10 @@ func resolveSQLs(ctx context.Context, cmd *cobra.Command, args, filePaths []stri // without an extra error message) when any statement failed; the failure detail // is already encoded in the printed JSON. The caller is responsible for // rejecting incompatible output formats before invoking this. -func runBatch(ctx context.Context, cmd *cobra.Command, api sql.StatementExecutionInterface, warehouseID string, sqls []string, concurrency int) error { - results := executeBatch(ctx, api, warehouseID, sqls, concurrency) +// +// params, if non-nil, are applied to every statement in the batch. +func runBatch(ctx context.Context, cmd *cobra.Command, api sql.StatementExecutionInterface, warehouseID string, sqls []string, params []sql.StatementParameterListItem, concurrency int) error { + results := executeBatch(ctx, api, warehouseID, sqls, params, concurrency) if err := renderBatchJSON(cmd.OutOrStdout(), results); err != nil { return err } @@ -252,11 +267,12 @@ func resolveWarehouseID(ctx context.Context, w any, flagValue string) (string, e // executeAndPoll submits a SQL statement asynchronously and polls until completion. // It shows a spinner in interactive mode and supports Ctrl+C cancellation. -func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, statement string) (*sql.StatementResponse, error) { +func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, statement string, params []sql.StatementParameterListItem) (*sql.StatementResponse, error) { // Submit asynchronously to get the statement ID immediately for cancellation. resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{ WarehouseId: warehouseID, Statement: statement, + Parameters: params, WaitTimeout: "0s", OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue, }) diff --git a/experimental/aitools/cmd/query_test.go b/experimental/aitools/cmd/query_test.go index 8458629c8e7..eaaa50b4967 100644 --- a/experimental/aitools/cmd/query_test.go +++ b/experimental/aitools/cmd/query_test.go @@ -59,12 +59,32 @@ func TestExecuteAndPollImmediateSuccess(t *testing.T) { Result: &sql.ResultData{DataArray: [][]string{{"1"}}}, }, nil) - resp, err := executeAndPoll(ctx, mockAPI, "wh-123", "SELECT 1") + resp, err := executeAndPoll(ctx, mockAPI, "wh-123", "SELECT 1", nil) require.NoError(t, err) assert.Equal(t, sql.StatementStateSucceeded, resp.Status.State) assert.Equal(t, "stmt-1", resp.StatementId) } +func TestExecuteAndPollPassesParameters(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + params := []sql.StatementParameterListItem{ + {Name: "name", Value: "alice"}, + {Name: "since", Type: "DATE", Value: "2026-01-01"}, + } + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return assert.ObjectsAreEqual(params, req.Parameters) + })).Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + }, nil) + + _, err := executeAndPoll(ctx, mockAPI, "wh-123", "SELECT * FROM t WHERE name = :name AND ts > :since", params) + require.NoError(t, err) +} + func TestExecuteAndPollImmediateFailure(t *testing.T) { ctx := cmdio.MockDiscard(t.Context()) mockAPI := mocksql.NewMockStatementExecutionInterface(t) @@ -80,7 +100,7 @@ func TestExecuteAndPollImmediateFailure(t *testing.T) { }, }, nil) - _, err := executeAndPoll(ctx, mockAPI, "wh-123", "SELCT 1") + _, err := executeAndPoll(ctx, mockAPI, "wh-123", "SELCT 1", nil) require.Error(t, err) assert.Contains(t, err.Error(), "SYNTAX_ERROR") assert.Contains(t, err.Error(), "syntax error") @@ -109,7 +129,7 @@ func TestExecuteAndPollWithPolling(t *testing.T) { Result: &sql.ResultData{DataArray: [][]string{{"42"}}}, }, nil).Once() - resp, err := executeAndPoll(ctx, mockAPI, "wh-123", "SELECT 42") + resp, err := executeAndPoll(ctx, mockAPI, "wh-123", "SELECT 42", nil) require.NoError(t, err) assert.Equal(t, sql.StatementStateSucceeded, resp.Status.State) assert.Equal(t, [][]string{{"42"}}, resp.Result.DataArray) @@ -132,7 +152,7 @@ func TestExecuteAndPollFailsDuringPolling(t *testing.T) { }, }, nil).Once() - _, err := executeAndPoll(ctx, mockAPI, "wh-123", "SELECT 1") + _, err := executeAndPoll(ctx, mockAPI, "wh-123", "SELECT 1", nil) require.Error(t, err) assert.Contains(t, err.Error(), "RESOURCE_EXHAUSTED") } @@ -157,7 +177,7 @@ func TestExecuteAndPollCancelledContextCallsCancelExecution(t *testing.T) { cancel() - _, err := executeAndPoll(ctx, mockAPI, "wh-123", "SELECT 1") + _, err := executeAndPoll(ctx, mockAPI, "wh-123", "SELECT 1", nil) require.ErrorIs(t, err, root.ErrAlreadyPrinted) } diff --git a/experimental/aitools/cmd/statement_submit.go b/experimental/aitools/cmd/statement_submit.go index ac8bf424e5f..c578590b50f 100644 --- a/experimental/aitools/cmd/statement_submit.go +++ b/experimental/aitools/cmd/statement_submit.go @@ -14,10 +14,12 @@ import ( func newStatementSubmitCmd() *cobra.Command { var warehouseID string var filePath string + var paramFlags []string // resolved by PreRunE so input validation runs before any auth/profile // work and the documented "validates input before WorkspaceClient" claim // in the PR description is actually true. var sqlStatement string + var params []sql.StatementParameterListItem cmd := &cobra.Command{ Use: "submit [SQL | file.sql]", @@ -27,9 +29,14 @@ statement_id immediately, without waiting for results. The statement keeps running server-side. Harvest results with 'statement get ', inspect with 'statement status ', or stop -with 'statement cancel '.`, +with 'statement cancel '. + +Pass named parameters with --param. Use ":name" markers in the SQL and +"--param name=value" (string) or "--param name:TYPE=value" (typed) to +bind values.`, Example: ` databricks experimental aitools tools statement submit "SELECT pg_sleep(60)" --warehouse - databricks experimental aitools tools statement submit --file query.sql`, + databricks experimental aitools tools statement submit --file query.sql + databricks experimental aitools tools statement submit --param since:DATE=2026-01-01 "SELECT * FROM events WHERE ts > :since"`, Args: cobra.MaximumNArgs(1), PreRunE: func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() @@ -47,6 +54,11 @@ with 'statement cancel '.`, } sqlStatement = sqls[0] + params, err = parseParams(paramFlags) + if err != nil { + return err + } + return root.MustWorkspaceClient(cmd, args) }, RunE: func(cmd *cobra.Command, args []string) error { @@ -57,7 +69,7 @@ with 'statement cancel '.`, return err } - info, err := submitStatement(ctx, w.StatementExecution, sqlStatement, wID) + info, err := submitStatement(ctx, w.StatementExecution, sqlStatement, wID, params) if err != nil { return err } @@ -67,15 +79,17 @@ with 'statement cancel '.`, cmd.Flags().StringVarP(&warehouseID, "warehouse", "w", "", "SQL warehouse ID to use for execution") cmd.Flags().StringVarP(&filePath, "file", "f", "", "Path to a SQL file to execute") + cmd.Flags().StringArrayVar(¶mFlags, "param", nil, "Named parameter, repeatable. Format: name=value (STRING) or name:TYPE=value (e.g. name:DATE=2026-01-01). Empty value is sent as NULL.") return cmd } // submitStatement issues an asynchronous ExecuteStatement and returns the handle. -func submitStatement(ctx context.Context, api sql.StatementExecutionInterface, statement, warehouseID string) (statementInfo, error) { +func submitStatement(ctx context.Context, api sql.StatementExecutionInterface, statement, warehouseID string, params []sql.StatementParameterListItem) (statementInfo, error) { resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{ WarehouseId: warehouseID, Statement: statement, + Parameters: params, WaitTimeout: "0s", OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue, }) diff --git a/experimental/aitools/cmd/statement_test.go b/experimental/aitools/cmd/statement_test.go index 9c2264daf2c..ff1e9fd4b25 100644 --- a/experimental/aitools/cmd/statement_test.go +++ b/experimental/aitools/cmd/statement_test.go @@ -29,13 +29,33 @@ func TestSubmitStatementReturnsHandle(t *testing.T) { Status: &sql.StatementStatus{State: sql.StatementStatePending}, }, nil).Once() - info, err := submitStatement(ctx, mockAPI, "SELECT 1", "wh-1") + info, err := submitStatement(ctx, mockAPI, "SELECT 1", "wh-1", nil) require.NoError(t, err) assert.Equal(t, "stmt-1", info.StatementID) assert.Equal(t, sql.StatementStatePending, info.State) assert.Equal(t, "wh-1", info.WarehouseID) } +func TestSubmitStatementPassesParameters(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + params := []sql.StatementParameterListItem{ + {Name: "since", Type: "DATE", Value: "2026-01-01"}, + } + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return assert.ObjectsAreEqual(params, req.Parameters) + })).Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + }, nil).Once() + + info, err := submitStatement(ctx, mockAPI, "SELECT * FROM events WHERE ts > :since", "wh-1", params) + require.NoError(t, err) + assert.Equal(t, "stmt-1", info.StatementID) +} + func TestSubmitStatementWrapsTransportError(t *testing.T) { ctx := cmdio.MockDiscard(t.Context()) mockAPI := mocksql.NewMockStatementExecutionInterface(t) @@ -43,7 +63,7 @@ func TestSubmitStatementWrapsTransportError(t *testing.T) { mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.Anything). Return(nil, errors.New("network unreachable")).Once() - _, err := submitStatement(ctx, mockAPI, "SELECT 1", "wh-1") + _, err := submitStatement(ctx, mockAPI, "SELECT 1", "wh-1", nil) require.Error(t, err) assert.Contains(t, err.Error(), "execute statement") assert.Contains(t, err.Error(), "network unreachable") From ecc1c173f60278d8d0689305012f9fdcf41d0d89 Mon Sep 17 00:00:00 2001 From: simon Date: Wed, 27 May 2026 13:08:53 +0200 Subject: [PATCH 2/2] aitools: validate query params before workspace setup --- experimental/aitools/cmd/query.go | 16 ++++++++++------ experimental/aitools/cmd/query_test.go | 10 ++++++++++ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/experimental/aitools/cmd/query.go b/experimental/aitools/cmd/query.go index 7f4818afeb1..873932b5158 100644 --- a/experimental/aitools/cmd/query.go +++ b/experimental/aitools/cmd/query.go @@ -69,6 +69,7 @@ func newQueryCmd() *cobra.Command { var outputFormat string var concurrency int var paramFlags []string + var params []sql.StatementParameterListItem cmd := &cobra.Command{ Use: "query [SQL | file.sql]...", @@ -96,7 +97,8 @@ interactive table browser. Use --output csv to export results as CSV. Pass named parameters with --param. Use ":name" markers in the SQL and "--param name=value" (string) or "--param name:TYPE=value" (typed, e.g. -DATE, INT) to bind values. Positional "?" markers are not supported.`, +DATE, INT) to bind values. Positional "?" markers are not supported. In +multi-query mode, the same parameter set is applied to every statement.`, Example: ` databricks experimental aitools tools query "SELECT * FROM samples.nyctaxi.trips LIMIT 5" databricks experimental aitools tools query --warehouse abc123 "SELECT 1" databricks experimental aitools tools query --file report.sql @@ -111,6 +113,13 @@ DATE, INT) to bind values. Positional "?" markers are not supported.`, if concurrency <= 0 { return errInvalidBatchConcurrency } + + var err error + params, err = parseParams(paramFlags) + if err != nil { + return err + } + return root.MustWorkspaceClient(cmd, args) }, RunE: func(cmd *cobra.Command, args []string) error { @@ -132,11 +141,6 @@ DATE, INT) to bind values. Positional "?" markers are not supported.`, return err } - params, err := parseParams(paramFlags) - if err != nil { - return err - } - // Reject incompatible flag combinations before any API call so the // user sees the real error instead of an auth/warehouse failure. if len(sqls) > 1 && format != sqlcli.OutputJSON { diff --git a/experimental/aitools/cmd/query_test.go b/experimental/aitools/cmd/query_test.go index eaaa50b4967..50197e12d92 100644 --- a/experimental/aitools/cmd/query_test.go +++ b/experimental/aitools/cmd/query_test.go @@ -646,6 +646,16 @@ func TestQueryCommandConcurrencyRejection(t *testing.T) { } } +func TestQueryCommandRejectsInvalidParamBeforeWorkspaceClient(t *testing.T) { + // Parameter validation runs in PreRunE before MustWorkspaceClient, so a + // malformed flag returns an actionable error without auth/profile work. + cmd := newQueryCmd() + cmd.SetArgs([]string{"--param", "bad", "SELECT 1"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "expected name=value") +} + func TestQueryCommandOutputFlagIsCaseInsensitive(t *testing.T) { cmd := newQueryCmd() cmd.PreRunE = nil