Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions llmsql/inference/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from llmsql.loggers.logging_config import log
from llmsql.utils.inference_utils import _maybe_download, _setup_seed
from llmsql.utils.utils import (
build_all_requests,
choose_prompt_builder,
load_jsonl,
overwrite_jsonl,
Expand Down Expand Up @@ -114,13 +115,10 @@ async def _inference_api_async(

async with aiohttp.ClientSession(headers=headers) as session:

async def process_question(q: dict[str, Any]) -> dict[str, str]:
tbl = tables[q["table_id"]]
example_row = tbl["rows"][0] if tbl["rows"] else []
prompt = prompt_builder(
q["question"], tbl["header"], tbl["types"], example_row
)
# Pre-build all prompts using the shared function
prompts = build_all_requests(questions, tables, prompt_builder)

async def process_question(q: dict[str, Any], prompt: str) -> dict[str, str]:
payload = {
"model": model_name,
"messages": [
Expand Down Expand Up @@ -152,7 +150,7 @@ async def process_question(q: dict[str, Any]) -> dict[str, str]:

return result

tasks = [process_question(q) for q in questions]
tasks = [process_question(q, p) for q, p in zip(questions, prompts)]
for coro in tqdm(
asyncio.as_completed(tasks),
total=len(tasks),
Expand Down
37 changes: 14 additions & 23 deletions llmsql/inference/inference_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from llmsql.loggers.logging_config import log
from llmsql.utils.inference_utils import _maybe_download, _setup_seed
from llmsql.utils.utils import (
build_all_requests,
choose_prompt_builder,
load_jsonl,
overwrite_jsonl,
Expand Down Expand Up @@ -201,37 +202,27 @@ def inference_vllm(

sampling_params = SamplingParams(**sampling_params_args)

# --- build all requests ---
prompts = build_all_requests(
questions,
tables,
prompt_builder,
tokenizer=tokenizer if use_chat_template else None,
use_chat_template=bool(use_chat_template),
)

# --- main inference loop ---
all_results: list[dict[str, str]] = []
total = len(questions)

for batch_start in tqdm(range(0, total, batch_size), desc="Generating"):
batch = questions[batch_start : batch_start + batch_size]

prompts = []
for q in batch:
tbl = tables[q["table_id"]]
example_row = tbl["rows"][0] if tbl["rows"] else []

raw_text = prompt_builder(
q["question"], tbl["header"], tbl["types"], example_row
)

if use_chat_template:
messages = [{"role": "user", "content": raw_text}]

final_prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
else:
final_prompt = raw_text

prompts.append(final_prompt)
batch_prompts = prompts[batch_start : batch_start + batch_size]
batch_questions = questions[batch_start : batch_start + batch_size]

outputs = llm.generate(prompts, sampling_params)
outputs = llm.generate(batch_prompts, sampling_params)

batch_results: list[dict[str, str]] = []
for q, out in zip(batch, outputs, strict=False):
for q, out in zip(batch_questions, outputs, strict=False):
text = out.outputs[0].text
batch_results.append(
{
Expand Down
42 changes: 42 additions & 0 deletions llmsql/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,45 @@ def choose_prompt_builder(
if shots == 5:
return build_prompt_5shot
raise ValueError("shots must be one of {0, 1, 5}")


def build_all_requests(
questions: list[dict],
tables: dict,
prompt_builder: Callable[[str, list[str], list[str], list[str | float | int]], str],
tokenizer=None,
use_chat_template: bool = True,
) -> list[str]:
"""
Build all prompts from questions and tables.

Args:
questions: List of question dicts with 'question' and 'table_id' keys.
tables: Dict mapping table_id to table metadata (with 'header', 'types', 'rows').
prompt_builder: Function to build raw prompt text.
tokenizer: Optional tokenizer with apply_chat_template method.
use_chat_template: Whether to apply chat template (if tokenizer provided).

Returns:
List of final prompts (with chat template applied if requested).
"""
prompts = []
for q in questions:
tbl = tables[q["table_id"]]
example_row = tbl["rows"][0] if tbl["rows"] else []

raw_text = prompt_builder(
q["question"], tbl["header"], tbl["types"], example_row
)

if tokenizer and use_chat_template:
messages = [{"role": "user", "content": raw_text}]
final_prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
else:
final_prompt = raw_text

prompts.append(final_prompt)

return prompts
Loading