2525
2626from bigframes import dataframe , dtypes , series , session
2727from bigframes import pandas as bpd
28+ from bigframes .bigquery ._operations import obj as bq_obj
2829from bigframes .bigquery ._operations import utils as bq_utils
2930from bigframes .core import convert
3031from bigframes .core .compile .sqlglot import sql as sg_sql
@@ -842,9 +843,12 @@ def classify(
842843 input : PROMPT_TYPE ,
843844 categories : tuple [str , ...] | list [str ],
844845 * ,
845- examples : list [tuple [str , str ]] | None = None ,
846+ examples : list [tuple [str , str ]]
847+ | list [tuple [str , list [str ] | tuple [str , ...]]]
848+ | None = None ,
846849 connection_id : str | None = None ,
847850 endpoint : str | None = None ,
851+ output_mode : Literal ["single" , "multi" ] | None = None ,
848852 optimization_mode : Literal ["minimize_cost" , "maximize_quality" ] | None = None ,
849853 max_error_ratio : float | None = None ,
850854) -> series .Series :
@@ -870,17 +874,21 @@ def classify(
870874 or pandas Series.
871875 categories (tuple[str, ...] | list[str]):
872876 Categories to classify the input into.
873- examples (list[tuple[str, str]], optional):
877+ examples (list[tuple[str, str]] | list[tuple[str, list[str] | tuple[str, ...]]] , optional):
874878 An array that contains representative examples of input strings and the output category
875- that you expect. You can provide examples to help the model understand your
876- intended threshold for a condition with nuanced or subjective logic. We recommend providing at most 5 examples.
879+ that you expect. If ``output_mode`` is ``multi``, each example output must be a list or tuple of strings.
880+ You can provide examples to help the model understand your intended threshold for a condition with nuanced
881+ or subjective logic. We recommend providing at most 5 examples.
877882 connection_id (str, optional):
878883 Specifies the connection to use to communicate with the model. For example, ``myproject.us.myconnection``.
879884 If not provided, the query uses your end-user credential.
880885 endpoint (str, optional):
881886 A STRING value that specifies the Vertex AI endpoint to use for the model. You can specify any
882887 generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically
883888 identifies and uses the full endpoint of the model.
889+ output_mode (Literal["single", "multi"], optional):
890+ A STRING value that indicates whether a single input can be classified into multiple categories.
891+ Supported values are ``single`` and ``multi``.
884892 optimization_mode (Literal["minimize_cost", "maximize_quality"], optional):
885893 A STRING value that specifies the optimization strategy to use. Supported values are ``minimize_cost``
886894 and ``maximize_quality``.
@@ -890,20 +898,27 @@ def classify(
890898 This argument isn't supported when ``optimization_mode`` is set to ``minimize_cost``.
891899
892900 Returns:
893- bigframes.series.Series: A new series of strings.
901+ bigframes.series.Series: A new series of strings (or a series of arrays of strings if ``output_mode`` is specified) .
894902 """
895903
896904 prompt_context , series_list = _separate_context_and_series (input )
897905 assert len (series_list ) > 0
898906
899- example_tuples = tuple (examples ) if examples is not None else None
907+ if examples is not None :
908+ example_tuples : Any = tuple (
909+ (ex [0 ], tuple (ex [1 ]) if isinstance (ex [1 ], (list , tuple )) else ex [1 ])
910+ for ex in examples
911+ )
912+ else :
913+ example_tuples = None
900914
901915 operator = ai_ops .AIClassify (
902916 prompt_context = tuple (prompt_context ),
903917 categories = tuple (categories ),
904918 examples = example_tuples ,
905919 connection_id = connection_id ,
906920 endpoint = endpoint ,
921+ output_mode = output_mode ,
907922 optimization_mode = _upper_optional (optimization_mode ),
908923 max_error_ratio = max_error_ratio ,
909924 )
@@ -1169,7 +1184,7 @@ def _separate_context_and_series(
11691184 if isinstance (prompt , series .Series ):
11701185 if prompt .dtype == dtypes .OBJ_REF_DTYPE :
11711186 # Multi-model support
1172- return [None ], [prompt . blob . read_url ( )]
1187+ return [None ], [bq_obj . get_access_url ( prompt , mode = "R" )]
11731188 return [None ], [prompt ]
11741189
11751190 prompt_context : List [str | None ] = []
@@ -1206,7 +1221,7 @@ def _convert_series(
12061221
12071222 if result .dtype == dtypes .OBJ_REF_DTYPE :
12081223 # Support multimodal
1209- return result . blob . read_url ( )
1224+ return bq_obj . get_access_url ( result , mode = "R" )
12101225 return result
12111226
12121227
0 commit comments