Skip to content

Commit f28266d

Browse files
committed
fix(cli/google): fix streaming responses on google models
1 parent 2b6deba commit f28266d

2 files changed

Lines changed: 32 additions & 71 deletions

File tree

crates/rullm-core/src/providers/google.rs

Lines changed: 29 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::types::{
55
ChatCompletion, ChatMessage, ChatRequest, ChatResponse, ChatRole, ChatStreamEvent, LlmProvider,
66
StreamConfig, StreamResult, TokenUsage,
77
};
8+
use crate::utils::sse::sse_lines;
89
use futures::StreamExt;
910
use std::collections::HashMap;
1011

@@ -405,7 +406,7 @@ impl ChatCompletion for GoogleProvider {
405406

406407
// Build the URL with API key for streaming endpoint
407408
let url = format!(
408-
"{}/models/{}:streamGenerateContent?key={}",
409+
"{}/models/{}:streamGenerateContent?alt=sse&key={}",
409410
self.config.base_url(),
410411
model,
411412
self.config.api_key()
@@ -426,6 +427,11 @@ impl ChatCompletion for GoogleProvider {
426427
}
427428
}
428429

430+
header_map.insert(
431+
reqwest::header::ACCEPT,
432+
reqwest::header::HeaderValue::from_static("text/event-stream"),
433+
);
434+
429435
let response_future = client
430436
.post(&url)
431437
.headers(header_map)
@@ -471,96 +477,49 @@ impl ChatCompletion for GoogleProvider {
471477
}
472478
};
473479

474-
// Get the byte stream and parse newline-delimited JSON chunks
475-
let mut byte_stream = response.bytes_stream();
476-
let mut buffer = String::new();
477-
478-
while let Some(chunk_result) = byte_stream.next().await {
479-
match chunk_result {
480-
Ok(bytes) => {
481-
// Add new bytes to buffer
482-
match std::str::from_utf8(&bytes) {
483-
Ok(text) => {
484-
buffer.push_str(text);
485-
486-
// Process complete lines (Google uses newline-delimited JSON)
487-
while let Some(newline_pos) = buffer.find('\n') {
488-
let line = buffer[..newline_pos].trim().to_string();
489-
buffer.drain(..newline_pos + 1);
490-
491-
// Skip empty lines
492-
if line.is_empty() {
493-
continue;
494-
}
495-
496-
// Parse the JSON chunk
497-
match serde_json::from_str::<serde_json::Value>(&line) {
498-
Ok(chunk) => {
499-
// Extract content from candidates[0].content.parts[].text
500-
if let Some(candidates) = chunk["candidates"].as_array() {
501-
if let Some(first_candidate) = candidates.first() {
502-
if let Some(content) = first_candidate.get("content") {
503-
if let Some(parts) = content["parts"].as_array() {
504-
for part in parts {
505-
if let Some(text) = part["text"].as_str() {
506-
yield Ok(ChatStreamEvent::Token(text.to_string()));
507-
}
508-
}
509-
}
480+
// Google now returns Server-Sent Events (SSE) for streaming responses.
481+
// Leverage the shared `sse_lines` util to properly parse the stream.
482+
483+
let byte_stream = response.bytes_stream();
484+
let mut sse_stream = sse_lines(byte_stream);
485+
486+
while let Some(event_res) = sse_stream.next().await {
487+
match event_res {
488+
Ok(line) => {
489+
// Each SSE `data:` line is a JSON object
490+
match serde_json::from_str::<serde_json::Value>(&line) {
491+
Ok(chunk) => {
492+
if let Some(candidates) = chunk["candidates"].as_array() {
493+
if let Some(first_candidate) = candidates.first() {
494+
if let Some(content) = first_candidate.get("content") {
495+
if let Some(parts) = content["parts"].as_array() {
496+
for part in parts {
497+
if let Some(text) = part["text"].as_str() {
498+
yield Ok(ChatStreamEvent::Token(text.to_string()));
510499
}
511500
}
512501
}
513502
}
514-
Err(e) => {
515-
yield Err(LlmError::serialization(
516-
format!("Failed to parse chunk JSON: {e}"),
517-
Box::new(e),
518-
));
519-
return;
520-
}
521503
}
522504
}
523505
}
524506
Err(e) => {
525507
yield Err(LlmError::serialization(
526-
"Invalid UTF-8 in response stream",
508+
format!("Failed to parse chunk JSON: {e}"),
527509
Box::new(e),
528510
));
529511
return;
530512
}
531513
}
532514
}
533515
Err(e) => {
534-
yield Err(LlmError::network(format!("Stream error: {e}")));
516+
// Propagate errors from SSE parser/network
517+
yield Err(e);
535518
return;
536519
}
537520
}
538521
}
539522

540-
// Process any remaining content in buffer
541-
if !buffer.trim().is_empty() {
542-
match serde_json::from_str::<serde_json::Value>(buffer.trim()) {
543-
Ok(chunk) => {
544-
if let Some(candidates) = chunk["candidates"].as_array() {
545-
if let Some(first_candidate) = candidates.first() {
546-
if let Some(content) = first_candidate.get("content") {
547-
if let Some(parts) = content["parts"].as_array() {
548-
for part in parts {
549-
if let Some(text) = part["text"].as_str() {
550-
yield Ok(ChatStreamEvent::Token(text.to_string()));
551-
}
552-
}
553-
}
554-
}
555-
}
556-
}
557-
}
558-
Err(_) => {
559-
// Ignore parse errors for trailing content that might not be complete JSON
560-
}
561-
}
562-
}
563-
564523
// Emit Done event when streaming completes
565524
yield Ok(ChatStreamEvent::Done);
566525
})

crates/rullm-core/src/utils/sse.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ where
7575
// Add new bytes to buffer
7676
match std::str::from_utf8(&bytes) {
7777
Ok(text) => {
78-
self.buffer.push_str(text);
78+
// Normalize CRLF to LF to handle Windows-style/HTTP CRLF delimiters
79+
let normalized = text.replace("\r\n", "\n");
80+
self.buffer.push_str(&normalized);
7981
// Continue loop to try parsing again
8082
}
8183
Err(e) => {

0 commit comments

Comments
 (0)