diff --git a/packages/bigframes/bigframes/bigquery/_operations/ai.py b/packages/bigframes/bigframes/bigquery/_operations/ai.py index 907d2e462295..82a9be19b7de 100644 --- a/packages/bigframes/bigframes/bigquery/_operations/ai.py +++ b/packages/bigframes/bigframes/bigquery/_operations/ai.py @@ -32,7 +32,7 @@ from bigframes.core.logging import log_adapter from bigframes.ml import base as ml_base from bigframes.ml import core as ml_core -from bigframes.operations import ai_ops, output_schemas +from bigframes.operations import ai_ops, googlesql, output_schemas PROMPT_TYPE = Union[ str, @@ -114,9 +114,6 @@ def generate( * "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful. """ - prompt_context, series_list = _separate_context_and_series(prompt) - assert len(series_list) > 0 - if output_schema is None: output_schema_str = None else: @@ -126,17 +123,21 @@ def generate( # Validate user input output_schemas.parse_sql_fields(output_schema_str) - operator = ai_ops.AIGenerate( - prompt_context=tuple(prompt_context), - connection_id=connection_id, - endpoint=endpoint, - request_type=_upper_optional(request_type), - model_params=json.dumps(model_params) if model_params else None, - output_schema=output_schema_str, + prompt_struct = _construct_prompt_struct(prompt) + + op = googlesql.AIGenerateOp(output_schema=output_schema_str) + return googlesql.apply_op( + op, + args=(prompt_struct,), + kwargs={ + "connection_id": connection_id, + "endpoint": endpoint, + "request_type": _upper_optional(request_type), + "model_params": json.dumps(model_params) if model_params else None, + "output_schema": output_schema_str, + }, ) - return series_list[0]._apply_nary_op(operator, series_list[1:]) - @log_adapter.method_logger(custom_base_name="bigquery_ai") def generate_bool( @@ -201,19 +202,19 @@ def generate_bool( * "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful. """ - prompt_context, series_list = _separate_context_and_series(prompt) - assert len(series_list) > 0 - - operator = ai_ops.AIGenerateBool( - prompt_context=tuple(prompt_context), - connection_id=connection_id, - endpoint=endpoint, - request_type=_upper_optional(request_type), - model_params=json.dumps(model_params) if model_params else None, + prompt_struct = _construct_prompt_struct(prompt) + + return googlesql.apply_op( + googlesql.AI_GENERATE_BOOL, + args=(prompt_struct,), + kwargs={ + "connection_id": connection_id, + "endpoint": endpoint, + "request_type": _upper_optional(request_type), + "model_params": json.dumps(model_params) if model_params else None, + }, ) - return series_list[0]._apply_nary_op(operator, series_list[1:]) - @log_adapter.method_logger(custom_base_name="bigquery_ai") def generate_int( @@ -275,19 +276,19 @@ def generate_int( * "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful. """ - prompt_context, series_list = _separate_context_and_series(prompt) - assert len(series_list) > 0 - - operator = ai_ops.AIGenerateInt( - prompt_context=tuple(prompt_context), - connection_id=connection_id, - endpoint=endpoint, - request_type=_upper_optional(request_type), - model_params=json.dumps(model_params) if model_params else None, + prompt_struct = _construct_prompt_struct(prompt) + + return googlesql.apply_op( + googlesql.AI_GENERATE_INT, + args=(prompt_struct,), + kwargs={ + "connection_id": connection_id, + "endpoint": endpoint, + "request_type": _upper_optional(request_type), + "model_params": json.dumps(model_params) if model_params else None, + }, ) - return series_list[0]._apply_nary_op(operator, series_list[1:]) - @log_adapter.method_logger(custom_base_name="bigquery_ai") def generate_double( @@ -349,19 +350,19 @@ def generate_double( * "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful. """ - prompt_context, series_list = _separate_context_and_series(prompt) - assert len(series_list) > 0 - - operator = ai_ops.AIGenerateDouble( - prompt_context=tuple(prompt_context), - connection_id=connection_id, - endpoint=endpoint, - request_type=_upper_optional(request_type), - model_params=json.dumps(model_params) if model_params else None, + prompt_struct = _construct_prompt_struct(prompt) + + return googlesql.apply_op( + googlesql.AI_GENERATE_DOUBLE, + args=(prompt_struct,), + kwargs={ + "connection_id": connection_id, + "endpoint": endpoint, + "request_type": _upper_optional(request_type), + "model_params": json.dumps(model_params) if model_params else None, + }, ) - return series_list[0]._apply_nary_op(operator, series_list[1:]) - @log_adapter.method_logger(custom_base_name="bigquery_ai") def generate_embedding( @@ -751,24 +752,19 @@ def embed( * "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful. """ - operator = ai_ops.AIEmbed( - endpoint=endpoint, - model=model, - task_type=_upper_optional(task_type), - title=title, - model_params=json.dumps(model_params) if model_params else None, - connection_id=connection_id, + return googlesql.apply_op( + googlesql.AI_EMBED, + args=(content,), + kwargs={ + "endpoint": endpoint, + "model": model, + "task_type": _upper_optional(task_type), + "title": title, + "model_params": json.dumps(model_params) if model_params else None, + "connection_id": connection_id, + }, ) - if isinstance(content, str): - return series.Series([content])._apply_unary_op(operator) - elif isinstance(content, pd.Series): - return series.Series(content)._apply_unary_op(operator) - elif isinstance(content, series.Series): - return content._apply_unary_op(operator) - else: - raise ValueError(f"Unsupported 'content' parameter type: {type(content)}") - @log_adapter.method_logger(custom_base_name="bigquery_ai") def if_( @@ -824,19 +820,19 @@ def if_( bigframes.series.Series: A new series of bools. """ - prompt_context, series_list = _separate_context_and_series(prompt) - assert len(series_list) > 0 - - operator = ai_ops.AIIf( - prompt_context=tuple(prompt_context), - connection_id=connection_id, - endpoint=endpoint, - optimization_mode=_upper_optional(optimization_mode), - max_error_ratio=max_error_ratio, + prompt_struct = _construct_prompt_struct(prompt) + + return googlesql.apply_op( + googlesql.AI_IF, + args=(prompt_struct,), + kwargs={ + "connection_id": connection_id, + "endpoint": endpoint, + "optimization_mode": _upper_optional(optimization_mode), + "max_error_ratio": max_error_ratio, + }, ) - return series_list[0]._apply_nary_op(operator, series_list[1:]) - @log_adapter.method_logger(custom_base_name="bigquery_ai") def classify( @@ -901,30 +897,30 @@ def classify( bigframes.series.Series: A new series of strings (or a series of arrays of strings if ``output_mode`` is specified). """ - prompt_context, series_list = _separate_context_and_series(input) - assert len(series_list) > 0 - if examples is not None: - example_tuples: Any = tuple( - (ex[0], tuple(ex[1]) if isinstance(ex[1], (list, tuple)) else ex[1]) + formatted_examples = [ + { + "input": ex[0], + "output": list(ex[1]) if isinstance(ex[1], (list, tuple)) else ex[1], + } for ex in examples - ) + ] else: - example_tuples = None - - operator = ai_ops.AIClassify( - prompt_context=tuple(prompt_context), - categories=tuple(categories), - examples=example_tuples, - connection_id=connection_id, - endpoint=endpoint, - output_mode=output_mode, - optimization_mode=_upper_optional(optimization_mode), - max_error_ratio=max_error_ratio, + formatted_examples = None + + return googlesql.apply_op( + googlesql.AI_CLASSIFY, + args=(input, categories), + kwargs={ + "examples": formatted_examples, + "connection_id": connection_id, + "endpoint": endpoint, + "output_mode": output_mode, + "optimization_mode": _upper_optional(optimization_mode), + "max_error_ratio": max_error_ratio, + }, ) - return series_list[0]._apply_nary_op(operator, series_list[1:]) - @log_adapter.method_logger(custom_base_name="bigquery_ai") def score( @@ -970,18 +966,18 @@ def score( bigframes.series.Series: A new series of double (float) values. """ - prompt_context, series_list = _separate_context_and_series(prompt) - assert len(series_list) > 0 + prompt_struct = _construct_prompt_struct(prompt) - operator = ai_ops.AIScore( - prompt_context=tuple(prompt_context), - connection_id=connection_id, - endpoint=endpoint, - max_error_ratio=max_error_ratio, + return googlesql.apply_op( + googlesql.AI_SCORE, + args=(prompt_struct,), + kwargs={ + "connection_id": connection_id, + "endpoint": endpoint, + "max_error_ratio": max_error_ratio, + }, ) - return series_list[0]._apply_nary_op(operator, series_list[1:]) - @log_adapter.method_logger(custom_base_name="bigquery_ai") def similarity( @@ -1026,36 +1022,18 @@ def similarity( bigframes.series.Series: A new series of FLOAT64 values representing the cosine similarity. """ - operator = ai_ops.AISimilarity( - endpoint=endpoint, - model=model, - model_params=json.dumps(model_params) if model_params else None, - connection_id=connection_id, + return googlesql.apply_op( + googlesql.AI_SIMILARITY, + kwargs={ + "content1": content1, + "content2": content2, + "endpoint": endpoint, + "model": model, + "model_params": json.dumps(model_params) if model_params else None, + "connection_id": connection_id, + }, ) - # Find a unifying session for the subsequent operations. - bf_session = None - if isinstance(content1, series.Series): - bf_session = content1._session - elif isinstance(content2, series.Series): - bf_session = content2._session - - if isinstance(content1, str) and isinstance(content2, str): - content1 = series.Series([content1], session=bf_session) - return content1._apply_binary_op(content2, operator) - elif isinstance(content1, str): - # content2 must be a series - content2 = convert.to_bf_series( - content2, default_index=None, session=bf_session - ) - return content2._apply_binary_op(content1, operator) - else: - # content1 must be a series. - content1 = convert.to_bf_series( - content1, default_index=None, session=bf_session - ) - return content1._apply_binary_op(content2, operator) - @log_adapter.method_logger(custom_base_name="bigquery_ai") def forecast( @@ -1246,3 +1224,24 @@ def _upper_optional(value: str | None) -> str | None: if value is None: return None return value.upper() + + +def _construct_prompt_struct(prompt: PROMPT_TYPE) -> series.Series: + prompt_context, series_list = _separate_context_and_series(prompt) + assert len(series_list) > 0 + + prompt_elements = [] + series_idx = 0 + for elem in prompt_context: + if elem is None: + prompt_elements.append(series_list[series_idx]) + series_idx += 1 + else: + prompt_elements.append(elem) + + import bigframes.operations.generic_ops as generic_ops + struct_names = tuple(f"_field_{i+1}" for i in range(len(prompt_elements))) + return googlesql.apply_op( + generic_ops.StructOp(column_names=struct_names), + args=prompt_elements, + ) diff --git a/packages/bigframes/bigframes/core/align.py b/packages/bigframes/bigframes/core/align.py new file mode 100644 index 000000000000..815f29282bc2 --- /dev/null +++ b/packages/bigframes/bigframes/core/align.py @@ -0,0 +1,167 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import typing +from typing import Any, Dict, Sequence, Tuple, Union + +import bigframes.core.expression as ex +import bigframes.core.identifiers as ids +from bigframes.core.blocks import Block + + +def align_n( + objects: Sequence[Any], + how: typing.Literal["inner", "left", "outer", "right", "cross"] = "outer", +) -> Tuple[Sequence[Any], Block]: + """Aligns a list of mixed Series-like, Index-like, and scalar values. + + Returns the list of mapped Expression objects corresponding to the aligned columns/constants, + along with the joined Block representing the aligned data. + """ + import bigframes.core.indexes as bf_indexes + import bigframes.series as bf_series + + # Find the first Series or Index object to serve as reference + ref_obj = None + for obj in objects: + if isinstance(obj, (bf_series.Series, bf_indexes.Index)): + ref_obj = obj + break + + if ref_obj is None: + raise ValueError("At least one input must be a BigFrames Series or Index object.") + + block = ref_obj._block + series_to_expr = {id(ref_obj): ex.deref(ref_obj._value_column)} + + # Collect all other unique Series and Index objects + all_block_likes = [] + for obj in objects: + if isinstance(obj, (bf_series.Series, bf_indexes.Index)) and id(obj) not in series_to_expr: + all_block_likes.append(obj) + series_to_expr[id(obj)] = None + + # Join Series/Index objects one by one + for s in all_block_likes: + ( + block, + ( + get_column_left, + get_column_right, + ), + ) = block.join(s._block, how=how) + + # Remap existing expressions + rebindings = { + ids.ColumnId(old): ids.ColumnId(new) + for old, new in get_column_left.items() + } + for oid, expr in series_to_expr.items(): + if expr is not None: + series_to_expr[oid] = expr.remap_column_refs(rebindings) + + # Assign expression for newly joined Series/Index + new_col_id = get_column_right[s._value_column] + series_to_expr[id(s)] = ex.deref(new_col_id) + + # Build the final list of aligned expressions/scalars + final_exprs = [] + for obj in objects: + if isinstance(obj, (bf_series.Series, bf_indexes.Index)): + final_exprs.append(series_to_expr[id(obj)]) + else: + final_exprs.append(obj) + + return final_exprs, block + + +def apply_op( + op: Any, # Any ops.NaryOp type to prevent circular import + args: Sequence[Any] = (), + kwargs: Dict[str, Any] = {}, +) -> Any: + """Applies an operation to a mix of Series, Index, literal, and other values, with necessary alignment.""" + import bigframes.core.convert as bf_convert + import bigframes.core.indexes as bf_indexes + import bigframes.series as bf_series + + # Find a reference block-like object in the inputs + ref_obj = None + for arg in args: + if isinstance(arg, (bf_series.Series, bf_indexes.Index)): + ref_obj = arg + break + if ref_obj is None: + for val in kwargs.values(): + if isinstance(val, (bf_series.Series, bf_indexes.Index)): + ref_obj = val + break + + if ref_obj is None: + raise ValueError("At least one input must be a BigFrames Series or Index.") + + session = ref_obj._block.session + ref_index = ref_obj.index + + # Convert inputs that are list-like or pandas Series/Index to BigFrames Series + def convert_input(val): + if isinstance(val, (bf_series.Series, bf_indexes.Index)): + return val + elif bf_convert.can_convert_to_series(val): + return bf_convert.to_bf_series(val, ref_index, session) + else: + return val + + converted_args = [convert_input(arg) for arg in args] + converted_kwargs = {k: convert_input(v) for k, v in kwargs.items()} + + # Collect all inputs for alignment + alignment_inputs = [] + for arg in converted_args: + alignment_inputs.append(arg) + for val in converted_kwargs.values(): + alignment_inputs.append(val) + + # Perform core alignment + aligned_inputs, block = align_n(alignment_inputs, how="outer") + + # Map the aligned expressions back to args and kwargs, wrapping any remaining raw scalars as ex.const + final_args = [] + cursor = 0 + for arg in converted_args: + expr = aligned_inputs[cursor] + if not isinstance(expr, ex.Expression): + expr = ex.const(expr) + final_args.append(expr) + cursor += 1 + + final_kwargs = {} + for k, v in converted_kwargs.items(): + expr = aligned_inputs[cursor] + if not isinstance(expr, ex.Expression): + expr = ex.const(expr) + final_kwargs[k] = expr + cursor += 1 + + # Apply the operation and construct the result + expr = op.as_expr(*final_args, **final_kwargs) + block, result_id = block.project_expr(expr) + + # Depending on the type of the reference object, return Series or Index + if isinstance(ref_obj, bf_series.Series): + return bf_series.Series(block.select_column(result_id)) + else: + return bf_indexes.Index(block.select_column(result_id)) diff --git a/packages/bigframes/bigframes/core/blocks.py b/packages/bigframes/bigframes/core/blocks.py index b9a246fc0360..77b7b8fb1272 100644 --- a/packages/bigframes/bigframes/core/blocks.py +++ b/packages/bigframes/bigframes/core/blocks.py @@ -3553,3 +3553,4 @@ def _resolve_index_col( else: # Joining with value columns only. Existing indices will be discarded. return [] + diff --git a/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py b/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py index 926f220370b0..2feeef27a040 100644 --- a/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py +++ b/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py @@ -1912,156 +1912,10 @@ def struct_op_impl( return ibis_types.struct(data) -@scalar_op_compiler.register_nary_op(ops.AIGenerate, pass_op=True) -def ai_generate( - *values: ibis_types.Value, op: ops.AIGenerate -) -> ibis_types.StructValue: - return ai_ops.AIGenerate( - _construct_prompt(values, op.prompt_context), # type: ignore - op.connection_id, # type: ignore - op.endpoint, # type: ignore - op.request_type, # type: ignore - op.model_params, # type: ignore - op.output_schema, # type: ignore - ).to_expr() - - -@scalar_op_compiler.register_nary_op(ops.AIGenerateBool, pass_op=True) -def ai_generate_bool( - *values: ibis_types.Value, op: ops.AIGenerateBool -) -> ibis_types.StructValue: - return ai_ops.AIGenerateBool( - _construct_prompt(values, op.prompt_context), # type: ignore - op.connection_id, # type: ignore - op.endpoint, # type: ignore - op.request_type, # type: ignore - op.model_params, # type: ignore - ).to_expr() - - -@scalar_op_compiler.register_nary_op(ops.AIGenerateInt, pass_op=True) -def ai_generate_int( - *values: ibis_types.Value, op: ops.AIGenerateInt -) -> ibis_types.StructValue: - return ai_ops.AIGenerateInt( - _construct_prompt(values, op.prompt_context), # type: ignore - op.connection_id, # type: ignore - op.endpoint, # type: ignore - op.request_type, # type: ignore - op.model_params, # type: ignore - ).to_expr() - - -@scalar_op_compiler.register_nary_op(ops.AIGenerateDouble, pass_op=True) -def ai_generate_double( - *values: ibis_types.Value, op: ops.AIGenerateDouble -) -> ibis_types.StructValue: - return ai_ops.AIGenerateDouble( - _construct_prompt(values, op.prompt_context), # type: ignore - op.connection_id, # type: ignore - op.endpoint, # type: ignore - op.request_type, # type: ignore - op.model_params, # type: ignore - ).to_expr() - - -@scalar_op_compiler.register_unary_op(ops.AIEmbed, pass_op=True) -def ai_embed(value: ibis_types.Value, op: ops.AIEmbed) -> ibis_types.StructValue: - return ai_ops.AIEmbed( - value, # type: ignore - connection_id=op.connection_id, # type: ignore - endpoint=op.endpoint, # type: ignore - model=op.model, # type: ignore - task_type=op.task_type, # type: ignore - title=op.title, # type: ignore - model_params=op.model_params, # type: ignore - ).to_expr() - - -@scalar_op_compiler.register_nary_op(ops.AIIf, pass_op=True) -def ai_if(*values: ibis_types.Value, op: ops.AIIf) -> ibis_types.StructValue: - return ai_ops.AIIf( - _construct_prompt(values, op.prompt_context), # type: ignore - op.connection_id, # type: ignore - op.endpoint, # type: ignore - op.optimization_mode, # type: ignore - op.max_error_ratio, # type: ignore - ).to_expr() - - -@scalar_op_compiler.register_nary_op(ops.AIClassify, pass_op=True) -def ai_classify( - *values: ibis_types.Value, op: ops.AIClassify -) -> ibis_types.StructValue: - return ai_ops.AIClassify( - _construct_prompt(values, op.prompt_context), # type: ignore - op.categories, # type: ignore - _construct_examples(op.examples), # type: ignore - op.connection_id, # type: ignore - op.endpoint, # type: ignore - op.output_mode, # type: ignore - op.optimization_mode, # type: ignore - op.max_error_ratio, # type: ignore - ).to_expr() - - -@scalar_op_compiler.register_nary_op(ops.AIScore, pass_op=True) -def ai_score(*values: ibis_types.Value, op: ops.AIScore) -> ibis_types.StructValue: - return ai_ops.AIScore( - _construct_prompt(values, op.prompt_context), # type: ignore - op.connection_id, # type: ignore - op.endpoint, # type: ignore - op.max_error_ratio, # type: ignore - ).to_expr() - - -@scalar_op_compiler.register_binary_op(ops.AISimilarity, pass_op=True) -def ai_similarity( - content1: ibis_types.Value, content2: ibis_types.Value, op: ops.AISimilarity -) -> ibis_types.Value: - return ai_ops.AISimilarity( - content1, # type: ignore - content2, # type: ignore - op.endpoint, # type: ignore - op.model, # type: ignore - op.model_params, # type: ignore - op.connection_id, # type: ignore - ).to_expr() - - -def _construct_prompt( - col_refs: tuple[ibis_types.Value], prompt_context: tuple[str | None] -) -> ibis_types.StructValue: - prompt: dict[str, ibis_types.Value | str] = {} - column_ref_idx = 0 - - for idx, elem in enumerate(prompt_context): - if elem is None: - prompt[f"_field_{idx + 1}"] = col_refs[column_ref_idx] - column_ref_idx += 1 - else: - prompt[f"_field_{idx + 1}"] = elem - - return ibis.struct(prompt) - - -def _construct_examples( - examples: tuple[tuple[str, str | tuple[str, ...]], ...] | None, -) -> ibis_types.ArrayValue | None: - if examples is None: - return None - results: list[ibis_types.StructValue] = [] - for example in examples: - value: Any = example[1] - if isinstance(example[1], (list, tuple)): - value = list(example[1]) - ibis_example = ibis.struct({"_field_1": example[0], "_field_2": value}) - results.append(ibis_example) - return ibis.array(results) @scalar_op_compiler.register_nary_op(ops.RowKey, pass_op=True) diff --git a/packages/bigframes/bigframes/core/compile/sqlglot/__init__.py b/packages/bigframes/bigframes/core/compile/sqlglot/__init__.py index fa515e4f15a2..4ebc56c48d2d 100644 --- a/packages/bigframes/bigframes/core/compile/sqlglot/__init__.py +++ b/packages/bigframes/bigframes/core/compile/sqlglot/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations -import bigframes.core.compile.sqlglot.expressions.ai_ops # noqa: F401 + import bigframes.core.compile.sqlglot.expressions.array_ops # noqa: F401 import bigframes.core.compile.sqlglot.expressions.blob_ops # noqa: F401 import bigframes.core.compile.sqlglot.expressions.bool_ops # noqa: F401 diff --git a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py deleted file mode 100644 index 12a6b9859a2a..000000000000 --- a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from dataclasses import asdict -from typing import Any - -import bigframes_vendored.sqlglot.expressions as sge - -from bigframes import operations as ops -from bigframes.core.compile.sqlglot import expression_compiler -from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr - -register_nary_op = expression_compiler.expression_compiler.register_nary_op -register_binary_op = expression_compiler.expression_compiler.register_binary_op -register_unary_op = expression_compiler.expression_compiler.register_unary_op - - -@register_nary_op(ops.AIGenerate, pass_op=True) -def _(*exprs: TypedExpr, op: ops.AIGenerate) -> sge.Expression: - args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op) - - return sge.func("AI.GENERATE", *args) - - -@register_nary_op(ops.AIGenerateBool, pass_op=True) -def _(*exprs: TypedExpr, op: ops.AIGenerateBool) -> sge.Expression: - args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op) - - return sge.func("AI.GENERATE_BOOL", *args) - - -@register_nary_op(ops.AIGenerateInt, pass_op=True) -def _(*exprs: TypedExpr, op: ops.AIGenerateInt) -> sge.Expression: - args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op) - - return sge.func("AI.GENERATE_INT", *args) - - -@register_nary_op(ops.AIGenerateDouble, pass_op=True) -def _(*exprs: TypedExpr, op: ops.AIGenerateDouble) -> sge.Expression: - args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op) - - return sge.func("AI.GENERATE_DOUBLE", *args) - - -@register_unary_op(ops.AIEmbed, pass_op=True) -def _(expr: TypedExpr, op: ops.AIEmbed) -> sge.Expression: - args: list[Any] = [expr.expr] + _construct_named_args(op) - - return sge.func("AI.EMBED", *args) - - -@register_nary_op(ops.AIIf, pass_op=True) -def _(*exprs: TypedExpr, op: ops.AIIf) -> sge.Expression: - args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op) - - return sge.func("AI.IF", *args) - - -@register_nary_op(ops.AIClassify, pass_op=True) -def _(*exprs: TypedExpr, op: ops.AIClassify) -> sge.Expression: - args = [ - _construct_prompt(exprs, op.prompt_context, param_name="input"), - ] + _construct_named_args(op) - - return sge.func("AI.CLASSIFY", *args) - - -@register_nary_op(ops.AIScore, pass_op=True) -def _(*exprs: TypedExpr, op: ops.AIScore) -> sge.Expression: - args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op) - - return sge.func("AI.SCORE", *args) - - -@register_binary_op(ops.AISimilarity, pass_op=True) -def _(content1: TypedExpr, content2: TypedExpr, op: ops.AISimilarity) -> sge.Expression: - args = [ - sge.Kwarg(this="content1", expression=content1.expr), - sge.Kwarg(this="content2", expression=content2.expr), - ] + _construct_named_args(op) - - return sge.func("AI.SIMILARITY", *args) - - -def _construct_prompt( - exprs: tuple[TypedExpr, ...], - prompt_context: tuple[str | None, ...], - param_name: str = "prompt", -) -> sge.Kwarg: - prompt: list[str | sge.Expression] = [] - column_ref_idx = 0 - - for elem in prompt_context: - if elem is None: - prompt.append(exprs[column_ref_idx].expr) - column_ref_idx += 1 - else: - prompt.append(sge.Literal.string(elem)) - - return sge.Kwarg(this=param_name, expression=sge.Tuple(expressions=prompt)) - - -def _construct_named_args(op: ops.ScalarOp) -> list[sge.Kwarg]: - args = [] - - op_args = asdict(op) - - for field, value in op_args.items(): - if value is None or field == "prompt_context": - continue - - if field == "categories": - category_literals = [sge.Literal.string(cat) for cat in value] - categories_arg = sge.Kwarg( - this="categories", expression=sge.array(*category_literals) - ) - args.append(categories_arg) - elif field == "model_params": - # model_params is a JSON string, so we need to use the JSON function to pass it as a named argument. - args.append( - sge.Kwarg( - this="model_params", - # sge.JSON requires the SQLGlot version to be at least 25.18.0 - # PARSE_JSON won't work as the function requires a JSON literal. - expression=sge.JSON(this=sge.Literal.string(value)), - ) - ) - elif field == "examples": - example_expressions = [] - for key, val in value: - if isinstance(val, (list, tuple)): - val_expr: sge.Array | sge.Literal = sge.array( - *[sge.Literal.string(v) for v in val] - ) - else: - val_expr = sge.Literal.string(val) - example_expressions.append( - sge.Tuple(expressions=[sge.Literal.string(key), val_expr]) - ) - args.append( - sge.Kwarg(this=field, expression=sge.array(*example_expressions)) - ) - else: - args.append(sge.Kwarg(this=field, expression=sge.convert(value))) - - return args diff --git a/packages/bigframes/bigframes/core/expression.py b/packages/bigframes/bigframes/core/expression.py index 6c27dfc120b6..e658eeee64ae 100644 --- a/packages/bigframes/bigframes/core/expression.py +++ b/packages/bigframes/bigframes/core/expression.py @@ -32,10 +32,27 @@ import bigframes.operations +class FrozenDict(dict): + def __hash__(self) -> int: + return hash(tuple((k, make_hashable(v)) for k, v in sorted(self.items()))) + + +def make_hashable(val: typing.Any) -> typing.Hashable: + if isinstance(val, list): + return tuple(make_hashable(x) for x in val) + elif isinstance(val, dict): + return FrozenDict({k: make_hashable(v) for k, v in val.items()}) + elif isinstance(val, tuple): + return tuple(make_hashable(x) for x in val) + else: + return val + + def const( - value: typing.Hashable, dtype: dtypes.ExpressionType = None + value: typing.Any, dtype: dtypes.ExpressionType = None ) -> ScalarConstantExpression: - return ScalarConstantExpression(value, dtype or dtypes.infer_literal_type(value)) + hashable_value = make_hashable(value) + return ScalarConstantExpression(hashable_value, dtype or dtypes.infer_literal_type(hashable_value)) def deref(name: str) -> DerefOp: diff --git a/packages/bigframes/bigframes/operations/__init__.py b/packages/bigframes/bigframes/operations/__init__.py index dd036bec5a26..4ad91629c06f 100644 --- a/packages/bigframes/bigframes/operations/__init__.py +++ b/packages/bigframes/bigframes/operations/__init__.py @@ -14,17 +14,7 @@ from __future__ import annotations -from bigframes.operations.ai_ops import ( - AIClassify, - AIEmbed, - AIGenerate, - AIGenerateBool, - AIGenerateDouble, - AIGenerateInt, - AIIf, - AIScore, - AISimilarity, -) + from bigframes.operations.array_ops import ( ArrayIndexOp, ArrayMapOp, @@ -126,7 +116,7 @@ geo_x_op, geo_y_op, ) -from bigframes.operations.googlesql import GoogleSqlScalarOp +from bigframes.operations.googlesql import apply_op, GoogleSqlScalarOp from bigframes.operations.json_ops import ( JSONExtract, JSONExtractArray, @@ -426,16 +416,6 @@ "GeoStDistanceOp", "GeoStLengthOp", "GeoStRegionStatsOp", - # AI ops - "AIClassify", - "AIGenerate", - "AIGenerateBool", - "AIGenerateDouble", - "AIGenerateInt", - "AIEmbed", - "AIIf", - "AIScore", - "AISimilarity", # Numpy ops mapping "NUMPY_TO_BINOP", "NUMPY_TO_OP", @@ -444,4 +424,5 @@ "ArrayMapOp", # GoogleSql "GoogleSqlScalarOp", + "apply_op", ] diff --git a/packages/bigframes/bigframes/operations/ai_ops.py b/packages/bigframes/bigframes/operations/ai_ops.py deleted file mode 100644 index ad2b9850577e..000000000000 --- a/packages/bigframes/bigframes/operations/ai_ops.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import dataclasses -from typing import ClassVar, Tuple - -import pandas as pd -import pyarrow as pa - -from bigframes import dtypes -from bigframes.operations import base_ops, output_schemas - - -@dataclasses.dataclass(frozen=True) -class AIGenerate(base_ops.NaryOp): - name: ClassVar[str] = "ai_generate" - - prompt_context: Tuple[str | None, ...] - connection_id: str | None = None - endpoint: str | None = None - request_type: str | None = None - model_params: str | None = None - output_schema: str | None = None - - def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: - if self.output_schema is None: - output_fields = (pa.field("result", pa.string()),) - else: - output_fields = output_schemas.parse_sql_fields(self.output_schema) - - return pd.ArrowDtype( - pa.struct( - ( - *output_fields, - pa.field("full_response", dtypes.JSON_ARROW_TYPE), - pa.field("status", pa.string()), - ) - ) - ) - - -@dataclasses.dataclass(frozen=True) -class AIGenerateBool(base_ops.NaryOp): - name: ClassVar[str] = "ai_generate_bool" - - prompt_context: Tuple[str | None, ...] - connection_id: str | None = None - endpoint: str | None = None - request_type: str | None = None - model_params: str | None = None - - def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: - return pd.ArrowDtype( - pa.struct( - ( - pa.field("result", pa.bool_()), - pa.field("full_response", dtypes.JSON_ARROW_TYPE), - pa.field("status", pa.string()), - ) - ) - ) - - -@dataclasses.dataclass(frozen=True) -class AIGenerateInt(base_ops.NaryOp): - name: ClassVar[str] = "ai_generate_int" - - prompt_context: Tuple[str | None, ...] - connection_id: str | None = None - endpoint: str | None = None - request_type: str | None = None - model_params: str | None = None - - def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: - return pd.ArrowDtype( - pa.struct( - ( - pa.field("result", pa.int64()), - pa.field("full_response", dtypes.JSON_ARROW_TYPE), - pa.field("status", pa.string()), - ) - ) - ) - - -@dataclasses.dataclass(frozen=True) -class AIGenerateDouble(base_ops.NaryOp): - name: ClassVar[str] = "ai_generate_double" - - prompt_context: Tuple[str | None, ...] - connection_id: str | None = None - endpoint: str | None = None - request_type: str | None = None - model_params: str | None = None - - def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: - return pd.ArrowDtype( - pa.struct( - ( - pa.field("result", pa.float64()), - pa.field("full_response", dtypes.JSON_ARROW_TYPE), - pa.field("status", pa.string()), - ) - ) - ) - - -@dataclasses.dataclass(frozen=True) -class AIEmbed(base_ops.UnaryOp): - name: ClassVar[str] = "ai_embed" - - endpoint: str | None = None - model: str | None = None - task_type: str | None = None - title: str | None = None - model_params: str | None = None - connection_id: str | None = None - - def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: - return pd.ArrowDtype( - pa.struct( - ( - pa.field("result", pa.list_(pa.float64())), - pa.field("status", pa.string()), - ) - ) - ) - - -@dataclasses.dataclass(frozen=True) -class AIIf(base_ops.NaryOp): - name: ClassVar[str] = "ai_if" - - prompt_context: Tuple[str | None, ...] - connection_id: str | None = None - endpoint: str | None = None - optimization_mode: str | None = None - max_error_ratio: float | None = None - - def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: - return dtypes.BOOL_DTYPE - - -@dataclasses.dataclass(frozen=True) -class AIClassify(base_ops.NaryOp): - name: ClassVar[str] = "ai_classify" - - prompt_context: Tuple[str | None, ...] - categories: tuple[str, ...] - examples: ( - tuple[tuple[str, str], ...] | tuple[tuple[str, tuple[str, ...]], ...] | None - ) = None - connection_id: str | None = None - endpoint: str | None = None - output_mode: str | None = None - optimization_mode: str | None = None - max_error_ratio: float | None = None - - def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: - if self.output_mode is not None: - return dtypes.list_type(dtypes.STRING_DTYPE) - return dtypes.STRING_DTYPE - - -@dataclasses.dataclass(frozen=True) -class AIScore(base_ops.NaryOp): - name: ClassVar[str] = "ai_score" - - prompt_context: Tuple[str | None, ...] - connection_id: str | None = None - endpoint: str | None = None - max_error_ratio: float | None = None - - def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: - return dtypes.FLOAT_DTYPE - - -@dataclasses.dataclass(frozen=True) -class AISimilarity(base_ops.BinaryOp): - name: ClassVar[str] = "ai_similarity" - - endpoint: str | None = None - model: str | None = None - model_params: str | None = None - connection_id: str | None = None - - def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: - return dtypes.FLOAT_DTYPE diff --git a/packages/bigframes/bigframes/operations/googlesql.py b/packages/bigframes/bigframes/operations/googlesql.py index 0100784bda1d..9821109e4def 100644 --- a/packages/bigframes/bigframes/operations/googlesql.py +++ b/packages/bigframes/bigframes/operations/googlesql.py @@ -20,10 +20,16 @@ from enum import Enum, auto from typing import Callable, Iterable +import pandas as pd + import bigframes.operations as ops import bigframes.operations.type as op_typing from bigframes import dtypes +from bigframes.operations.base_ops import _convert_expr_input +if typing.TYPE_CHECKING: + # Avoids circular dependency + import bigframes.core.expression @dataclasses.dataclass(frozen=True) class ArgSpec: @@ -62,6 +68,43 @@ def deterministic(self) -> bool: def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: return self.signature(*input_types) + def as_expr(self, *args: str | ex.Expression, **kwargs: Any) -> ex.Expression: + import bigframes.core.expression as ex + + def wrap_input(expr, index: int): + if isinstance(expr, ex.Expression): + return expr + + is_const_only = False + if index < len(self.args): + is_const_only = self.args[index].const_only + + if isinstance(expr, str) and not is_const_only: + return ex.deref(expr) + else: + return ex.const(expr) + + name_to_index = {arg_spec.arg_name: i for i, arg_spec in enumerate(self.args) if arg_spec.arg_name is not None} + + # Keep this in sync with output_type and compilers + inputs: list[ex.Expression] = [] + + for i, expr in enumerate(args): + inputs.append(wrap_input(expr, i)) + + for name, expr in kwargs.items(): + if name not in name_to_index: + raise ValueError(f"Argument '{name}' is not valid for this operation.") + index = name_to_index[name] + if index >= len(inputs): + inputs.extend([ex.OmittedArg()] * (index - len(inputs) + 1)) + inputs[index] = wrap_input(expr, index) + + return ex.OpExpression( + self, + tuple(inputs), + ) + RAND = GoogleSqlScalarOp( @@ -107,3 +150,208 @@ def _check_simplify_inputs( is_deterministic=True, signature=_check_simplify_inputs, ) + +def _ai_classify_output_type(*input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + output_mode = input_types[5] if len(input_types) > 5 else None + if output_mode is not None: + return dtypes.list_type(dtypes.STRING_DTYPE) + return dtypes.STRING_DTYPE + +AI_CLASSIFY = GoogleSqlScalarOp( + sql_name="AI.CLASSIFY", + args=( + ArgSpec(arg_name="input"), + ArgSpec(arg_name="categories"), + ArgSpec(arg_name="examples", optional=True, const_only=True), + ArgSpec(arg_name="connection_id", optional=True, const_only=True), + ArgSpec(arg_name="endpoint", optional=True, const_only=True), + ArgSpec(arg_name="output_mode", optional=True, const_only=True), + ArgSpec(arg_name="optimization_mode", optional=True, const_only=True), + ArgSpec(arg_name="max_error_ratio", optional=True, const_only=True), + ), + signature=_ai_classify_output_type, +) + + +@dataclasses.dataclass(frozen=True) +class AIGenerateOp(GoogleSqlScalarOp): + sql_name: str = "AI.GENERATE" + args: tuple[ArgSpec, ...] = ( + ArgSpec(arg_name="prompt"), + ArgSpec(arg_name="connection_id", optional=True, const_only=True), + ArgSpec(arg_name="endpoint", optional=True, const_only=True), + ArgSpec(arg_name="request_type", optional=True, const_only=True), + ArgSpec(arg_name="model_params", optional=True, const_only=True), + ArgSpec(arg_name="output_schema", optional=True, const_only=True), + ) + signature: typing.Callable[..., dtypes.ExpressionType] = lambda: dtypes.STRING_DTYPE + output_schema: str | None = None + + def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + import pyarrow as pa + if self.output_schema is None: + output_fields = (pa.field("result", pa.string()),) + else: + from bigframes.operations import output_schemas + output_fields = output_schemas.parse_sql_fields(self.output_schema) + + return pd.ArrowDtype( + pa.struct( + ( + *output_fields, + pa.field("full_response", dtypes.JSON_ARROW_TYPE), + pa.field("status", pa.string()), + ) + ) + ) + + +def _ai_generate_bool_output_type(*input_types): + import pyarrow as pa + return pd.ArrowDtype( + pa.struct( + ( + pa.field("result", pa.bool_()), + pa.field("full_response", dtypes.JSON_ARROW_TYPE), + pa.field("status", pa.string()), + ) + ) + ) + +AI_GENERATE_BOOL = GoogleSqlScalarOp( + sql_name="AI.GENERATE_BOOL", + args=( + ArgSpec(arg_name="prompt"), + ArgSpec(arg_name="connection_id", optional=True, const_only=True), + ArgSpec(arg_name="endpoint", optional=True, const_only=True), + ArgSpec(arg_name="request_type", optional=True, const_only=True), + ArgSpec(arg_name="model_params", optional=True, const_only=True), + ), + signature=_ai_generate_bool_output_type, +) + + +def _ai_generate_int_output_type(*input_types): + import pyarrow as pa + return pd.ArrowDtype( + pa.struct( + ( + pa.field("result", pa.int64()), + pa.field("full_response", dtypes.JSON_ARROW_TYPE), + pa.field("status", pa.string()), + ) + ) + ) + +AI_GENERATE_INT = GoogleSqlScalarOp( + sql_name="AI.GENERATE_INT", + args=( + ArgSpec(arg_name="prompt"), + ArgSpec(arg_name="connection_id", optional=True, const_only=True), + ArgSpec(arg_name="endpoint", optional=True, const_only=True), + ArgSpec(arg_name="request_type", optional=True, const_only=True), + ArgSpec(arg_name="model_params", optional=True, const_only=True), + ), + signature=_ai_generate_int_output_type, +) + + +def _ai_generate_double_output_type(*input_types): + import pyarrow as pa + return pd.ArrowDtype( + pa.struct( + ( + pa.field("result", pa.float64()), + pa.field("full_response", dtypes.JSON_ARROW_TYPE), + pa.field("status", pa.string()), + ) + ) + ) + +AI_GENERATE_DOUBLE = GoogleSqlScalarOp( + sql_name="AI.GENERATE_DOUBLE", + args=( + ArgSpec(arg_name="prompt"), + ArgSpec(arg_name="connection_id", optional=True, const_only=True), + ArgSpec(arg_name="endpoint", optional=True, const_only=True), + ArgSpec(arg_name="request_type", optional=True, const_only=True), + ArgSpec(arg_name="model_params", optional=True, const_only=True), + ), + signature=_ai_generate_double_output_type, +) + + +def _ai_embed_output_type(*input_types): + import pyarrow as pa + return pd.ArrowDtype( + pa.struct( + ( + pa.field("result", pa.list_(pa.float64())), + pa.field("status", pa.string()), + ) + ) + ) + +AI_EMBED = GoogleSqlScalarOp( + sql_name="AI.EMBED", + args=( + ArgSpec(arg_name="content"), + ArgSpec(arg_name="endpoint", optional=True, const_only=True), + ArgSpec(arg_name="model", optional=True, const_only=True), + ArgSpec(arg_name="task_type", optional=True, const_only=True), + ArgSpec(arg_name="title", optional=True, const_only=True), + ArgSpec(arg_name="model_params", optional=True, const_only=True), + ArgSpec(arg_name="connection_id", optional=True, const_only=True), + ), + signature=_ai_embed_output_type, +) + + +AI_IF = GoogleSqlScalarOp( + sql_name="AI.IF", + args=( + ArgSpec(arg_name="prompt"), + ArgSpec(arg_name="connection_id", optional=True, const_only=True), + ArgSpec(arg_name="endpoint", optional=True, const_only=True), + ArgSpec(arg_name="optimization_mode", optional=True, const_only=True), + ArgSpec(arg_name="max_error_ratio", optional=True, const_only=True), + ), + signature=lambda *args: dtypes.BOOL_DTYPE, +) + + +AI_SCORE = GoogleSqlScalarOp( + sql_name="AI.SCORE", + args=( + ArgSpec(arg_name="prompt"), + ArgSpec(arg_name="connection_id", optional=True, const_only=True), + ArgSpec(arg_name="endpoint", optional=True, const_only=True), + ArgSpec(arg_name="max_error_ratio", optional=True, const_only=True), + ), + signature=lambda *args: dtypes.FLOAT_DTYPE, +) + + +AI_SIMILARITY = GoogleSqlScalarOp( + sql_name="AI.SIMILARITY", + args=( + ArgSpec(arg_name="content1"), + ArgSpec(arg_name="content2"), + ArgSpec(arg_name="endpoint", optional=True, const_only=True), + ArgSpec(arg_name="model", optional=True, const_only=True), + ArgSpec(arg_name="model_params", optional=True, const_only=True), + ArgSpec(arg_name="connection_id", optional=True, const_only=True), + ), + signature=lambda *args: dtypes.FLOAT_DTYPE, +) + + + +def apply_op( + op: ops.NaryOp, + args: typing.Sequence[typing.Any] = (), + kwargs: typing.Dict[str, typing.Any] = {}, +) -> typing.Any: + """Applies an operation to a mix of Series-like, literal, and other values, with necessary alignment.""" + import bigframes.core.align as align + return align.apply_op(op, args=args, kwargs=kwargs) \ No newline at end of file diff --git a/packages/bigframes/bigframes/series.py b/packages/bigframes/bigframes/series.py index 87c03395c753..208316de44e1 100644 --- a/packages/bigframes/bigframes/series.py +++ b/packages/bigframes/bigframes/series.py @@ -2769,40 +2769,20 @@ def _align_n( typing.Sequence[Union[ex.ScalarConstantExpression, ex.DerefOp]], blocks.Block, ]: - if ignore_self: - value_ids: List[Union[ex.ScalarConstantExpression, ex.DerefOp]] = [] - else: - value_ids = [ex.deref(self._value_column)] - - block = self._block - for other in others: - if isinstance(other, Series): - ( - block, - ( - get_column_left, - get_column_right, - ), - ) = block.join(other._block, how=how) - rebindings = { - ids.ColumnId(old): ids.ColumnId(new) - for old, new in get_column_left.items() - } - remapped_value_ids = ( - value.remap_column_refs(rebindings) for value in value_ids - ) - value_ids = [ - *remapped_value_ids, # type: ignore - ex.deref(get_column_right[other._value_column]), - ] + inputs = others if ignore_self else [self, *others] + import bigframes.core.align as align + aligned_exprs, block = align.align_n(inputs, how=how) + + # Post-wrap raw scalars into ex.const with proper type coercion + final_exprs = [] + for expr in aligned_exprs: + if isinstance(expr, ex.Expression): + final_exprs.append(expr) else: - # Will throw if can't interpret as scalar. dtype = typing.cast(bigframes.dtypes.Dtype, self._dtype) - value_ids = [ - *value_ids, - ex.const(other, dtype=dtype if cast_scalars else None), - ] - return (value_ids, block) + final_exprs.append(ex.const(expr, dtype=dtype if cast_scalars else None)) + + return final_exprs, block def _throw_if_null_index(self, opname: __builtins__.str): if len(self._block.index_columns) == 0: diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/None/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/None/out.sql index 6771527318fa..4cd48b99a574 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/None/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/None/out.sql @@ -1,3 +1,7 @@ SELECT - AI.CLASSIFY(input => (`string_col`), categories => ['greeting', 'rejection']) AS `result` + AI.CLASSIFY( + input => `string_col`, + categories => ['greeting', 'rejection'], + connection_id => NULL + ) AS `result` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/bigframes-dev.us.bigframes-default-connection/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/bigframes-dev.us.bigframes-default-connection/out.sql index 63c31d94566d..b205d54e745b 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/bigframes-dev.us.bigframes-default-connection/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/bigframes-dev.us.bigframes-default-connection/out.sql @@ -1,6 +1,6 @@ SELECT AI.CLASSIFY( - input => (`string_col`), + input => `string_col`, categories => ['greeting', 'rejection'], connection_id => 'bigframes-dev.us.bigframes-default-connection' ) AS `result` diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify_multi_with_list_examples/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify_multi_with_list_examples/out.sql index a4a7f783da97..c289bddd2ac4 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify_multi_with_list_examples/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify_multi_with_list_examples/out.sql @@ -1,8 +1,11 @@ SELECT AI.CLASSIFY( - input => (`string_col`), + input => `string_col`, categories => ['greeting', 'rejection'], - examples => [('hi', ['greeting', 'positive']), ('bye', ['rejection', 'negative'])], + examples => [ + STRUCT('hi' AS `input`, ['greeting', 'positive'] AS `output`), + STRUCT('bye' AS `input`, ['rejection', 'negative'] AS `output`) + ], output_mode => 'multi' ) AS `result` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify_with_output_mode/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify_with_output_mode/out.sql index fb3c6af8b0b0..8c16d1c155e6 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify_with_output_mode/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify_with_output_mode/out.sql @@ -1,6 +1,6 @@ SELECT AI.CLASSIFY( - input => (`string_col`), + input => `string_col`, categories => ['greeting', 'rejection'], output_mode => 'multi' ) AS `result` diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify_with_params/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify_with_params/out.sql index 982b747f8927..ae0d288bc38c 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify_with_params/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify_with_params/out.sql @@ -1,8 +1,11 @@ SELECT AI.CLASSIFY( - input => (`string_col`), + input => `string_col`, categories => ['greeting', 'rejection'], - examples => [('hi', 'greeting'), ('bye', 'rejection')], + examples => [ + STRUCT('hi' AS `input`, 'greeting' AS `output`), + STRUCT('bye' AS `input`, 'rejection' AS `output`) + ], endpoint => 'gemini-2.5-flash', max_error_ratio => 0.1 ) AS `result` diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed/out.sql index 9c18a7cd532f..260bf7368b2e 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed/out.sql @@ -1,3 +1,3 @@ SELECT - AI.EMBED(`string_col`, endpoint => 'text-embedding-005') AS `result` + AI.EMBED(content => `string_col`, endpoint => 'text-embedding-005') AS `result` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_connection_id/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_connection_id/out.sql index 0968a101b22a..8053215b11de 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_connection_id/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_connection_id/out.sql @@ -1,6 +1,6 @@ SELECT AI.EMBED( - `string_col`, + content => `string_col`, endpoint => 'text-embedding-005', connection_id => 'bigframes-dev.us.bigframes-default-connection' ) AS `result` diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_model/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_model/out.sql index 4c3c76f87b61..8a8b9aad0ddd 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_model/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_model/out.sql @@ -1,3 +1,3 @@ SELECT - AI.EMBED(`string_col`, model => 'embeddinggemma-300m') AS `result` + AI.EMBED(content => `string_col`, model => 'embeddinggemma-300m') AS `result` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_task_type_and_title/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_task_type_and_title/out.sql index 9e4db995871b..417b8dcc4b05 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_task_type_and_title/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_task_type_and_title/out.sql @@ -1,9 +1,9 @@ SELECT AI.EMBED( - `string_col`, + content => `string_col`, endpoint => 'text-embedding-005', task_type => 'RETRIEVAL_DOCUMENT', title => 'My Document', - model_params => JSON '{"outputDimensionality": 256}' + model_params => '{"outputDimensionality": 256}' ) AS `result` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate/out.sql index 9593347238f8..22556c43b2dc 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate/out.sql @@ -1,6 +1,6 @@ SELECT AI.GENERATE( - prompt => (`string_col`, ' is the same as ', `string_col`), + prompt => STRUCT(`string_col` AS _field_1, ' is the same as ' AS _field_2, `string_col` AS _field_3), endpoint => 'gemini-2.5-flash', request_type => 'SHARED' ) AS `result` diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool/out.sql index aebccad12217..f3dd825c2f3a 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool/out.sql @@ -1,6 +1,6 @@ SELECT AI.GENERATE_BOOL( - prompt => (`string_col`, ' is the same as ', `string_col`), + prompt => STRUCT(`string_col` AS _field_1, ' is the same as ' AS _field_2, `string_col` AS _field_3), endpoint => 'gemini-2.5-flash' ) AS `result` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool_with_connection_id/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool_with_connection_id/out.sql index 8f501a2cc292..fd2791f66ce7 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool_with_connection_id/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool_with_connection_id/out.sql @@ -1,6 +1,6 @@ SELECT AI.GENERATE_BOOL( - prompt => (`string_col`, ' is the same as ', `string_col`), + prompt => STRUCT(`string_col` AS _field_1, ' is the same as ' AS _field_2, `string_col` AS _field_3), connection_id => 'bigframes-dev.us.bigframes-default-connection', endpoint => 'gemini-2.5-flash' ) AS `result` diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool_with_model_param/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool_with_model_param/out.sql index 985f5bb255d7..3ba4027bf1dc 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool_with_model_param/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool_with_model_param/out.sql @@ -1,6 +1,6 @@ SELECT AI.GENERATE_BOOL( - prompt => (`string_col`, ' is the same as ', `string_col`), - model_params => JSON '{}' + prompt => STRUCT(`string_col` AS _field_1, ' is the same as ' AS _field_2, `string_col` AS _field_3), + model_params => '{}' ) AS `result` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double/out.sql index 3aed8986e179..e0ec997baf68 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double/out.sql @@ -1,6 +1,6 @@ SELECT AI.GENERATE_DOUBLE( - prompt => (`string_col`, ' is the same as ', `string_col`), + prompt => STRUCT(`string_col` AS _field_1, ' is the same as ' AS _field_2, `string_col` AS _field_3), endpoint => 'gemini-2.5-flash' ) AS `result` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double_with_connection_id/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double_with_connection_id/out.sql index 19b8c18eec14..42cbbd2e3346 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double_with_connection_id/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double_with_connection_id/out.sql @@ -1,6 +1,6 @@ SELECT AI.GENERATE_DOUBLE( - prompt => (`string_col`, ' is the same as ', `string_col`), + prompt => STRUCT(`string_col` AS _field_1, ' is the same as ' AS _field_2, `string_col` AS _field_3), connection_id => 'bigframes-dev.us.bigframes-default-connection', endpoint => 'gemini-2.5-flash' ) AS `result` diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double_with_model_param/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double_with_model_param/out.sql index 854acc386739..5e4ff22c6812 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double_with_model_param/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double_with_model_param/out.sql @@ -1,6 +1,6 @@ SELECT AI.GENERATE_DOUBLE( - prompt => (`string_col`, ' is the same as ', `string_col`), - model_params => JSON '{}' + prompt => STRUCT(`string_col` AS _field_1, ' is the same as ' AS _field_2, `string_col` AS _field_3), + model_params => '{}' ) AS `result` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int/out.sql index 1ea5d0355cc9..0b0643f0d6c6 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int/out.sql @@ -1,6 +1,6 @@ SELECT AI.GENERATE_INT( - prompt => (`string_col`, ' is the same as ', `string_col`), + prompt => STRUCT(`string_col` AS _field_1, ' is the same as ' AS _field_2, `string_col` AS _field_3), endpoint => 'gemini-2.5-flash' ) AS `result` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int_with_connection_id/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int_with_connection_id/out.sql index b99a8e9a207e..38f099fe9622 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int_with_connection_id/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int_with_connection_id/out.sql @@ -1,6 +1,6 @@ SELECT AI.GENERATE_INT( - prompt => (`string_col`, ' is the same as ', `string_col`), + prompt => STRUCT(`string_col` AS _field_1, ' is the same as ' AS _field_2, `string_col` AS _field_3), connection_id => 'bigframes-dev.us.bigframes-default-connection', endpoint => 'gemini-2.5-flash' ) AS `result` diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int_with_model_param/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int_with_model_param/out.sql index fb3c9c001013..97c9099346ac 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int_with_model_param/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int_with_model_param/out.sql @@ -1,6 +1,6 @@ SELECT AI.GENERATE_INT( - prompt => (`string_col`, ' is the same as ', `string_col`), - model_params => JSON '{}' + prompt => STRUCT(`string_col` AS _field_1, ' is the same as ' AS _field_2, `string_col` AS _field_3), + model_params => '{}' ) AS `result` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_with_connection_id/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_with_connection_id/out.sql index b122d97b0617..59527b132077 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_with_connection_id/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_with_connection_id/out.sql @@ -1,6 +1,6 @@ SELECT AI.GENERATE( - prompt => (`string_col`, ' is the same as ', `string_col`), + prompt => STRUCT(`string_col` AS _field_1, ' is the same as ' AS _field_2, `string_col` AS _field_3), connection_id => 'bigframes-dev.us.bigframes-default-connection', endpoint => 'gemini-2.5-flash' ) AS `result` diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_with_model_param/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_with_model_param/out.sql index 9d818b8c0cc9..ea3a252ac5c9 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_with_model_param/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_with_model_param/out.sql @@ -1,6 +1,6 @@ SELECT AI.GENERATE( - prompt => (`string_col`, ' is the same as ', `string_col`), - model_params => JSON '{}' + prompt => STRUCT(`string_col` AS _field_1, ' is the same as ' AS _field_2, `string_col` AS _field_3), + model_params => '{}' ) AS `result` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_with_output_schema/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_with_output_schema/out.sql index 44abe7085c4e..409413be0048 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_with_output_schema/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_with_output_schema/out.sql @@ -1,6 +1,7 @@ SELECT AI.GENERATE( - prompt => (`string_col`, ' is the same as ', `string_col`), + prompt => STRUCT(`string_col` AS _field_1, ' is the same as ' AS _field_2, `string_col` AS _field_3), + connection_id => NULL, endpoint => 'gemini-2.5-flash', output_schema => 'x INT64, y FLOAT64' ) AS `result` diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/None/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/None/out.sql index 7696a12c5893..26e768d59401 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/None/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/None/out.sql @@ -1,6 +1,7 @@ SELECT AI.IF( - prompt => (`string_col`, ' is the same as ', `string_col`), + prompt => STRUCT(`string_col` AS _field_1, ' is the same as ' AS _field_2, `string_col` AS _field_3), + connection_id => NULL, optimization_mode => 'MINIMIZE_COST', max_error_ratio => 0.5 ) AS `result` diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/bigframes-dev.us.bigframes-default-connection/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/bigframes-dev.us.bigframes-default-connection/out.sql index dc8707487b54..c9578158d8ac 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/bigframes-dev.us.bigframes-default-connection/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/bigframes-dev.us.bigframes-default-connection/out.sql @@ -1,6 +1,6 @@ SELECT AI.IF( - prompt => (`string_col`, ' is the same as ', `string_col`), + prompt => STRUCT(`string_col` AS _field_1, ' is the same as ' AS _field_2, `string_col` AS _field_3), connection_id => 'bigframes-dev.us.bigframes-default-connection', optimization_mode => 'MINIMIZE_COST', max_error_ratio => 0.5 diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if_with_endpoint/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if_with_endpoint/out.sql index 5074584bd72d..608feab517f9 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if_with_endpoint/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if_with_endpoint/out.sql @@ -1,6 +1,6 @@ SELECT AI.IF( - prompt => (`string_col`, ' is the same as ', `string_col`), + prompt => STRUCT(`string_col` AS _field_1, ' is the same as ' AS _field_2, `string_col` AS _field_3), endpoint => 'gemini-2.5-flash' ) AS `result` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/None/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/None/out.sql index 6a16276734ee..8270a9f2e300 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/None/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/None/out.sql @@ -1,3 +1,6 @@ SELECT - AI.SCORE(prompt => (`string_col`, ' is the same as ', `string_col`)) AS `result` + AI.SCORE( + prompt => STRUCT(`string_col` AS _field_1, ' is the same as ' AS _field_2, `string_col` AS _field_3), + connection_id => NULL + ) AS `result` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/bigframes-dev.us.bigframes-default-connection/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/bigframes-dev.us.bigframes-default-connection/out.sql index 92de7cdcdc65..e32d9f1cd348 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/bigframes-dev.us.bigframes-default-connection/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/bigframes-dev.us.bigframes-default-connection/out.sql @@ -1,6 +1,6 @@ SELECT AI.SCORE( - prompt => (`string_col`, ' is the same as ', `string_col`), + prompt => STRUCT(`string_col` AS _field_1, ' is the same as ' AS _field_2, `string_col` AS _field_3), connection_id => 'bigframes-dev.us.bigframes-default-connection' ) AS `result` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score_with_endpoint_and_max_error_ratio/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score_with_endpoint_and_max_error_ratio/out.sql index d65590d0b66d..a266b8bbf78a 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score_with_endpoint_and_max_error_ratio/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score_with_endpoint_and_max_error_ratio/out.sql @@ -1,6 +1,6 @@ SELECT AI.SCORE( - prompt => (`string_col`, ' is the same as ', `string_col`), + prompt => STRUCT(`string_col` AS _field_1, ' is the same as ' AS _field_2, `string_col` AS _field_3), endpoint => 'gemini-2.5-flash', max_error_ratio => 0.5 ) AS `result` diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_similarity/None/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_similarity/None/out.sql index 1df70aaf18e3..f96272f06813 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_similarity/None/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_similarity/None/out.sql @@ -1,3 +1,8 @@ SELECT - AI.SIMILARITY(content1 => `string_col`, content2 => `string_col`, endpoint => 'text-embedding-005') AS `result` + AI.SIMILARITY( + content1 => `string_col`, + content2 => `string_col`, + endpoint => 'text-embedding-005', + connection_id => NULL + ) AS `result` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_similarity_with_model_param/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_similarity_with_model_param/out.sql index 5173ac43bd96..ec1e68576551 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_similarity_with_model_param/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_similarity_with_model_param/out.sql @@ -3,6 +3,6 @@ SELECT content1 => `string_col`, content2 => `string_col`, endpoint => 'text-embedding-005', - model_params => JSON '{"outputDimensionality": 256}' + model_params => '{"outputDimensionality": 256}' ) AS `result` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py index 57c524908607..e2fc12711792 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py @@ -25,82 +25,90 @@ CONNECTION_ID = "bigframes-dev.us.bigframes-default-connection" +def _construct_prompt(col_name, context): + import bigframes.core.expression as ex + elements = [] + for elem in context: + if elem is None: + elements.append(col_name) + else: + elements.append(ex.const(elem)) + + struct_names = tuple(f"_field_{i+1}" for i in range(len(elements))) + return ops.StructOp(column_names=struct_names).as_expr(*elements) + + def test_ai_generate(scalar_types_df: dataframe.DataFrame, snapshot): col_name = "string_col" + prompt_expr = _construct_prompt(col_name, (None, " is the same as ", None)) - op = ops.AIGenerate( - prompt_context=(None, " is the same as ", None), + op = ops.googlesql.AIGenerateOp() + expr = op.as_expr( + prompt_expr, endpoint="gemini-2.5-flash", request_type="SHARED", ) - sql = utils._apply_ops_to_sql( - scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] - ) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") def test_ai_generate_with_connection_id(scalar_types_df: dataframe.DataFrame, snapshot): col_name = "string_col" + prompt_expr = _construct_prompt(col_name, (None, " is the same as ", None)) - op = ops.AIGenerate( - prompt_context=(None, " is the same as ", None), + op = ops.googlesql.AIGenerateOp() + expr = op.as_expr( + prompt_expr, connection_id=CONNECTION_ID, endpoint="gemini-2.5-flash", ) - sql = utils._apply_ops_to_sql( - scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] - ) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") def test_ai_generate_with_output_schema(scalar_types_df: dataframe.DataFrame, snapshot): col_name = "string_col" + prompt_expr = _construct_prompt(col_name, (None, " is the same as ", None)) - op = ops.AIGenerate( - prompt_context=(None, " is the same as ", None), + output_schema_str = "x INT64, y FLOAT64" + op = ops.googlesql.AIGenerateOp(output_schema=output_schema_str) + expr = op.as_expr( + prompt_expr, connection_id=None, endpoint="gemini-2.5-flash", - output_schema="x INT64, y FLOAT64", - ) - - sql = utils._apply_ops_to_sql( - scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] + output_schema=output_schema_str, ) + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") def test_ai_generate_with_model_param(scalar_types_df: dataframe.DataFrame, snapshot): col_name = "string_col" + prompt_expr = _construct_prompt(col_name, (None, " is the same as ", None)) - op = ops.AIGenerate( - prompt_context=(None, " is the same as ", None), + op = ops.googlesql.AIGenerateOp() + expr = op.as_expr( + prompt_expr, model_params=json.dumps(dict()), ) - sql = utils._apply_ops_to_sql( - scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] - ) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") def test_ai_generate_bool(scalar_types_df: dataframe.DataFrame, snapshot): col_name = "string_col" + prompt_expr = _construct_prompt(col_name, (None, " is the same as ", None)) - op = ops.AIGenerateBool( - prompt_context=(None, " is the same as ", None), + expr = ops.googlesql.AI_GENERATE_BOOL.as_expr( + prompt_expr, endpoint="gemini-2.5-flash", ) - sql = utils._apply_ops_to_sql( - scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] - ) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") @@ -108,17 +116,15 @@ def test_ai_generate_bool_with_connection_id( scalar_types_df: dataframe.DataFrame, snapshot ): col_name = "string_col" + prompt_expr = _construct_prompt(col_name, (None, " is the same as ", None)) - op = ops.AIGenerateBool( - prompt_context=(None, " is the same as ", None), + expr = ops.googlesql.AI_GENERATE_BOOL.as_expr( + prompt_expr, connection_id=CONNECTION_ID, endpoint="gemini-2.5-flash", ) - sql = utils._apply_ops_to_sql( - scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] - ) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") @@ -126,32 +132,27 @@ def test_ai_generate_bool_with_model_param( scalar_types_df: dataframe.DataFrame, snapshot ): col_name = "string_col" + prompt_expr = _construct_prompt(col_name, (None, " is the same as ", None)) - op = ops.AIGenerateBool( - prompt_context=(None, " is the same as ", None), + expr = ops.googlesql.AI_GENERATE_BOOL.as_expr( + prompt_expr, model_params=json.dumps(dict()), ) - sql = utils._apply_ops_to_sql( - scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] - ) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") def test_ai_generate_int(scalar_types_df: dataframe.DataFrame, snapshot): col_name = "string_col" + prompt_expr = _construct_prompt(col_name, (None, " is the same as ", None)) - op = ops.AIGenerateInt( - # The prompt does not make semantic sense but we only care about syntax correctness. - prompt_context=(None, " is the same as ", None), + expr = ops.googlesql.AI_GENERATE_INT.as_expr( + prompt_expr, endpoint="gemini-2.5-flash", ) - sql = utils._apply_ops_to_sql( - scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] - ) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") @@ -159,18 +160,15 @@ def test_ai_generate_int_with_connection_id( scalar_types_df: dataframe.DataFrame, snapshot ): col_name = "string_col" + prompt_expr = _construct_prompt(col_name, (None, " is the same as ", None)) - op = ops.AIGenerateInt( - # The prompt does not make semantic sense but we only care about syntax correctness. - prompt_context=(None, " is the same as ", None), + expr = ops.googlesql.AI_GENERATE_INT.as_expr( + prompt_expr, connection_id=CONNECTION_ID, endpoint="gemini-2.5-flash", ) - sql = utils._apply_ops_to_sql( - scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] - ) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") @@ -178,33 +176,27 @@ def test_ai_generate_int_with_model_param( scalar_types_df: dataframe.DataFrame, snapshot ): col_name = "string_col" + prompt_expr = _construct_prompt(col_name, (None, " is the same as ", None)) - op = ops.AIGenerateInt( - # The prompt does not make semantic sense but we only care about syntax correctness. - prompt_context=(None, " is the same as ", None), + expr = ops.googlesql.AI_GENERATE_INT.as_expr( + prompt_expr, model_params=json.dumps(dict()), ) - sql = utils._apply_ops_to_sql( - scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] - ) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") def test_ai_generate_double(scalar_types_df: dataframe.DataFrame, snapshot): col_name = "string_col" + prompt_expr = _construct_prompt(col_name, (None, " is the same as ", None)) - op = ops.AIGenerateDouble( - # The prompt does not make semantic sense but we only care about syntax correctness. - prompt_context=(None, " is the same as ", None), + expr = ops.googlesql.AI_GENERATE_DOUBLE.as_expr( + prompt_expr, endpoint="gemini-2.5-flash", ) - sql = utils._apply_ops_to_sql( - scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] - ) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") @@ -212,18 +204,15 @@ def test_ai_generate_double_with_connection_id( scalar_types_df: dataframe.DataFrame, snapshot ): col_name = "string_col" + prompt_expr = _construct_prompt(col_name, (None, " is the same as ", None)) - op = ops.AIGenerateDouble( - # The prompt does not make semantic sense but we only care about syntax correctness. - prompt_context=(None, " is the same as ", None), + expr = ops.googlesql.AI_GENERATE_DOUBLE.as_expr( + prompt_expr, connection_id=CONNECTION_ID, endpoint="gemini-2.5-flash", ) - sql = utils._apply_ops_to_sql( - scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] - ) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") @@ -231,54 +220,51 @@ def test_ai_generate_double_with_model_param( scalar_types_df: dataframe.DataFrame, snapshot ): col_name = "string_col" + prompt_expr = _construct_prompt(col_name, (None, " is the same as ", None)) - op = ops.AIGenerateDouble( - # The prompt does not make semantic sense but we only care about syntax correctness. - prompt_context=(None, " is the same as ", None), + expr = ops.googlesql.AI_GENERATE_DOUBLE.as_expr( + prompt_expr, model_params=json.dumps(dict()), ) - sql = utils._apply_ops_to_sql( - scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] - ) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") def test_ai_embed(scalar_types_df: dataframe.DataFrame, snapshot): col_name = "string_col" - op = ops.AIEmbed( + expr = ops.googlesql.AI_EMBED.as_expr( + col_name, endpoint="text-embedding-005", ) - sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"]) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") def test_ai_embed_with_connection_id(scalar_types_df: dataframe.DataFrame, snapshot): col_name = "string_col" - op = ops.AIEmbed( + expr = ops.googlesql.AI_EMBED.as_expr( + col_name, endpoint="text-embedding-005", connection_id=CONNECTION_ID, ) - sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"]) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") def test_ai_embed_with_model(scalar_types_df: dataframe.DataFrame, snapshot): col_name = "string_col" - op = ops.AIEmbed( + expr = ops.googlesql.AI_EMBED.as_expr( + col_name, model="embeddinggemma-300m", ) - sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"]) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") @@ -287,48 +273,44 @@ def test_ai_embed_with_task_type_and_title( ): col_name = "string_col" - op = ops.AIEmbed( + expr = ops.googlesql.AI_EMBED.as_expr( + col_name, endpoint="text-embedding-005", task_type="RETRIEVAL_DOCUMENT", title="My Document", model_params=json.dumps({"outputDimensionality": 256}), ) - sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"]) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") @pytest.mark.parametrize("connection_id", [None, CONNECTION_ID]) def test_ai_if(scalar_types_df: dataframe.DataFrame, snapshot, connection_id): col_name = "string_col" + prompt_expr = _construct_prompt(col_name, (None, " is the same as ", None)) - op = ops.AIIf( - prompt_context=(None, " is the same as ", None), + expr = ops.googlesql.AI_IF.as_expr( + prompt_expr, connection_id=connection_id, optimization_mode="MINIMIZE_COST", max_error_ratio=0.5, ) - sql = utils._apply_ops_to_sql( - scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] - ) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") def test_ai_if_with_endpoint(scalar_types_df: dataframe.DataFrame, snapshot): col_name = "string_col" + prompt_expr = _construct_prompt(col_name, (None, " is the same as ", None)) - op = ops.AIIf( - prompt_context=(None, " is the same as ", None), + expr = ops.googlesql.AI_IF.as_expr( + prompt_expr, endpoint="gemini-2.5-flash", ) - sql = utils._apply_ops_to_sql( - scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] - ) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") @@ -336,44 +318,41 @@ def test_ai_if_with_endpoint(scalar_types_df: dataframe.DataFrame, snapshot): def test_ai_classify(scalar_types_df: dataframe.DataFrame, snapshot, connection_id): col_name = "string_col" - op = ops.AIClassify( - prompt_context=(None,), + expr = ops.googlesql.AI_CLASSIFY.as_expr( + col_name, categories=("greeting", "rejection"), connection_id=connection_id, ) - sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"]) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") def test_ai_classify_with_params(scalar_types_df: dataframe.DataFrame, snapshot): col_name = "string_col" - op = ops.AIClassify( - prompt_context=(None,), + expr = ops.googlesql.AI_CLASSIFY.as_expr( + col_name, categories=("greeting", "rejection"), - examples=(("hi", "greeting"), ("bye", "rejection")), + examples=[{"input": "hi", "output": "greeting"}, {"input": "bye", "output": "rejection"}], endpoint="gemini-2.5-flash", max_error_ratio=0.1, ) - sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"]) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") def test_ai_classify_with_output_mode(scalar_types_df: dataframe.DataFrame, snapshot): col_name = "string_col" - op = ops.AIClassify( - prompt_context=(None,), + expr = ops.googlesql.AI_CLASSIFY.as_expr( + col_name, categories=("greeting", "rejection"), output_mode="multi", ) - sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"]) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") @@ -382,35 +361,32 @@ def test_ai_classify_multi_with_list_examples( ): col_name = "string_col" - examples = ( - ("hi", ("greeting", "positive")), - ("bye", ("rejection", "negative")), - ) - op = ops.AIClassify( - prompt_context=(None,), + examples = [ + {"input": "hi", "output": ["greeting", "positive"]}, + {"input": "bye", "output": ["rejection", "negative"]}, + ] + expr = ops.googlesql.AI_CLASSIFY.as_expr( + col_name, categories=("greeting", "rejection"), examples=examples, output_mode="multi", ) - sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"]) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") @pytest.mark.parametrize("connection_id", [None, CONNECTION_ID]) def test_ai_score(scalar_types_df: dataframe.DataFrame, snapshot, connection_id): col_name = "string_col" + prompt_expr = _construct_prompt(col_name, (None, " is the same as ", None)) - op = ops.AIScore( - prompt_context=(None, " is the same as ", None), + expr = ops.googlesql.AI_SCORE.as_expr( + prompt_expr, connection_id=connection_id, ) - sql = utils._apply_ops_to_sql( - scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] - ) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") @@ -418,17 +394,15 @@ def test_ai_score_with_endpoint_and_max_error_ratio( scalar_types_df: dataframe.DataFrame, snapshot ): col_name = "string_col" + prompt_expr = _construct_prompt(col_name, (None, " is the same as ", None)) - op = ops.AIScore( - prompt_context=(None, " is the same as ", None), + expr = ops.googlesql.AI_SCORE.as_expr( + prompt_expr, endpoint="gemini-2.5-flash", max_error_ratio=0.5, ) - sql = utils._apply_ops_to_sql( - scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] - ) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") @@ -436,42 +410,39 @@ def test_ai_score_with_endpoint_and_max_error_ratio( def test_ai_similarity(scalar_types_df: dataframe.DataFrame, snapshot, connection_id): col_name = "string_col" - op = ops.AISimilarity( + expr = ops.googlesql.AI_SIMILARITY.as_expr( + content1=col_name, + content2=col_name, endpoint="text-embedding-005", connection_id=connection_id, ) - sql = utils._apply_ops_to_sql( - scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] - ) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") def test_ai_similarity_with_model(scalar_types_df: dataframe.DataFrame, snapshot): col_name = "string_col" - op = ops.AISimilarity( + expr = ops.googlesql.AI_SIMILARITY.as_expr( + content1=col_name, + content2=col_name, model="embeddinggemma-300m", ) - sql = utils._apply_ops_to_sql( - scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] - ) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") def test_ai_similarity_with_model_param(scalar_types_df: dataframe.DataFrame, snapshot): col_name = "string_col" - op = ops.AISimilarity( + expr = ops.googlesql.AI_SIMILARITY.as_expr( + content1=col_name, + content2=col_name, endpoint="text-embedding-005", model_params=json.dumps({"outputDimensionality": 256}), ) - sql = utils._apply_ops_to_sql( - scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] - ) - + sql = utils._apply_ops_to_sql(scalar_types_df, [expr], ["result"]) snapshot.assert_match(sql, "out.sql") diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_googlesql_apply_op.py b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_googlesql_apply_op.py new file mode 100644 index 000000000000..e7614bffe2b8 --- /dev/null +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_googlesql_apply_op.py @@ -0,0 +1,87 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import bigframes.dataframe as dataframe +from bigframes import dtypes +from bigframes.operations.googlesql import ArgSpec, GoogleSqlScalarOp, apply_op + +# Define standard Google SQL operations to test with +GREATEST_OP = GoogleSqlScalarOp( + sql_name="GREATEST", + args=(ArgSpec(arg_name="x"), ArgSpec(arg_name="y")), + signature=lambda x, y: dtypes.INT_DTYPE, +) + +CONCAT_OP = GoogleSqlScalarOp( + sql_name="CONCAT", + args=(ArgSpec(arg_name="a"), ArgSpec(arg_name="b")), + signature=lambda a, b: dtypes.STRING_DTYPE, +) + + +def test_apply_op_positional_series(scalar_types_df: dataframe.DataFrame): + s1 = scalar_types_df["int64_col"] + s2 = scalar_types_df["int64_too"] + + result = apply_op(GREATEST_OP, args=(s1, s2)) + + # Compile the resulting Series' underlying ArrayValue to SQL + array_value = result._block.expr + sql = array_value.session._executor.to_sql(array_value, enable_cache=False) + + assert "GREATEST(" in sql + assert "`int64_col`" in sql + assert "`int64_too`" in sql + + +def test_apply_op_keyword_series(scalar_types_df: dataframe.DataFrame): + s1 = scalar_types_df["int64_col"] + s2 = scalar_types_df["int64_too"] + + result = apply_op(GREATEST_OP, kwargs={"x": s1, "y": s2}) + + array_value = result._block.expr + sql = array_value.session._executor.to_sql(array_value, enable_cache=False) + + assert "GREATEST(" in sql + assert "x => `int64_col`" in sql + assert "y => `int64_too`" in sql + + +def test_apply_op_mixed_series_and_literal(scalar_types_df: dataframe.DataFrame): + s1 = scalar_types_df["int64_col"] + + result = apply_op(GREATEST_OP, args=(s1, 15)) + + array_value = result._block.expr + sql = array_value.session._executor.to_sql(array_value, enable_cache=False) + + assert "GREATEST(" in sql + assert "`int64_col`" in sql + assert "15" in sql + + +def test_apply_op_string_concat(scalar_types_df: dataframe.DataFrame): + s1 = scalar_types_df["string_col"] + + result = apply_op(CONCAT_OP, args=(s1, "world")) + + array_value = result._block.expr + sql = array_value.session._executor.to_sql(array_value, enable_cache=False) + + assert "CONCAT(" in sql + assert "`string_col`" in sql + assert "'world'" in sql diff --git a/packages/bigframes/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py b/packages/bigframes/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py index e47164f6c469..3fd9d8a787ad 100644 --- a/packages/bigframes/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/packages/bigframes/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py @@ -1140,9 +1140,6 @@ def visit_AIEmbed(self, op, **kwargs): def visit_AIIf(self, op, **kwargs): return sge.func("AI.IF", *self._compile_ai_args(**kwargs)) - def visit_AIClassify(self, op, **kwargs): - return sge.func("AI.CLASSIFY", *self._compile_ai_args(**kwargs)) - def visit_AIScore(self, op, **kwargs): return sge.func("AI.SCORE", *self._compile_ai_args(**kwargs)) diff --git a/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py b/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py index fcd97c6f61b2..b585e0a6ae56 100644 --- a/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py +++ b/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py @@ -149,26 +149,7 @@ def dtype(self) -> dt.Struct: return dt.bool -@public -class AIClassify(Value): - """Generate categories based on the prompt""" - input: Value - categories: Value[dt.Array[dt.String]] - examples: Optional[Value] - connection_id: Optional[Value[dt.String]] - endpoint: Optional[Value[dt.String]] - output_mode: Optional[Value[dt.String]] - optimization_mode: Optional[Value[dt.String]] - max_error_ratio: Optional[Value[dt.Float64]] - - shape = rlz.shape_like("input") - - @attribute - def dtype(self) -> dt.DataType: - if self.output_mode is not None: - return dt.Array(dt.string) - return dt.string @public