Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions experimental/aitools/cmd/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ type batchResultError struct {
// On context cancellation (Ctrl+C or parent context), still-running statements
// are cancelled server-side via CancelExecution. Statements that finished
// before cancellation are left as-is.
func executeBatch(ctx context.Context, api sql.StatementExecutionInterface, warehouseID string, sqls []string, concurrency int) []batchResult {
//
// params, if non-nil, are bound on every statement. The same parameter set is
// reused across the batch, so callers must ensure each SQL uses only markers
// that are covered.
func executeBatch(ctx context.Context, api sql.StatementExecutionInterface, warehouseID string, sqls []string, params []sql.StatementParameterListItem, concurrency int) []batchResult {
pollCtx, pollCancel := context.WithCancel(ctx)
defer pollCancel()

Expand Down Expand Up @@ -97,7 +101,7 @@ func executeBatch(ctx context.Context, api sql.StatementExecutionInterface, ware
g.SetLimit(concurrency)
for i, sqlStr := range sqls {
g.Go(func() error {
results[i] = runOneBatchQuery(pollCtx, api, warehouseID, sqlStr, statementIDs, i)
results[i] = runOneBatchQuery(pollCtx, api, warehouseID, sqlStr, params, statementIDs, i)
completed.Add(1)
return nil
})
Expand All @@ -115,13 +119,14 @@ func executeBatch(ctx context.Context, api sql.StatementExecutionInterface, ware

// runOneBatchQuery submits one SQL, polls to completion, and returns its
// batchResult. All errors are encoded into the result; never returns an error.
func runOneBatchQuery(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, sqlStr string, statementIDs []string, idx int) batchResult {
func runOneBatchQuery(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, sqlStr string, params []sql.StatementParameterListItem, statementIDs []string, idx int) batchResult {
start := time.Now()
result := batchResult{SQL: sqlStr}

resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{
WarehouseId: warehouseID,
Statement: sqlStr,
Parameters: params,
WaitTimeout: "0s",
OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue,
})
Expand Down
31 changes: 25 additions & 6 deletions experimental/aitools/cmd/batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func TestExecuteBatchAllSucceed(t *testing.T) {
}, nil).Once()
}

results := executeBatch(ctx, mockAPI, "wh-123", sqls, 8)
results := executeBatch(ctx, mockAPI, "wh-123", sqls, nil, 8)

require.Len(t, results, 3)
for i, r := range results {
Expand All @@ -86,6 +86,25 @@ func TestExecuteBatchAllSucceed(t *testing.T) {
}
}

func TestExecuteBatchPassesParametersToAllStatements(t *testing.T) {
ctx := cmdio.MockDiscard(t.Context())
mockAPI := mocksql.NewMockStatementExecutionInterface(t)

params := []sql.StatementParameterListItem{
{Name: "since", Type: "DATE", Value: "2026-01-01"},
}

mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool {
return assert.ObjectsAreEqual(params, req.Parameters)
})).Return(&sql.StatementResponse{
StatementId: "stmt",
Status: &sql.StatementStatus{State: sql.StatementStateSucceeded},
}, nil).Times(2)

results := executeBatch(ctx, mockAPI, "wh-1", []string{"SELECT 1 WHERE 1=1 AND :since IS NOT NULL", "SELECT 2 WHERE :since IS NOT NULL"}, params, 8)
require.Len(t, results, 2)
}

func TestExecuteBatchPartialFailure(t *testing.T) {
ctx := cmdio.MockDiscard(t.Context())
mockAPI := mocksql.NewMockStatementExecutionInterface(t)
Expand All @@ -112,7 +131,7 @@ func TestExecuteBatchPartialFailure(t *testing.T) {
},
}, nil).Once()

results := executeBatch(ctx, mockAPI, "wh-123", []string{"SELECT 1", "SELECT bad"}, 8)
results := executeBatch(ctx, mockAPI, "wh-123", []string{"SELECT 1", "SELECT bad"}, nil, 8)

require.Len(t, results, 2)
assert.Nil(t, results[0].Error)
Expand Down Expand Up @@ -141,7 +160,7 @@ func TestExecuteBatchSubmissionFailure(t *testing.T) {
return req.Statement == "SELECT broken"
})).Return(nil, errors.New("network unreachable")).Once()

results := executeBatch(ctx, mockAPI, "wh-123", []string{"SELECT good", "SELECT broken"}, 8)
results := executeBatch(ctx, mockAPI, "wh-123", []string{"SELECT good", "SELECT broken"}, nil, 8)

require.Len(t, results, 2)
assert.Nil(t, results[0].Error)
Expand All @@ -163,7 +182,7 @@ func TestExecuteBatchSetsOnWaitTimeoutContinue(t *testing.T) {
Status: &sql.StatementStatus{State: sql.StatementStateSucceeded},
}, nil).Times(2)

results := executeBatch(ctx, mockAPI, "wh-123", []string{"q1", "q2"}, 8)
results := executeBatch(ctx, mockAPI, "wh-123", []string{"q1", "q2"}, nil, 8)
require.Len(t, results, 2)
}

Expand Down Expand Up @@ -196,7 +215,7 @@ func TestExecuteBatchPreservesInputOrder(t *testing.T) {
}

sqls := []string{"SELECT 'slow'", "SELECT 'fast1'", "SELECT 'fast2'"}
results := executeBatch(ctx, mockAPI, "wh-1", sqls, 8)
results := executeBatch(ctx, mockAPI, "wh-1", sqls, nil, 8)

require.Len(t, results, 3)
for i, r := range results {
Expand Down Expand Up @@ -233,7 +252,7 @@ func TestExecuteBatchContextCancellationCancelsInFlight(t *testing.T) {

cancel()

results := executeBatch(ctx, mockAPI, "wh", []string{"q1", "q2", "q3"}, 8)
results := executeBatch(ctx, mockAPI, "wh", []string{"q1", "q2", "q3"}, nil, 8)

require.Len(t, results, 3)
for i, r := range results {
Expand Down
50 changes: 50 additions & 0 deletions experimental/aitools/cmd/params.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package aitools

import (
"fmt"
"strings"

"github.com/databricks/databricks-sdk-go/service/sql"
)

// parseParams converts --param flag values into SDK parameter list items for
// the Databricks SQL Statement Execution API. Each input is either
// "name=value" (defaults to STRING server-side) or "name:TYPE=value" (typed,
// e.g. "since:DATE=2026-01-01"). An empty value becomes NULL on the wire
// because StatementParameterListItem.Value uses omitempty.
//
// The Databricks API only supports named markers (`:name`), not positional
// `?`, and parameter names must be unique within a statement.
func parseParams(raw []string) ([]sql.StatementParameterListItem, error) {
if len(raw) == 0 {
return nil, nil
}

out := make([]sql.StatementParameterListItem, 0, len(raw))
seen := make(map[string]struct{}, len(raw))
for _, s := range raw {
lhs, value, ok := strings.Cut(s, "=")
if !ok {
return nil, fmt.Errorf("invalid --param %q: expected name=value or name:TYPE=value", s)
}

name, typ, _ := strings.Cut(lhs, ":")
name = strings.TrimSpace(name)
typ = strings.TrimSpace(typ)

if name == "" {
return nil, fmt.Errorf("invalid --param %q: name is empty", s)
}
if _, dup := seen[name]; dup {
return nil, fmt.Errorf("duplicate --param name %q", name)
}
seen[name] = struct{}{}

out = append(out, sql.StatementParameterListItem{
Name: name,
Type: typ,
Value: value,
})
}
return out, nil
}
129 changes: 129 additions & 0 deletions experimental/aitools/cmd/params_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package aitools

import (
"testing"

"github.com/databricks/databricks-sdk-go/service/sql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestParseParams(t *testing.T) {
tests := []struct {
name string
in []string
want []sql.StatementParameterListItem
}{
{
name: "nil input returns nil",
in: nil,
want: nil,
},
{
name: "empty input returns nil",
in: []string{},
want: nil,
},
{
name: "single string param defaults type to empty (server-side STRING)",
in: []string{"name=alice"},
want: []sql.StatementParameterListItem{
{Name: "name", Value: "alice"},
},
},
{
name: "typed param splits name and type on first colon",
in: []string{"since:DATE=2026-01-01"},
want: []sql.StatementParameterListItem{
{Name: "since", Type: "DATE", Value: "2026-01-01"},
},
},
{
name: "value can contain = and :",
in: []string{"clause=ts >= '2026-01-01T00:00:00'"},
want: []sql.StatementParameterListItem{
{Name: "clause", Value: "ts >= '2026-01-01T00:00:00'"},
},
},
{
name: "decimal type with parens preserved",
in: []string{"amount:DECIMAL(10,2)=42.50"},
want: []sql.StatementParameterListItem{
{Name: "amount", Type: "DECIMAL(10,2)", Value: "42.50"},
},
},
{
name: "empty value becomes NULL on the wire via omitempty",
in: []string{"opt="},
want: []sql.StatementParameterListItem{
{Name: "opt", Value: ""},
},
},
{
name: "whitespace around name and type is trimmed",
in: []string{" name : INT =42"},
want: []sql.StatementParameterListItem{
{Name: "name", Type: "INT", Value: "42"},
},
},
{
name: "multiple params preserve input order",
in: []string{"a=1", "b:INT=2", "c=three"},
want: []sql.StatementParameterListItem{
{Name: "a", Value: "1"},
{Name: "b", Type: "INT", Value: "2"},
{Name: "c", Value: "three"},
},
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := parseParams(tc.in)
require.NoError(t, err)
assert.Equal(t, tc.want, got)
})
}
}

func TestParseParamsErrors(t *testing.T) {
tests := []struct {
name string
in []string
wantMsg string
}{
{
name: "no equals sign",
in: []string{"foo"},
wantMsg: "expected name=value",
},
{
name: "empty name",
in: []string{"=value"},
wantMsg: "name is empty",
},
{
name: "empty name with type",
in: []string{":INT=42"},
wantMsg: "name is empty",
},
{
name: "whitespace-only name",
in: []string{" =value"},
wantMsg: "name is empty",
},
{
name: "duplicate name",
in: []string{"name=alice", "name=bob"},
wantMsg: `duplicate --param name "name"`,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := parseParams(tc.in)
require.Error(t, err)
assert.Contains(t, err.Error(), tc.wantMsg)
})
}
}
32 changes: 26 additions & 6 deletions experimental/aitools/cmd/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ func newQueryCmd() *cobra.Command {
var filePaths []string
var outputFormat string
var concurrency int
var paramFlags []string
var params []sql.StatementParameterListItem

cmd := &cobra.Command{
Use: "query [SQL | file.sql]...",
Expand All @@ -91,19 +93,33 @@ or the DATABRICKS_WAREHOUSE_ID environment variable is configured.

For a single query, output is JSON in non-interactive contexts. In
interactive terminals it renders tables, and large results open an
interactive table browser. Use --output csv to export results as CSV.`,
interactive table browser. Use --output csv to export results as CSV.

Pass named parameters with --param. Use ":name" markers in the SQL and
"--param name=value" (string) or "--param name:TYPE=value" (typed, e.g.
DATE, INT) to bind values. Positional "?" markers are not supported. In
multi-query mode, the same parameter set is applied to every statement.`,
Example: ` databricks experimental aitools tools query "SELECT * FROM samples.nyctaxi.trips LIMIT 5"
databricks experimental aitools tools query --warehouse abc123 "SELECT 1"
databricks experimental aitools tools query --file report.sql
databricks experimental aitools tools query report.sql
databricks experimental aitools tools query --output csv "SELECT * FROM samples.nyctaxi.trips LIMIT 5"
databricks experimental aitools tools query --output json "SELECT 1" "SELECT 2" "SELECT 3"
databricks experimental aitools tools query --param name=alice "SELECT * FROM users WHERE name = :name"
databricks experimental aitools tools query --param since:DATE=2026-01-01 "SELECT * FROM events WHERE ts > :since"
echo "SELECT 1" | databricks experimental aitools tools query`,
Args: cobra.ArbitraryArgs,
PreRunE: func(cmd *cobra.Command, args []string) error {
if concurrency <= 0 {
return errInvalidBatchConcurrency
}

var err error
params, err = parseParams(paramFlags)
if err != nil {
return err
}

return root.MustWorkspaceClient(cmd, args)
},
RunE: func(cmd *cobra.Command, args []string) error {
Expand Down Expand Up @@ -139,10 +155,10 @@ interactive table browser. Use --output csv to export results as CSV.`,
}

if len(sqls) > 1 {
return runBatch(ctx, cmd, w.StatementExecution, wID, sqls, concurrency)
return runBatch(ctx, cmd, w.StatementExecution, wID, sqls, params, concurrency)
}

resp, err := executeAndPoll(ctx, w.StatementExecution, wID, sqls[0])
resp, err := executeAndPoll(ctx, w.StatementExecution, wID, sqls[0], params)
if err != nil {
return err
}
Expand Down Expand Up @@ -185,6 +201,7 @@ interactive table browser. Use --output csv to export results as CSV.`,
cmd.Flags().StringVarP(&warehouseID, "warehouse", "w", "", "SQL warehouse ID to use for execution")
cmd.Flags().StringSliceVarP(&filePaths, "file", "f", nil, "Path to a SQL file to execute (repeatable; pair with positional SQLs to run a batch)")
cmd.Flags().IntVar(&concurrency, "concurrency", defaultBatchConcurrency, "Maximum in-flight statements when running a batch of queries")
cmd.Flags().StringArrayVar(&paramFlags, "param", nil, "Named parameter, repeatable. Format: name=value (STRING) or name:TYPE=value (e.g. name:DATE=2026-01-01). Empty value is sent as NULL.")
// Local --output flag shadows the root command's persistent --output flag,
// adding csv support for this command only.
cmd.Flags().StringVarP(&outputFormat, "output", "o", string(sqlcli.OutputText), "Output format: text, json, or csv")
Expand Down Expand Up @@ -222,8 +239,10 @@ func resolveSQLs(ctx context.Context, cmd *cobra.Command, args, filePaths []stri
// without an extra error message) when any statement failed; the failure detail
// is already encoded in the printed JSON. The caller is responsible for
// rejecting incompatible output formats before invoking this.
func runBatch(ctx context.Context, cmd *cobra.Command, api sql.StatementExecutionInterface, warehouseID string, sqls []string, concurrency int) error {
results := executeBatch(ctx, api, warehouseID, sqls, concurrency)
//
// params, if non-nil, are applied to every statement in the batch.
func runBatch(ctx context.Context, cmd *cobra.Command, api sql.StatementExecutionInterface, warehouseID string, sqls []string, params []sql.StatementParameterListItem, concurrency int) error {
results := executeBatch(ctx, api, warehouseID, sqls, params, concurrency)
if err := renderBatchJSON(cmd.OutOrStdout(), results); err != nil {
return err
}
Expand Down Expand Up @@ -252,11 +271,12 @@ func resolveWarehouseID(ctx context.Context, w any, flagValue string) (string, e

// executeAndPoll submits a SQL statement asynchronously and polls until completion.
// It shows a spinner in interactive mode and supports Ctrl+C cancellation.
func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, statement string) (*sql.StatementResponse, error) {
func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, statement string, params []sql.StatementParameterListItem) (*sql.StatementResponse, error) {
// Submit asynchronously to get the statement ID immediately for cancellation.
resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{
WarehouseId: warehouseID,
Statement: statement,
Parameters: params,
WaitTimeout: "0s",
OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue,
})
Expand Down
Loading