Skip to content

Commit a5dbf82

Browse files
committed
Merge branch 'main' into shuowei-anywidget-extraneous-output
2 parents 402b88f + f93911c commit a5dbf82

50 files changed

Lines changed: 746 additions & 3451 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

packages/bigframes/bigframes/bigquery/_operations/ai.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from bigframes import dataframe, dtypes, series, session
2727
from bigframes import pandas as bpd
28+
from bigframes.bigquery._operations import obj as bq_obj
2829
from bigframes.bigquery._operations import utils as bq_utils
2930
from bigframes.core import convert
3031
from bigframes.core.compile.sqlglot import sql as sg_sql
@@ -842,9 +843,12 @@ def classify(
842843
input: PROMPT_TYPE,
843844
categories: tuple[str, ...] | list[str],
844845
*,
845-
examples: list[tuple[str, str]] | None = None,
846+
examples: list[tuple[str, str]]
847+
| list[tuple[str, list[str] | tuple[str, ...]]]
848+
| None = None,
846849
connection_id: str | None = None,
847850
endpoint: str | None = None,
851+
output_mode: Literal["single", "multi"] | None = None,
848852
optimization_mode: Literal["minimize_cost", "maximize_quality"] | None = None,
849853
max_error_ratio: float | None = None,
850854
) -> series.Series:
@@ -870,17 +874,21 @@ def classify(
870874
or pandas Series.
871875
categories (tuple[str, ...] | list[str]):
872876
Categories to classify the input into.
873-
examples (list[tuple[str, str]], optional):
877+
examples (list[tuple[str, str]] | list[tuple[str, list[str] | tuple[str, ...]]], optional):
874878
An array that contains representative examples of input strings and the output category
875-
that you expect. You can provide examples to help the model understand your
876-
intended threshold for a condition with nuanced or subjective logic. We recommend providing at most 5 examples.
879+
that you expect. If ``output_mode`` is ``multi``, each example output must be a list or tuple of strings.
880+
You can provide examples to help the model understand your intended threshold for a condition with nuanced
881+
or subjective logic. We recommend providing at most 5 examples.
877882
connection_id (str, optional):
878883
Specifies the connection to use to communicate with the model. For example, ``myproject.us.myconnection``.
879884
If not provided, the query uses your end-user credential.
880885
endpoint (str, optional):
881886
A STRING value that specifies the Vertex AI endpoint to use for the model. You can specify any
882887
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically
883888
identifies and uses the full endpoint of the model.
889+
output_mode (Literal["single", "multi"], optional):
890+
A STRING value that indicates whether a single input can be classified into multiple categories.
891+
Supported values are ``single`` and ``multi``.
884892
optimization_mode (Literal["minimize_cost", "maximize_quality"], optional):
885893
A STRING value that specifies the optimization strategy to use. Supported values are ``minimize_cost``
886894
and ``maximize_quality``.
@@ -890,20 +898,27 @@ def classify(
890898
This argument isn't supported when ``optimization_mode`` is set to ``minimize_cost``.
891899
892900
Returns:
893-
bigframes.series.Series: A new series of strings.
901+
bigframes.series.Series: A new series of strings (or a series of arrays of strings if ``output_mode`` is specified).
894902
"""
895903

896904
prompt_context, series_list = _separate_context_and_series(input)
897905
assert len(series_list) > 0
898906

899-
example_tuples = tuple(examples) if examples is not None else None
907+
if examples is not None:
908+
example_tuples: Any = tuple(
909+
(ex[0], tuple(ex[1]) if isinstance(ex[1], (list, tuple)) else ex[1])
910+
for ex in examples
911+
)
912+
else:
913+
example_tuples = None
900914

901915
operator = ai_ops.AIClassify(
902916
prompt_context=tuple(prompt_context),
903917
categories=tuple(categories),
904918
examples=example_tuples,
905919
connection_id=connection_id,
906920
endpoint=endpoint,
921+
output_mode=output_mode,
907922
optimization_mode=_upper_optional(optimization_mode),
908923
max_error_ratio=max_error_ratio,
909924
)
@@ -1169,7 +1184,7 @@ def _separate_context_and_series(
11691184
if isinstance(prompt, series.Series):
11701185
if prompt.dtype == dtypes.OBJ_REF_DTYPE:
11711186
# Multi-model support
1172-
return [None], [prompt.blob.read_url()]
1187+
return [None], [bq_obj.get_access_url(prompt, mode="R")]
11731188
return [None], [prompt]
11741189

11751190
prompt_context: List[str | None] = []
@@ -1206,7 +1221,7 @@ def _convert_series(
12061221

12071222
if result.dtype == dtypes.OBJ_REF_DTYPE:
12081223
# Support multimodal
1209-
return result.blob.read_url()
1224+
return bq_obj.get_access_url(result, mode="R")
12101225
return result
12111226

12121227

0 commit comments

Comments
 (0)