Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 128 additions & 129 deletions packages/bigframes/bigframes/bigquery/_operations/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from bigframes.core.logging import log_adapter
from bigframes.ml import base as ml_base
from bigframes.ml import core as ml_core
from bigframes.operations import ai_ops, output_schemas
from bigframes.operations import ai_ops, googlesql, output_schemas

PROMPT_TYPE = Union[
str,
Expand Down Expand Up @@ -114,9 +114,6 @@ def generate(
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
"""

prompt_context, series_list = _separate_context_and_series(prompt)
assert len(series_list) > 0

if output_schema is None:
output_schema_str = None
else:
Expand All @@ -126,17 +123,21 @@ def generate(
# Validate user input
output_schemas.parse_sql_fields(output_schema_str)

operator = ai_ops.AIGenerate(
prompt_context=tuple(prompt_context),
connection_id=connection_id,
endpoint=endpoint,
request_type=_upper_optional(request_type),
model_params=json.dumps(model_params) if model_params else None,
output_schema=output_schema_str,
prompt_struct = _construct_prompt_struct(prompt)

op = googlesql.AIGenerateOp(output_schema=output_schema_str)
return googlesql.apply_op(
op,
args=(prompt_struct,),
kwargs={
"connection_id": connection_id,
"endpoint": endpoint,
"request_type": _upper_optional(request_type),
"model_params": json.dumps(model_params) if model_params else None,
"output_schema": output_schema_str,
},
)

return series_list[0]._apply_nary_op(operator, series_list[1:])


@log_adapter.method_logger(custom_base_name="bigquery_ai")
def generate_bool(
Expand Down Expand Up @@ -201,19 +202,19 @@ def generate_bool(
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
"""

prompt_context, series_list = _separate_context_and_series(prompt)
assert len(series_list) > 0

operator = ai_ops.AIGenerateBool(
prompt_context=tuple(prompt_context),
connection_id=connection_id,
endpoint=endpoint,
request_type=_upper_optional(request_type),
model_params=json.dumps(model_params) if model_params else None,
prompt_struct = _construct_prompt_struct(prompt)

return googlesql.apply_op(
googlesql.AI_GENERATE_BOOL,
args=(prompt_struct,),
kwargs={
"connection_id": connection_id,
"endpoint": endpoint,
"request_type": _upper_optional(request_type),
"model_params": json.dumps(model_params) if model_params else None,
},
)

return series_list[0]._apply_nary_op(operator, series_list[1:])


@log_adapter.method_logger(custom_base_name="bigquery_ai")
def generate_int(
Expand Down Expand Up @@ -275,19 +276,19 @@ def generate_int(
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
"""

prompt_context, series_list = _separate_context_and_series(prompt)
assert len(series_list) > 0

operator = ai_ops.AIGenerateInt(
prompt_context=tuple(prompt_context),
connection_id=connection_id,
endpoint=endpoint,
request_type=_upper_optional(request_type),
model_params=json.dumps(model_params) if model_params else None,
prompt_struct = _construct_prompt_struct(prompt)

return googlesql.apply_op(
googlesql.AI_GENERATE_INT,
args=(prompt_struct,),
kwargs={
"connection_id": connection_id,
"endpoint": endpoint,
"request_type": _upper_optional(request_type),
"model_params": json.dumps(model_params) if model_params else None,
},
)

return series_list[0]._apply_nary_op(operator, series_list[1:])


@log_adapter.method_logger(custom_base_name="bigquery_ai")
def generate_double(
Expand Down Expand Up @@ -349,19 +350,19 @@ def generate_double(
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
"""

prompt_context, series_list = _separate_context_and_series(prompt)
assert len(series_list) > 0

operator = ai_ops.AIGenerateDouble(
prompt_context=tuple(prompt_context),
connection_id=connection_id,
endpoint=endpoint,
request_type=_upper_optional(request_type),
model_params=json.dumps(model_params) if model_params else None,
prompt_struct = _construct_prompt_struct(prompt)

return googlesql.apply_op(
googlesql.AI_GENERATE_DOUBLE,
args=(prompt_struct,),
kwargs={
"connection_id": connection_id,
"endpoint": endpoint,
"request_type": _upper_optional(request_type),
"model_params": json.dumps(model_params) if model_params else None,
},
)

return series_list[0]._apply_nary_op(operator, series_list[1:])


@log_adapter.method_logger(custom_base_name="bigquery_ai")
def generate_embedding(
Expand Down Expand Up @@ -751,24 +752,19 @@ def embed(
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
"""

operator = ai_ops.AIEmbed(
endpoint=endpoint,
model=model,
task_type=_upper_optional(task_type),
title=title,
model_params=json.dumps(model_params) if model_params else None,
connection_id=connection_id,
return googlesql.apply_op(
googlesql.AI_EMBED,
args=(content,),
kwargs={
"endpoint": endpoint,
"model": model,
"task_type": _upper_optional(task_type),
"title": title,
"model_params": json.dumps(model_params) if model_params else None,
"connection_id": connection_id,
},
)

if isinstance(content, str):
return series.Series([content])._apply_unary_op(operator)
elif isinstance(content, pd.Series):
return series.Series(content)._apply_unary_op(operator)
elif isinstance(content, series.Series):
return content._apply_unary_op(operator)
else:
raise ValueError(f"Unsupported 'content' parameter type: {type(content)}")


@log_adapter.method_logger(custom_base_name="bigquery_ai")
def if_(
Expand Down Expand Up @@ -824,19 +820,19 @@ def if_(
bigframes.series.Series: A new series of bools.
"""

prompt_context, series_list = _separate_context_and_series(prompt)
assert len(series_list) > 0

operator = ai_ops.AIIf(
prompt_context=tuple(prompt_context),
connection_id=connection_id,
endpoint=endpoint,
optimization_mode=_upper_optional(optimization_mode),
max_error_ratio=max_error_ratio,
prompt_struct = _construct_prompt_struct(prompt)

return googlesql.apply_op(
googlesql.AI_IF,
args=(prompt_struct,),
kwargs={
"connection_id": connection_id,
"endpoint": endpoint,
"optimization_mode": _upper_optional(optimization_mode),
"max_error_ratio": max_error_ratio,
},
)

return series_list[0]._apply_nary_op(operator, series_list[1:])


@log_adapter.method_logger(custom_base_name="bigquery_ai")
def classify(
Expand Down Expand Up @@ -901,30 +897,30 @@ def classify(
bigframes.series.Series: A new series of strings (or a series of arrays of strings if ``output_mode`` is specified).
"""

prompt_context, series_list = _separate_context_and_series(input)
assert len(series_list) > 0

if examples is not None:
example_tuples: Any = tuple(
(ex[0], tuple(ex[1]) if isinstance(ex[1], (list, tuple)) else ex[1])
formatted_examples = [
{
"input": ex[0],
"output": list(ex[1]) if isinstance(ex[1], (list, tuple)) else ex[1],
}
for ex in examples
)
]
else:
example_tuples = None

operator = ai_ops.AIClassify(
prompt_context=tuple(prompt_context),
categories=tuple(categories),
examples=example_tuples,
connection_id=connection_id,
endpoint=endpoint,
output_mode=output_mode,
optimization_mode=_upper_optional(optimization_mode),
max_error_ratio=max_error_ratio,
formatted_examples = None

return googlesql.apply_op(
googlesql.AI_CLASSIFY,
args=(input, categories),
kwargs={
"examples": formatted_examples,
"connection_id": connection_id,
"endpoint": endpoint,
"output_mode": output_mode,
"optimization_mode": _upper_optional(optimization_mode),
"max_error_ratio": max_error_ratio,
},
)

return series_list[0]._apply_nary_op(operator, series_list[1:])


@log_adapter.method_logger(custom_base_name="bigquery_ai")
def score(
Expand Down Expand Up @@ -970,18 +966,18 @@ def score(
bigframes.series.Series: A new series of double (float) values.
"""

prompt_context, series_list = _separate_context_and_series(prompt)
assert len(series_list) > 0
prompt_struct = _construct_prompt_struct(prompt)

operator = ai_ops.AIScore(
prompt_context=tuple(prompt_context),
connection_id=connection_id,
endpoint=endpoint,
max_error_ratio=max_error_ratio,
return googlesql.apply_op(
googlesql.AI_SCORE,
args=(prompt_struct,),
kwargs={
"connection_id": connection_id,
"endpoint": endpoint,
"max_error_ratio": max_error_ratio,
},
)

return series_list[0]._apply_nary_op(operator, series_list[1:])


@log_adapter.method_logger(custom_base_name="bigquery_ai")
def similarity(
Expand Down Expand Up @@ -1026,36 +1022,18 @@ def similarity(
bigframes.series.Series: A new series of FLOAT64 values representing the cosine similarity.
"""

operator = ai_ops.AISimilarity(
endpoint=endpoint,
model=model,
model_params=json.dumps(model_params) if model_params else None,
connection_id=connection_id,
return googlesql.apply_op(
googlesql.AI_SIMILARITY,
kwargs={
"content1": content1,
"content2": content2,
"endpoint": endpoint,
"model": model,
"model_params": json.dumps(model_params) if model_params else None,
"connection_id": connection_id,
},
)

# Find a unifying session for the subsequent operations.
bf_session = None
if isinstance(content1, series.Series):
bf_session = content1._session
elif isinstance(content2, series.Series):
bf_session = content2._session

if isinstance(content1, str) and isinstance(content2, str):
content1 = series.Series([content1], session=bf_session)
return content1._apply_binary_op(content2, operator)
elif isinstance(content1, str):
# content2 must be a series
content2 = convert.to_bf_series(
content2, default_index=None, session=bf_session
)
return content2._apply_binary_op(content1, operator)
else:
# content1 must be a series.
content1 = convert.to_bf_series(
content1, default_index=None, session=bf_session
)
return content1._apply_binary_op(content2, operator)


@log_adapter.method_logger(custom_base_name="bigquery_ai")
def forecast(
Expand Down Expand Up @@ -1246,3 +1224,24 @@ def _upper_optional(value: str | None) -> str | None:
if value is None:
return None
return value.upper()


def _construct_prompt_struct(prompt: PROMPT_TYPE) -> series.Series:
prompt_context, series_list = _separate_context_and_series(prompt)
assert len(series_list) > 0

prompt_elements = []
series_idx = 0
for elem in prompt_context:
if elem is None:
prompt_elements.append(series_list[series_idx])
series_idx += 1
else:
prompt_elements.append(elem)

import bigframes.operations.generic_ops as generic_ops
struct_names = tuple(f"_field_{i+1}" for i in range(len(prompt_elements)))
return googlesql.apply_op(
generic_ops.StructOp(column_names=struct_names),
args=prompt_elements,
)
Loading
Loading