From 833e85aa25cdeba8f02f43f8eb36127be2273926 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Wed, 13 May 2026 20:13:38 +0000 Subject: [PATCH 1/3] feat(bigframes): support output_mode for ai.classify --- .../bigframes/bigquery/_operations/ai.py | 26 ++++++++++---- .../ibis_compiler/scalar_op_registry.py | 14 ++++---- .../compile/sqlglot/expressions/ai_ops.py | 13 ++++--- .../bigframes/bigframes/operations/ai_ops.py | 7 +++- .../tests/system/small/bigquery/test_ai.py | 9 +++++ .../out.sql | 8 +++++ .../test_ai_classify_with_output_mode/out.sql | 7 ++++ .../sqlglot/expressions/test_ai_ops.py | 35 +++++++++++++++++++ .../ibis/expr/operations/ai_ops.py | 3 ++ 9 files changed, 103 insertions(+), 19 deletions(-) create mode 100644 packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify_multi_with_list_examples/out.sql create mode 100644 packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify_with_output_mode/out.sql diff --git a/packages/bigframes/bigframes/bigquery/_operations/ai.py b/packages/bigframes/bigframes/bigquery/_operations/ai.py index d92cf3fbd978..4eb24f25206f 100644 --- a/packages/bigframes/bigframes/bigquery/_operations/ai.py +++ b/packages/bigframes/bigframes/bigquery/_operations/ai.py @@ -842,9 +842,12 @@ def classify( input: PROMPT_TYPE, categories: tuple[str, ...] | list[str], *, - examples: list[tuple[str, str]] | None = None, + examples: list[tuple[str, str]] + | list[tuple[str, list[str] | tuple[str, ...]]] + | None = None, connection_id: str | None = None, endpoint: str | None = None, + output_mode: Literal["single", "multi"] | None = None, optimization_mode: Literal["minimize_cost", "maximize_quality"] | None = None, max_error_ratio: float | None = None, ) -> series.Series: @@ -870,10 +873,11 @@ def classify( or pandas Series. categories (tuple[str, ...] | list[str]): Categories to classify the input into. - examples (list[tuple[str, str]], optional): + examples (list[tuple[str, str]] | list[tuple[str, list[str] | tuple[str, ...]]], optional): An array that contains representative examples of input strings and the output category - that you expect. You can provide examples to help the model understand your - intended threshold for a condition with nuanced or subjective logic. We recommend providing at most 5 examples. + that you expect. If ``output_mode`` is ``multi``, each example output must be a list or tuple of strings. + You can provide examples to help the model understand your intended threshold for a condition with nuanced + or subjective logic. We recommend providing at most 5 examples. connection_id (str, optional): Specifies the connection to use to communicate with the model. For example, ``myproject.us.myconnection``. If not provided, the query uses your end-user credential. @@ -881,6 +885,9 @@ def classify( A STRING value that specifies the Vertex AI endpoint to use for the model. You can specify any generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and uses the full endpoint of the model. + output_mode (Literal["single", "multi"], optional): + A STRING value that indicates whether a single input can be classified into multiple categories. + Supported values are ``single`` and ``multi``. optimization_mode (Literal["minimize_cost", "maximize_quality"], optional): A STRING value that specifies the optimization strategy to use. Supported values are ``minimize_cost`` and ``maximize_quality``. @@ -890,13 +897,19 @@ def classify( This argument isn't supported when ``optimization_mode`` is set to ``minimize_cost``. Returns: - bigframes.series.Series: A new series of strings. + bigframes.series.Series: A new series of strings (or a series of arrays of strings if ``output_mode`` is specified). """ prompt_context, series_list = _separate_context_and_series(input) assert len(series_list) > 0 - example_tuples = tuple(examples) if examples is not None else None + if examples is not None: + example_tuples = tuple( + (ex[0], tuple(ex[1]) if isinstance(ex[1], (list, tuple)) else ex[1]) + for ex in examples + ) + else: + example_tuples = None operator = ai_ops.AIClassify( prompt_context=tuple(prompt_context), @@ -904,6 +917,7 @@ def classify( examples=example_tuples, connection_id=connection_id, endpoint=endpoint, + output_mode=output_mode, optimization_mode=_upper_optional(optimization_mode), max_error_ratio=max_error_ratio, ) diff --git a/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py b/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py index 8116c4864599..64c4e75a85e5 100644 --- a/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py +++ b/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py @@ -1999,6 +1999,7 @@ def ai_classify( _construct_examples(op.examples), # type: ignore op.connection_id, # type: ignore op.endpoint, # type: ignore + op.output_mode, # type: ignore op.optimization_mode, # type: ignore op.max_error_ratio, # type: ignore ).to_expr() @@ -2045,7 +2046,7 @@ def _construct_prompt( def _construct_examples( - examples: tuple[tuple[str, str]] | None, + examples: tuple[tuple[str, str | tuple[str, ...]], ...] | None, ) -> ibis_types.ArrayValue | None: if examples is None: return None @@ -2053,12 +2054,11 @@ def _construct_examples( results: list[ibis_types.StructValue] = [] for example in examples: - ibis_example = ibis.struct( - { - "_field_1": example[0], - "_field_2": example[1], - } - ) + value = example[1] + if isinstance(example[1], (list, tuple)): + value = list(example[1]) + + ibis_example = ibis.struct({"_field_1": example[0], "_field_2": value}) results.append(ibis_example) return ibis.array(results) diff --git a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py index ed1c59f92e66..7863d61efc10 100644 --- a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py +++ b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py @@ -140,12 +140,15 @@ def _construct_named_args(op: ops.ScalarOp) -> list[sge.Kwarg]: ) ) elif field == "examples": - example_expressions = [ - sge.Tuple( - expressions=[sge.Literal.string(key), sge.Literal.string(val)] + example_expressions = [] + for key, val in value: + if isinstance(val, (list, tuple)): + val_expr = sge.array(*[sge.Literal.string(v) for v in val]) + else: + val_expr = sge.Literal.string(val) + example_expressions.append( + sge.Tuple(expressions=[sge.Literal.string(key), val_expr]) ) - for key, val in value - ] args.append( sge.Kwarg(this=field, expression=sge.array(*example_expressions)) ) diff --git a/packages/bigframes/bigframes/operations/ai_ops.py b/packages/bigframes/bigframes/operations/ai_ops.py index 9d1fdbc4e130..ad2b9850577e 100644 --- a/packages/bigframes/bigframes/operations/ai_ops.py +++ b/packages/bigframes/bigframes/operations/ai_ops.py @@ -160,13 +160,18 @@ class AIClassify(base_ops.NaryOp): prompt_context: Tuple[str | None, ...] categories: tuple[str, ...] - examples: tuple[tuple[str, str], ...] | None = None + examples: ( + tuple[tuple[str, str], ...] | tuple[tuple[str, tuple[str, ...]], ...] | None + ) = None connection_id: str | None = None endpoint: str | None = None + output_mode: str | None = None optimization_mode: str | None = None max_error_ratio: float | None = None def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + if self.output_mode is not None: + return dtypes.list_type(dtypes.STRING_DTYPE) return dtypes.STRING_DTYPE diff --git a/packages/bigframes/tests/system/small/bigquery/test_ai.py b/packages/bigframes/tests/system/small/bigquery/test_ai.py index 39c74d42b2dc..fdfcefb5f4cf 100644 --- a/packages/bigframes/tests/system/small/bigquery/test_ai.py +++ b/packages/bigframes/tests/system/small/bigquery/test_ai.py @@ -336,6 +336,15 @@ def test_ai_classify_with_examples(session): assert result.dtype == dtypes.STRING_DTYPE +def test_ai_classify_output_mode(session, bq_connection): + s = bpd.Series(["cat", "orchid"], session=session) + + result = bbq.ai.classify(s, ["animal", "plant"], output_mode="multi", examples=[("dog", ["animal"])]) + + assert len(result) == len(s) + assert result.dtype == dtypes.list_type(dtypes.STRING_DTYPE) + + def test_ai_classify_multi_model(session, bq_connection): df = session.from_glob_path( "gs://bigframes-dev-testing/a_multimodal/images/*", diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify_multi_with_list_examples/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify_multi_with_list_examples/out.sql new file mode 100644 index 000000000000..a4a7f783da97 --- /dev/null +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify_multi_with_list_examples/out.sql @@ -0,0 +1,8 @@ +SELECT + AI.CLASSIFY( + input => (`string_col`), + categories => ['greeting', 'rejection'], + examples => [('hi', ['greeting', 'positive']), ('bye', ['rejection', 'negative'])], + output_mode => 'multi' + ) AS `result` +FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify_with_output_mode/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify_with_output_mode/out.sql new file mode 100644 index 000000000000..fb3c6af8b0b0 --- /dev/null +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify_with_output_mode/out.sql @@ -0,0 +1,7 @@ +SELECT + AI.CLASSIFY( + input => (`string_col`), + categories => ['greeting', 'rejection'], + output_mode => 'multi' + ) AS `result` +FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py index d2797cd3eda1..57c524908607 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py @@ -363,6 +363,41 @@ def test_ai_classify_with_params(scalar_types_df: dataframe.DataFrame, snapshot) snapshot.assert_match(sql, "out.sql") +def test_ai_classify_with_output_mode(scalar_types_df: dataframe.DataFrame, snapshot): + col_name = "string_col" + + op = ops.AIClassify( + prompt_context=(None,), + categories=("greeting", "rejection"), + output_mode="multi", + ) + + sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"]) + + snapshot.assert_match(sql, "out.sql") + + +def test_ai_classify_multi_with_list_examples( + scalar_types_df: dataframe.DataFrame, snapshot +): + col_name = "string_col" + + examples = ( + ("hi", ("greeting", "positive")), + ("bye", ("rejection", "negative")), + ) + op = ops.AIClassify( + prompt_context=(None,), + categories=("greeting", "rejection"), + examples=examples, + output_mode="multi", + ) + + sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"]) + + snapshot.assert_match(sql, "out.sql") + + @pytest.mark.parametrize("connection_id", [None, CONNECTION_ID]) def test_ai_score(scalar_types_df: dataframe.DataFrame, snapshot, connection_id): col_name = "string_col" diff --git a/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py b/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py index 7a6a31c4b72a..fcd97c6f61b2 100644 --- a/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py +++ b/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py @@ -158,6 +158,7 @@ class AIClassify(Value): examples: Optional[Value] connection_id: Optional[Value[dt.String]] endpoint: Optional[Value[dt.String]] + output_mode: Optional[Value[dt.String]] optimization_mode: Optional[Value[dt.String]] max_error_ratio: Optional[Value[dt.Float64]] @@ -165,6 +166,8 @@ class AIClassify(Value): @attribute def dtype(self) -> dt.DataType: + if self.output_mode is not None: + return dt.Array(dt.string) return dt.string From a390803937c8c028a3afc8b39399765bf37f310b Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Wed, 13 May 2026 21:46:47 +0000 Subject: [PATCH 2/3] fix lint --- packages/bigframes/tests/system/small/bigquery/test_ai.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/packages/bigframes/tests/system/small/bigquery/test_ai.py b/packages/bigframes/tests/system/small/bigquery/test_ai.py index fdfcefb5f4cf..9dc0461a2448 100644 --- a/packages/bigframes/tests/system/small/bigquery/test_ai.py +++ b/packages/bigframes/tests/system/small/bigquery/test_ai.py @@ -339,7 +339,9 @@ def test_ai_classify_with_examples(session): def test_ai_classify_output_mode(session, bq_connection): s = bpd.Series(["cat", "orchid"], session=session) - result = bbq.ai.classify(s, ["animal", "plant"], output_mode="multi", examples=[("dog", ["animal"])]) + result = bbq.ai.classify( + s, ["animal", "plant"], output_mode="multi", examples=[("dog", ["animal"])] + ) assert len(result) == len(s) assert result.dtype == dtypes.list_type(dtypes.STRING_DTYPE) From f858cd45c5977f36c03c6c9e0f2f94afb5f09ce0 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Wed, 13 May 2026 22:07:54 +0000 Subject: [PATCH 3/3] fix mypy --- packages/bigframes/bigframes/bigquery/_operations/ai.py | 2 +- .../core/compile/ibis_compiler/scalar_op_registry.py | 4 ++-- .../bigframes/core/compile/sqlglot/expressions/ai_ops.py | 4 +++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/packages/bigframes/bigframes/bigquery/_operations/ai.py b/packages/bigframes/bigframes/bigquery/_operations/ai.py index 4eb24f25206f..8382cb613bbb 100644 --- a/packages/bigframes/bigframes/bigquery/_operations/ai.py +++ b/packages/bigframes/bigframes/bigquery/_operations/ai.py @@ -904,7 +904,7 @@ def classify( assert len(series_list) > 0 if examples is not None: - example_tuples = tuple( + example_tuples: Any = tuple( (ex[0], tuple(ex[1]) if isinstance(ex[1], (list, tuple)) else ex[1]) for ex in examples ) diff --git a/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py b/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py index 64c4e75a85e5..926f220370b0 100644 --- a/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py +++ b/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py @@ -16,7 +16,7 @@ import functools import typing -from typing import cast +from typing import cast, Any import bigframes_vendored.ibis.expr.api as ibis_api import bigframes_vendored.ibis.expr.datatypes as ibis_dtypes @@ -2054,7 +2054,7 @@ def _construct_examples( results: list[ibis_types.StructValue] = [] for example in examples: - value = example[1] + value: Any = example[1] if isinstance(example[1], (list, tuple)): value = list(example[1]) diff --git a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py index 7863d61efc10..12a6b9859a2a 100644 --- a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py +++ b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py @@ -143,7 +143,9 @@ def _construct_named_args(op: ops.ScalarOp) -> list[sge.Kwarg]: example_expressions = [] for key, val in value: if isinstance(val, (list, tuple)): - val_expr = sge.array(*[sge.Literal.string(v) for v in val]) + val_expr: sge.Array | sge.Literal = sge.array( + *[sge.Literal.string(v) for v in val] + ) else: val_expr = sge.Literal.string(val) example_expressions.append(