-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy patheval_hooks.py
More file actions
64 lines (47 loc) · 1.77 KB
/
eval_hooks.py
File metadata and controls
64 lines (47 loc) · 1.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""
Evaluation hooks and progress reporting for the dev server.
Similar to the JavaScript implementation, this provides callbacks
for reporting progress during evaluation execution.
"""
import asyncio
import json
from collections.abc import Callable
from typing import Any
from ..parameters import ValidatedParameters
class EvalHooks:
"""Hooks provided to eval tasks for progress reporting."""
def __init__(
self,
report_progress: Callable[[dict[str, Any]], None] | None = None,
parameters: ValidatedParameters | None = None,
):
self._report_progress = report_progress
self.parameters = parameters or {}
def report_progress(self, event: dict[str, Any]) -> None:
"""Report progress during task execution."""
if self._report_progress:
self._report_progress(event)
def serialize_sse_event(event: str, data: Any) -> str:
"""
Serialize data into SSE format.
This follows the same format as the SSEClient expects to parse.
"""
if isinstance(data, dict) or isinstance(data, list):
data_str = json.dumps(data)
else:
data_str = str(data)
return f"event: {event}\ndata: {data_str}\n\n"
class SSEQueue:
"""Simple wrapper around asyncio.Queue for SSE events."""
def __init__(self):
self.queue: asyncio.Queue[str | None] = asyncio.Queue()
async def put_event(self, event: str, data: Any) -> None:
"""Add an SSE event to the queue."""
sse_data = serialize_sse_event(event, data)
await self.queue.put(sse_data)
async def close(self) -> None:
"""Signal end of stream."""
await self.queue.put(None)
async def get(self) -> str | None:
"""Get the next event from the queue."""
return await self.queue.get()