Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,23 @@ import (
"fmt"
"net/url"
"strings"
"time"

"cloud.google.com/go/bigquery"
"github.com/sirupsen/logrus"
"google.golang.org/api/option"
)

const (
// accountIDParam is the DSN query parameter
// name for the BigQuery job label account ID.
accountIDParam = "account_id"

// defaultAccountID is used when the account_id
// parameter is not set in the DSN.
defaultAccountID = "UNSPECIFIED"
)

type BigQueryDriver struct {
}

Expand All @@ -21,6 +33,13 @@ type bigQueryConfig struct {
scopes []string
endpoint string
disableAuth bool
accountID string
// jobTimeout is the server-side timeout for BQ
// jobs. It applies only to job execution time,
// not queue/pending time. Set via the
// "job_timeout" DSN query parameter (e.g.,
// ?job_timeout=5m). BQ floors minimum to 1s.
jobTimeout time.Duration
}

func (b BigQueryDriver) Open(uri string) (driver.Conn, error) {
Expand Down Expand Up @@ -84,12 +103,30 @@ func configFromUri(uri string) (*bigQueryConfig, error) {
datasetName = fields[len(fields)-1]
}

accountID := u.Query().Get(accountIDParam)
if accountID == "" {
accountID = defaultAccountID
}

config := &bigQueryConfig{
projectID: u.Hostname(),
dataSet: datasetName,
scopes: getScopes(u.Query()),
endpoint: u.Query().Get("endpoint"),
disableAuth: u.Query().Get("disable_auth") == "true",
accountID: accountID,
}

if v := u.Query().Get("job_timeout"); v != "" {
d, err := time.ParseDuration(v)
if err != nil {
logrus.Warnf(
"bq driver: invalid job_timeout %q: %v",
v, err,
)
} else {
config.jobTimeout = d
}
}

if len(fields) == 2 {
Expand Down
33 changes: 31 additions & 2 deletions driver/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,22 @@ func (statement *bigQueryStatement) ExecContext(ctx context.Context, args []driv
return nil, err
}

rowIterator, err := query.Read(ctx)
// Split into Run + Read (instead of query.Read)
// to get a job handle for server-side cancellation
// when the caller's context is done.
job, err := query.Run(ctx)
if err != nil {
return nil, err
}

rowIterator, err := job.Read(ctx)
if err != nil {
if ctx.Err() != nil {
go job.Cancel(context.Background())
}
return nil, err
}

return &bigQueryResult{rowIterator}, nil
}

Expand Down Expand Up @@ -88,8 +99,19 @@ func (statement *bigQueryStatement) QueryContext(ctx context.Context, args []dri
return nil, err
}

rowIterator, err := query.Read(context.Background())
// Split into Run + Read (instead of query.Read)
// to get a job handle for server-side cancellation
// when the caller's context is done.
job, err := query.Run(ctx)
if err != nil {
return nil, err
}

rowIterator, err := job.Read(ctx)
if err != nil {
if ctx.Err() != nil {
go job.Cancel(context.Background())
}
return nil, err
}

Expand Down Expand Up @@ -151,6 +173,13 @@ func (statement bigQueryStatement) buildQuery(args []driver.Value) (*bigquery.Qu
return nil, err
}
query.DefaultDatasetID = statement.connection.config.dataSet
query.Labels = map[string]string{
accountIDParam: statement.connection.config.accountID,
}
if statement.connection.config.jobTimeout > 0 {
query.JobTimeout = statement.connection.config.jobTimeout
}

query.Parameters, err = statement.buildParameters(args)
if err != nil {
return nil, err
Expand Down