Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions packages/bigframes/bigframes/bigquery/_operations/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,9 +842,12 @@ def classify(
input: PROMPT_TYPE,
categories: tuple[str, ...] | list[str],
*,
examples: list[tuple[str, str]] | None = None,
examples: list[tuple[str, str]]
| list[tuple[str, list[str] | tuple[str, ...]]]
| None = None,
connection_id: str | None = None,
endpoint: str | None = None,
output_mode: Literal["single", "multi"] | None = None,
optimization_mode: Literal["minimize_cost", "maximize_quality"] | None = None,
max_error_ratio: float | None = None,
) -> series.Series:
Expand All @@ -870,17 +873,21 @@ def classify(
or pandas Series.
categories (tuple[str, ...] | list[str]):
Categories to classify the input into.
examples (list[tuple[str, str]], optional):
examples (list[tuple[str, str]] | list[tuple[str, list[str] | tuple[str, ...]]], optional):
An array that contains representative examples of input strings and the output category
that you expect. You can provide examples to help the model understand your
intended threshold for a condition with nuanced or subjective logic. We recommend providing at most 5 examples.
that you expect. If ``output_mode`` is ``multi``, each example output must be a list or tuple of strings.
You can provide examples to help the model understand your intended threshold for a condition with nuanced
or subjective logic. We recommend providing at most 5 examples.
connection_id (str, optional):
Specifies the connection to use to communicate with the model. For example, ``myproject.us.myconnection``.
If not provided, the query uses your end-user credential.
endpoint (str, optional):
A STRING value that specifies the Vertex AI endpoint to use for the model. You can specify any
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically
identifies and uses the full endpoint of the model.
output_mode (Literal["single", "multi"], optional):
A STRING value that indicates whether a single input can be classified into multiple categories.
Supported values are ``single`` and ``multi``.
optimization_mode (Literal["minimize_cost", "maximize_quality"], optional):
A STRING value that specifies the optimization strategy to use. Supported values are ``minimize_cost``
and ``maximize_quality``.
Expand All @@ -890,20 +897,27 @@ def classify(
This argument isn't supported when ``optimization_mode`` is set to ``minimize_cost``.

Returns:
bigframes.series.Series: A new series of strings.
bigframes.series.Series: A new series of strings (or a series of arrays of strings if ``output_mode`` is specified).
Comment thread
sycai marked this conversation as resolved.
"""

prompt_context, series_list = _separate_context_and_series(input)
assert len(series_list) > 0

example_tuples = tuple(examples) if examples is not None else None
if examples is not None:
example_tuples: Any = tuple(
(ex[0], tuple(ex[1]) if isinstance(ex[1], (list, tuple)) else ex[1])
for ex in examples
)
else:
example_tuples = None

operator = ai_ops.AIClassify(
prompt_context=tuple(prompt_context),
categories=tuple(categories),
examples=example_tuples,
connection_id=connection_id,
endpoint=endpoint,
output_mode=output_mode,
optimization_mode=_upper_optional(optimization_mode),
max_error_ratio=max_error_ratio,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import functools
import typing
from typing import cast
from typing import cast, Any

import bigframes_vendored.ibis.expr.api as ibis_api
import bigframes_vendored.ibis.expr.datatypes as ibis_dtypes
Expand Down Expand Up @@ -1999,6 +1999,7 @@ def ai_classify(
_construct_examples(op.examples), # type: ignore
op.connection_id, # type: ignore
op.endpoint, # type: ignore
op.output_mode, # type: ignore
op.optimization_mode, # type: ignore
op.max_error_ratio, # type: ignore
).to_expr()
Expand Down Expand Up @@ -2045,20 +2046,19 @@ def _construct_prompt(


def _construct_examples(
examples: tuple[tuple[str, str]] | None,
examples: tuple[tuple[str, str | tuple[str, ...]], ...] | None,
) -> ibis_types.ArrayValue | None:
if examples is None:
return None

results: list[ibis_types.StructValue] = []

for example in examples:
ibis_example = ibis.struct(
{
"_field_1": example[0],
"_field_2": example[1],
}
)
value: Any = example[1]
if isinstance(example[1], (list, tuple)):
value = list(example[1])

ibis_example = ibis.struct({"_field_1": example[0], "_field_2": value})
results.append(ibis_example)

return ibis.array(results)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,17 @@ def _construct_named_args(op: ops.ScalarOp) -> list[sge.Kwarg]:
)
)
elif field == "examples":
example_expressions = [
sge.Tuple(
expressions=[sge.Literal.string(key), sge.Literal.string(val)]
example_expressions = []
for key, val in value:
if isinstance(val, (list, tuple)):
val_expr: sge.Array | sge.Literal = sge.array(
*[sge.Literal.string(v) for v in val]
)
else:
val_expr = sge.Literal.string(val)
example_expressions.append(
sge.Tuple(expressions=[sge.Literal.string(key), val_expr])
)
for key, val in value
]
args.append(
sge.Kwarg(this=field, expression=sge.array(*example_expressions))
)
Expand Down
7 changes: 6 additions & 1 deletion packages/bigframes/bigframes/operations/ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,18 @@ class AIClassify(base_ops.NaryOp):

prompt_context: Tuple[str | None, ...]
categories: tuple[str, ...]
examples: tuple[tuple[str, str], ...] | None = None
examples: (
tuple[tuple[str, str], ...] | tuple[tuple[str, tuple[str, ...]], ...] | None
) = None
connection_id: str | None = None
endpoint: str | None = None
output_mode: str | None = None
optimization_mode: str | None = None
max_error_ratio: float | None = None

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
if self.output_mode is not None:
Comment thread
sycai marked this conversation as resolved.
return dtypes.list_type(dtypes.STRING_DTYPE)
return dtypes.STRING_DTYPE


Expand Down
11 changes: 11 additions & 0 deletions packages/bigframes/tests/system/small/bigquery/test_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,17 @@ def test_ai_classify_with_examples(session):
assert result.dtype == dtypes.STRING_DTYPE


def test_ai_classify_output_mode(session, bq_connection):
s = bpd.Series(["cat", "orchid"], session=session)

result = bbq.ai.classify(
s, ["animal", "plant"], output_mode="multi", examples=[("dog", ["animal"])]
)

assert len(result) == len(s)
assert result.dtype == dtypes.list_type(dtypes.STRING_DTYPE)


def test_ai_classify_multi_model(session, bq_connection):
df = session.from_glob_path(
"gs://bigframes-dev-testing/a_multimodal/images/*",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
SELECT
AI.CLASSIFY(
input => (`string_col`),
categories => ['greeting', 'rejection'],
examples => [('hi', ['greeting', 'positive']), ('bye', ['rejection', 'negative'])],
output_mode => 'multi'
) AS `result`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SELECT
AI.CLASSIFY(
input => (`string_col`),
categories => ['greeting', 'rejection'],
output_mode => 'multi'
) AS `result`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,41 @@ def test_ai_classify_with_params(scalar_types_df: dataframe.DataFrame, snapshot)
snapshot.assert_match(sql, "out.sql")


def test_ai_classify_with_output_mode(scalar_types_df: dataframe.DataFrame, snapshot):
col_name = "string_col"

op = ops.AIClassify(
prompt_context=(None,),
categories=("greeting", "rejection"),
output_mode="multi",
)

sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"])

snapshot.assert_match(sql, "out.sql")


def test_ai_classify_multi_with_list_examples(
scalar_types_df: dataframe.DataFrame, snapshot
):
col_name = "string_col"

examples = (
("hi", ("greeting", "positive")),
("bye", ("rejection", "negative")),
)
op = ops.AIClassify(
prompt_context=(None,),
categories=("greeting", "rejection"),
examples=examples,
output_mode="multi",
)

sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"])

snapshot.assert_match(sql, "out.sql")


@pytest.mark.parametrize("connection_id", [None, CONNECTION_ID])
def test_ai_score(scalar_types_df: dataframe.DataFrame, snapshot, connection_id):
col_name = "string_col"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,16 @@ class AIClassify(Value):
examples: Optional[Value]
connection_id: Optional[Value[dt.String]]
endpoint: Optional[Value[dt.String]]
output_mode: Optional[Value[dt.String]]
optimization_mode: Optional[Value[dt.String]]
max_error_ratio: Optional[Value[dt.Float64]]

shape = rlz.shape_like("input")

@attribute
def dtype(self) -> dt.DataType:
if self.output_mode is not None:
Comment thread
sycai marked this conversation as resolved.
return dt.Array(dt.string)
return dt.string


Expand Down
Loading