Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/grok_search/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ def retry_multiplier(self) -> float:
def retry_max_wait(self) -> int:
return int(os.getenv("GROK_RETRY_MAX_WAIT", "10"))

@property
def output_cleanup_enabled(self) -> bool:
raw = os.getenv("GROK_OUTPUT_CLEANUP")
if raw is None:
raw = os.getenv("GROK_FILTER_THINK_TAGS", "true")
return raw.lower() in ("true", "1", "yes")

@property
def grok_api_url(self) -> str:
url = os.getenv("GROK_API_URL")
Expand Down Expand Up @@ -184,6 +191,7 @@ def get_config_info(self) -> dict:
"GROK_API_KEY": api_key_masked,
"GROK_MODEL": self.grok_model,
"GROK_DEBUG": self.debug_enabled,
"GROK_OUTPUT_CLEANUP": self.output_cleanup_enabled,
"GROK_LOG_LEVEL": self.log_level,
"GROK_LOG_DIR": str(self.log_dir),
"TAVILY_API_URL": self.tavily_api_url,
Expand Down
1 change: 0 additions & 1 deletion src/grok_search/logger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
from datetime import datetime
from pathlib import Path
from .config import config

logger = logging.getLogger("grok_search")
Expand Down
39 changes: 39 additions & 0 deletions src/grok_search/planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ class ExecutionOrderOutput(BaseModel):

_ACCUMULATIVE_LIST_PHASES = {"query_decomposition", "tool_selection"}
_MERGE_STRATEGY_PHASE = "search_strategy"
_PHASE_PREDECESSORS = {
"complexity_assessment": "intent_analysis",
"query_decomposition": "complexity_assessment",
"search_strategy": "query_decomposition",
"tool_selection": "search_strategy",
"execution_order": "tool_selection",
}


def _split_csv(value: str) -> list[str]:
Expand Down Expand Up @@ -126,6 +133,9 @@ def __init__(self):
def get_session(self, session_id: str) -> PlanningSession | None:
return self._sessions.get(session_id)

def reset(self) -> None:
self._sessions.clear()

def process_phase(
self,
phase: str,
Expand All @@ -147,6 +157,35 @@ def process_phase(
if target not in PHASE_NAMES:
return {"error": f"Unknown phase: {target}. Valid: {', '.join(PHASE_NAMES)}"}

if not is_revision:
predecessor = _PHASE_PREDECESSORS.get(target)
if predecessor and predecessor not in session.phases:
return {
"error": f"Phase '{target}' requires '{predecessor}' to be completed first.",
"expected_phase_order": PHASE_NAMES,
"session_id": session.session_id,
"completed_phases": session.completed_phases,
"complexity_level": session.complexity_level,
}

if session.complexity_level == 1 and target in {"search_strategy", "tool_selection", "execution_order"}:
return {
"error": "Level 1 planning completes after query_decomposition.",
"expected_phase_order": PHASE_NAMES,
"session_id": session.session_id,
"completed_phases": session.completed_phases,
"complexity_level": session.complexity_level,
}

if session.complexity_level == 2 and target == "execution_order":
return {
"error": "Level 2 planning completes after tool_selection.",
"expected_phase_order": PHASE_NAMES,
"session_id": session.session_id,
"completed_phases": session.completed_phases,
"complexity_level": session.complexity_level,
}

if target in _ACCUMULATIVE_LIST_PHASES:
if is_revision:
session.phases[target] = PhaseRecord(
Expand Down
179 changes: 144 additions & 35 deletions src/grok_search/providers/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import List, Optional
from tenacity import AsyncRetrying, retry_if_exception, stop_after_attempt, wait_random_exponential
from tenacity.wait import wait_base
from zoneinfo import ZoneInfo
from .base import BaseSearchProvider, SearchResult
from ..utils import search_prompt, fetch_prompt, url_describe_prompt, rank_sources_prompt
from ..logger import log_info
Expand Down Expand Up @@ -125,11 +124,16 @@ def __init__(self, api_url: str, api_key: str, model: str = "grok-4-fast"):
def get_provider_name(self) -> str:
return "Grok"

async def search(self, query: str, platform: str = "", min_results: int = 3, max_results: int = 10, ctx=None) -> List[SearchResult]:
headers = {
def _build_api_headers(self) -> dict:
return {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
"User-Agent": "grok-search-mcp/0.1.0",
}

async def search(self, query: str, platform: str = "", min_results: int = 3, max_results: int = 10, ctx=None) -> List[SearchResult]:
headers = self._build_api_headers()
platform_prompt = ""

if platform:
Expand All @@ -146,18 +150,15 @@ async def search(self, query: str, platform: str = "", min_results: int = 3, max
},
{"role": "user", "content": time_context + query + platform_prompt},
],
"stream": True,
"stream": False,
}

await log_info(ctx, f"platform_prompt: { query + platform_prompt}", config.debug_enabled)

return await self._execute_stream_with_retry(headers, payload, ctx)
return await self._execute_completion_with_retry(headers, payload, ctx)

async def fetch(self, url: str, ctx=None) -> str:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
headers = self._build_api_headers()
payload = {
"model": self.model,
"messages": [
Expand All @@ -167,19 +168,67 @@ async def fetch(self, url: str, ctx=None) -> str:
},
{"role": "user", "content": url + "\n获取该网页内容并返回其结构化Markdown格式" },
],
"stream": True,
"stream": False,
}
return await self._execute_stream_with_retry(headers, payload, ctx)
return await self._execute_completion_with_retry(headers, payload, ctx)

def _extract_content_from_choice(self, choice: dict) -> str:
if not isinstance(choice, dict):
return ""

message = choice.get("message", {})
if isinstance(message, dict):
content = message.get("content", "")
if isinstance(content, str) and content:
return content

delta = choice.get("delta", {})
if isinstance(delta, dict):
content = delta.get("content", "")
if isinstance(content, str) and content:
return content

for key in ("text", "content"):
value = choice.get(key, "")
if isinstance(value, str) and value:
return value

return ""

def _is_empty_placeholder_payload(self, data: dict) -> bool:
if not isinstance(data, dict):
return False

if data.get("choices", object()) is not None:
return False

return all(not str(data.get(key, "")).strip() for key in ("id", "object", "model"))

def _build_placeholder_error(self, headers=None) -> ValueError:
request_id = ""
if headers:
request_id = (
headers.get("x-oneapi-request-id", "")
or headers.get("x-request-id", "")
or headers.get("request-id", "")
).strip()

message = "上游返回了空的占位 completion 帧(choices=null),疑似中转站对 Grok chat/completions 的实现异常"
if request_id:
message += f",request_id={request_id}"
return ValueError(message)

async def _parse_streaming_response(self, response, ctx=None) -> str:
content = ""
full_body_buffer = []

full_body_buffer = []
empty_placeholder_detected = False
response_headers = getattr(response, "headers", None)

async for line in response.aiter_lines():
line = line.strip()
if not line:
continue

full_body_buffer.append(line)

# 兼容 "data: {...}" 和 "data:{...}" 两种 SSE 格式
Expand All @@ -190,24 +239,70 @@ async def _parse_streaming_response(self, response, ctx=None) -> str:
# 去掉 "data:" 前缀,并去除可能的空格
json_str = line[5:].lstrip()
data = json.loads(json_str)
if self._is_empty_placeholder_payload(data):
empty_placeholder_detected = True
continue
choices = data.get("choices", [])
if choices and len(choices) > 0:
delta = choices[0].get("delta", {})
if "content" in delta:
content += delta["content"]
if isinstance(choices, list) and choices:
chunk = self._extract_content_from_choice(choices[0])
if chunk:
content += chunk
except (json.JSONDecodeError, IndexError):
continue

if not content and full_body_buffer:
try:
full_text = "".join(full_body_buffer)
data = json.loads(full_text)
if "choices" in data and len(data["choices"]) > 0:
message = data["choices"][0].get("message", {})
content = message.get("content", "")
if self._is_empty_placeholder_payload(data):
empty_placeholder_detected = True
choices = data.get("choices", [])
if isinstance(choices, list) and choices:
content = self._extract_content_from_choice(choices[0])
except json.JSONDecodeError:
pass


if not content and empty_placeholder_detected:
raise self._build_placeholder_error(response_headers)

await log_info(ctx, f"content: {content}", config.debug_enabled)

return content

async def _parse_completion_response(self, response: httpx.Response, ctx=None) -> str:
content = ""
body_text = response.text or ""

try:
data = response.json()
except Exception:
data = None

if isinstance(data, dict):
if self._is_empty_placeholder_payload(data):
raise self._build_placeholder_error(response.headers)
choices = data.get("choices", [])
if isinstance(choices, list) and choices:
content = self._extract_content_from_choice(choices[0])

if not content and any(line.lstrip().startswith("data:") for line in body_text.splitlines()):
class _LineResponse:
def __init__(self, text: str, headers):
self._lines = text.splitlines()
self.headers = headers

async def aiter_lines(self):
for line in self._lines:
yield line

content = await self._parse_streaming_response(_LineResponse(body_text, response.headers), ctx)

if not content and body_text.strip():
normalized = body_text.lower()
if "<html" in normalized and "login" in normalized:
raise ValueError("API 代理返回了登录页面,请检查认证状态")
raise ValueError("上游返回了无法解析的 completion 响应")

await log_info(ctx, f"content: {content}", config.debug_enabled)

return content
Expand All @@ -233,21 +328,38 @@ async def _execute_stream_with_retry(self, headers: dict, payload: dict, ctx=Non
response.raise_for_status()
return await self._parse_streaming_response(response, ctx)

async def _execute_completion_with_retry(self, headers: dict, payload: dict, ctx=None) -> str:
"""执行带重试机制的非流式 HTTP 请求,兼容 JSON completion 与 SSE 文本响应。"""
timeout = httpx.Timeout(connect=6.0, read=120.0, write=10.0, pool=None)

async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as client:
async for attempt in AsyncRetrying(
stop=stop_after_attempt(config.retry_max_attempts + 1),
wait=_WaitWithRetryAfter(config.retry_multiplier, config.retry_max_wait),
retry=retry_if_exception(_is_retryable_exception),
reraise=True,
):
with attempt:
response = await client.post(
f"{self.api_url}/chat/completions",
headers=headers,
json=payload,
)
response.raise_for_status()
return await self._parse_completion_response(response, ctx)

async def describe_url(self, url: str, ctx=None) -> dict:
"""让 Grok 阅读单个 URL 并返回 title + extracts"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
headers = self._build_api_headers()
payload = {
"model": self.model,
"messages": [
{"role": "system", "content": url_describe_prompt},
{"role": "user", "content": url},
],
"stream": True,
"stream": False,
}
result = await self._execute_stream_with_retry(headers, payload, ctx)
result = await self._execute_completion_with_retry(headers, payload, ctx)
title, extracts = url, ""
for line in result.strip().splitlines():
if line.startswith("Title:"):
Expand All @@ -258,19 +370,16 @@ async def describe_url(self, url: str, ctx=None) -> dict:

async def rank_sources(self, query: str, sources_text: str, total: int, ctx=None) -> list[int]:
"""让 Grok 按查询相关度对信源排序,返回排序后的序号列表"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
headers = self._build_api_headers()
payload = {
"model": self.model,
"messages": [
{"role": "system", "content": rank_sources_prompt},
{"role": "user", "content": f"Query: {query}\n\n{sources_text}"},
],
"stream": True,
"stream": False,
}
result = await self._execute_stream_with_retry(headers, payload, ctx)
result = await self._execute_completion_with_retry(headers, payload, ctx)
order: list[int] = []
seen: set[int] = set()
for token in result.strip().split():
Expand Down
Loading