diff --git a/.gitignore b/.gitignore index 9bc885a33..4eed25643 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ src/viser/client/build src/viser/client/.nodeenv **/.claude/settings.local.json +.venv diff --git a/docs/source/index.rst b/docs/source/index.rst index 834275808..17a6dc9a4 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -97,4 +97,4 @@ URL (default: ``http://localhost:8080``). :alt: Version icon :target: https://pypi.org/project/viser/ .. |nbsp| unicode:: 0xA0 - :trim: \ No newline at end of file + :trim: diff --git a/docs/source/scene_handles.rst b/docs/source/scene_handles.rst index 012a87022..c844daf34 100644 --- a/docs/source/scene_handles.rst +++ b/docs/source/scene_handles.rst @@ -64,4 +64,4 @@ methods like :func:`viser.ViserServer.add_frame()` or .. autoclass:: viser.RectAreaLightHandle -.. autoclass:: viser.SpotLightHandle \ No newline at end of file +.. autoclass:: viser.SpotLightHandle diff --git a/examples/assets/download_colmap_garden.sh b/examples/assets/download_colmap_garden.sh index 63f93046a..5f2da5bfd 100755 --- a/examples/assets/download_colmap_garden.sh +++ b/examples/assets/download_colmap_garden.sh @@ -8,4 +8,4 @@ gdown "https://drive.google.com/uc?id=1wYHdrgwXPHtREdCjItvt4gqRQGISMade" mkdir -p colmap_garden # shellcheck disable=SC2035 -unzip *.zip && rm *.zip \ No newline at end of file +unzip *.zip && rm *.zip diff --git a/live_plots.py b/live_plots.py new file mode 100644 index 000000000..4080e38bf --- /dev/null +++ b/live_plots.py @@ -0,0 +1,125 @@ +"""Live Plotly Plots in Viser + +Example of creating live-updating Plotly plots in Viser.""" + +import time + +import numpy as np +import plotly.graph_objects as go + +import viser + +# handle the modal plot DONE +# handle the main plot reanchoring +# handle multiple trajectories +# handles number of elements in history DONE +# handle boundary ylims, xlims +# rename functions + + +def create_wave_plot(t: float, wave_type: str = "sin") -> go.Figure: + """Create a wave plot starting at time t.""" + x_data = np.linspace(t, t + 0.1 * np.pi, 50) + if wave_type == "sin": + y_data = np.sin(60 * x_data) * 10 + title = "Sine Wave" + else: + y_data = np.cos(60 * x_data) * 10 + title = "Cosine Wave" + + fig = go.Figure() + fig.add_trace( + go.Scatter( + x=list(x_data), + y=list(10 + y_data), + mode="lines", + line=dict(color="red", width=2), # Thinner line + fill="tozeroy", + fillcolor="rgba(255, 0, 0, 0.2)", + name=wave_type, + ) + ) + fig.add_trace( + go.Scatter( + x=list(x_data), + y=list(y_data), + mode="lines", + line=dict(color="blue", width=2), # Thinner line + # fill="tozeroy", + # fillcolor="rgba(0, 0, 255, 0.2)", + name=wave_type + "_2", + ) + ) + + fig.update_layout( + title=title, + xaxis_title="x", + yaxis_title=f"{wave_type}(x)", + margin=dict(l=20, r=20, t=40, b=20), + showlegend=False, + # yaxis=dict(range=[-15, 15]), + xaxis=dict(autorange=False), + yaxis=dict(autorange=False), + ) + + return fig + + +def main() -> None: + server = viser.ViserServer() + + Nfull = 40 + Nupdate = 100000 + time_step = 0.1 + Nchunk = 1 + + # Create two plots + time_value = 0.0 + sin_plot_handle = server.gui.add_plotly( + figure=create_wave_plot(time_value, "sin"), aspect=0.75 + ) + cos_plot_handle = server.gui.add_plotly( + figure=create_wave_plot(time_value, "cos"), aspect=0.75 + ) + + # while True: + for i in range(Nfull): + print("i", i, "of", Nfull) + sin_plot_handle.figure = create_wave_plot(time_value, "sin") + cos_plot_handle.figure = create_wave_plot(time_value, "cos") + + time.sleep(time_step) + time_value += time_step + + for i in range(Nupdate): + t0 = time.time() + + x_data = time_value + time_step * np.arange(Nchunk) / Nchunk + x_data = np.tile(x_data, (2, 1)) + y_data = 10 * np.sin(5 * x_data) + np.array( + [5 * np.ones(Nchunk), np.zeros(Nchunk)] + ) + + server.gui.plotly_extend_traces( + plotly_element_uuids=[ + cos_plot_handle._impl.uuid, + sin_plot_handle._impl.uuid, + ], + x_data=x_data, + y_data=y_data, + history_length=10, + ) + + print("cos_plot_handle", cos_plot_handle._impl.uuid) + print("sin_plot_handle", sin_plot_handle._impl.uuid) + t1 = time.time() + elapsed = t1 - t0 + print("elapsed", elapsed) + time.sleep(time_step) + time_value += time_step + + input("Press Enter to continue...") + + +if __name__ == "__main__": + main() diff --git a/src/viser/_gui_api.py b/src/viser/_gui_api.py index e14e0770f..86c0a5a06 100644 --- a/src/viser/_gui_api.py +++ b/src/viser/_gui_api.py @@ -17,6 +17,7 @@ Sequence, Tuple, TypeVar, + Union, cast, overload, ) @@ -432,9 +433,9 @@ def configure_theme( if brand_color is not None: assert len(brand_color) in (3, 10) if len(brand_color) == 3: - assert all(map(lambda val: isinstance(val, int), brand_color)), ( - "All channels should be integers." - ) + assert all( + map(lambda val: isinstance(val, int), brand_color) + ), "All channels should be integers." # RGB => HLS. h, l, s = colorsys.rgb_to_hls( @@ -709,28 +710,11 @@ def add_image( handle.image = image return handle - def add_plotly( - self, - figure: go.Figure, - aspect: float = 1.0, - order: float | None = None, - visible: bool = True, - ) -> GuiPlotlyHandle: - """Add a Plotly figure to the GUI. Requires the `plotly` package to be - installed. - - Args: - figure: Plotly figure to display. - aspect: Aspect ratio of the plot in the control panel (width/height). - order: Optional ordering, smallest values will be displayed first. - visible: Whether the component is visible. - - Returns: - A handle that can be used to interact with the GUI element. + def setup_plotly_js(self) -> None: + """ + If plotly.min.js hasn't been sent to the client yet, the client won't be able + to render the plot. Send this large file now! (~3MB) """ - - # If plotly.min.js hasn't been sent to the client yet, the client won't be able - # to render the plot. Send this large file now! (~3MB) if not self._setup_plotly_js: # Check if plotly is installed. try: @@ -744,9 +728,9 @@ def add_plotly( plotly_path = ( Path(plotly.__file__).parent / "package_data" / "plotly.min.js" ) - assert plotly_path.exists(), ( - f"Could not find plotly.min.js at {plotly_path}." - ) + assert ( + plotly_path.exists() + ), f"Could not find plotly.min.js at {plotly_path}." # Send it over! plotly_js = plotly_path.read_text(encoding="utf-8") @@ -757,7 +741,29 @@ def add_plotly( # Update the flag so we don't send it again. self._setup_plotly_js = True + def add_plotly( + self, + figure: go.Figure, + aspect: float = 1.0, + order: float | None = None, + visible: bool = True, + ) -> GuiPlotlyHandle: + """Add a Plotly figure to the GUI. Requires the `plotly` package to be + installed. + + Args: + figure: Plotly figure to display. + aspect: Aspect ratio of the plot in the control panel (width/height). + order: Optional ordering, smallest values will be displayed first. + visible: Whether the component is visible. + + Returns: + A handle that can be used to interact with the GUI element. + """ + + self.setup_plotly_js() # After plotly.min.js has been sent, we can send the plotly figure. + # Empty string for `plotly_json_str` is a signal to the client to render nothing. message = _messages.GuiPlotlyMessage( uuid=_make_uuid(), @@ -785,8 +791,41 @@ def add_plotly( # Set the plotly handle properties. handle.figure = figure handle.aspect = aspect + return handle + def plotly_extend_traces( + self, + plotly_element_uuids: list[str], + x_data: list[float] | list[list[float]] | np.ndarray, + y_data: list[float] | list[list[float]] | np.ndarray, + history_length: int, + ) -> None: + """Extend traces in a plotly plot with new data. + + Args: + plotly_element_uuids: UUIDs of the plotly elements to update + x_data: X-axis data. Can be a 1D list/array for single trace or 2D list/array for multiple traces + y_data: Y-axis data. Can be a 1D list/array for single trace or 2D list/array for multiple traces + history_length: Number of points to keep in the history + """ + # Create a unique message for each update + message = _messages.GuiPlotlyExtendTracesMessage( + # uuid=_make_uuid(), + container_uuid=self._get_container_uuid(), + props=_messages.GuiPlotlyExtendTracesProps( + plotly_element_uuids=plotly_element_uuids, + x_data=self.to_list_of_lists(x_data), + y_data=self.to_list_of_lists(y_data), + history_length=history_length, + ), + ) + # Ensure the message is queued with a unique key + # message.redundancy_key = lambda: f"plotly-extend-{plotly_element_uuid}-{id(message)}" + print("redundancy_key", message.redundancy_key()) + self._websock_interface.queue_message(message) + # self._websock_interface.flush() + def add_button( self, label: str, @@ -1649,3 +1688,23 @@ def sync_other_clients( handle_state.sync_cb = sync_other_clients return handle_state + + def to_list_of_lists(self, x: Union[Sequence, np.ndarray]) -> list[list[float]]: + """Convert input to a list of list of floats.""" + if isinstance(x, np.ndarray): + arr = x.astype(float) + if arr.ndim == 1: + return [arr.tolist()] # Wrap 1D array + elif arr.ndim == 2: + return arr.tolist() + else: + raise ValueError(f"Unsupported ndarray with ndim={arr.ndim}") + elif isinstance(x, (list, tuple)): + if all(isinstance(el, (int, float)) for el in x): + return [[float(val) for val in x]] # 1D list + elif all(isinstance(el, (list, tuple)) for el in x): + return [[float(val) for val in row] for row in x] # 2D list + else: + raise ValueError("List must contain only numbers or lists of numbers.") + else: + raise TypeError(f"Unsupported input type: {type(x)}") diff --git a/src/viser/_gui_handles.py b/src/viser/_gui_handles.py index ce1843cfa..b3b640c43 100644 --- a/src/viser/_gui_handles.py +++ b/src/viser/_gui_handles.py @@ -779,6 +779,7 @@ def figure(self, figure: go.Figure) -> None: json_str = figure.to_json() assert isinstance(json_str, str) + # print(json_str) self._plotly_json_str = json_str diff --git a/src/viser/_messages.py b/src/viser/_messages.py index f85fff48a..626f8474e 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -1011,6 +1011,25 @@ class GuiPlotlyMessage(_CreateGuiComponentMessage): props: GuiPlotlyProps +@dataclasses.dataclass +class GuiPlotlyExtendTracesProps: + plotly_element_uuids: list[str] + """UUIDs of the plotly elements to update.""" + x_data: list[float] + """List of x-data points for each trace.""" + y_data: list[float] + """List of y-data points for each trace.""" + history_length: int + """History length for the plot.""" + + +@dataclasses.dataclass +class GuiPlotlyExtendTracesMessage(Message): + # uuid: str + container_uuid: str + props: GuiPlotlyExtendTracesProps + + @dataclasses.dataclass class GuiImageProps: order: float diff --git a/src/viser/_scene_api.py b/src/viser/_scene_api.py index 7b8e3d417..76b54442a 100644 --- a/src/viser/_scene_api.py +++ b/src/viser/_scene_api.py @@ -108,9 +108,9 @@ def _encode_image_binary( def cast_vector(vector: TVector | np.ndarray, length: int) -> TVector: if not isinstance(vector, tuple): - assert cast(np.ndarray, vector).shape == (length,), ( - f"Expected vector of shape {(length,)}, but got {vector.shape} instead" - ) + assert cast(np.ndarray, vector).shape == ( + length, + ), f"Expected vector of shape {(length,)}, but got {vector.shape} instead" return cast(TVector, tuple(map(float, vector))) @@ -1107,9 +1107,9 @@ def add_point_cloud( Handle for manipulating scene node. """ colors_cast = colors_to_uint8(np.asarray(colors)) - assert len(points.shape) == 2 and points.shape[-1] == 3, ( - "Shape of points should be (N, 3)." - ) + assert ( + len(points.shape) == 2 and points.shape[-1] == 3 + ), "Shape of points should be (N, 3)." assert colors_cast.shape in { points.shape, (3,), diff --git a/src/viser/client/public/logo.svg b/src/viser/client/public/logo.svg index 7a6bfe30e..f98c7bdd9 100644 --- a/src/viser/client/public/logo.svg +++ b/src/viser/client/public/logo.svg @@ -1 +1 @@ - \ No newline at end of file + diff --git a/src/viser/client/src/ControlPanel/Generated.tsx b/src/viser/client/src/ControlPanel/Generated.tsx index 6ae486928..4459879c9 100644 --- a/src/viser/client/src/ControlPanel/Generated.tsx +++ b/src/viser/client/src/ControlPanel/Generated.tsx @@ -16,7 +16,7 @@ import RgbComponent from "../components/Rgb"; import RgbaComponent from "../components/Rgba"; import ButtonGroupComponent from "../components/ButtonGroup"; import MarkdownComponent from "../components/Markdown"; -import PlotlyComponent from "../components/PlotlyComponent"; +import PlotlyComponent, { PlotlyExtendTracesComponent } from "../components/PlotlyComponent"; import TabGroupComponent from "../components/TabGroup"; import FolderComponent from "../components/Folder"; import MultiSliderComponent from "../components/MultiSlider"; @@ -95,6 +95,12 @@ function GeneratedInput(props: { guiUuid: string }) { console.error("Tried to render non-existent component", props.guiUuid); return null; } + + // console.log("%c[GeneratedInput] Rendering component", "color: #2196F3; font-weight: bold", { + // type: conf.type, + // uuid: props.guiUuid + // }); + switch (conf.type) { case "GuiFolderMessage": return ; @@ -106,6 +112,8 @@ function GeneratedInput(props: { guiUuid: string }) { return ; case "GuiPlotlyMessage": return ; + case "GuiPlotlyExtendTracesMessage": + return ; case "GuiImageMessage": return ; case "GuiButtonMessage": diff --git a/src/viser/client/src/MessageHandler.tsx b/src/viser/client/src/MessageHandler.tsx index 7108a957b..fa4a877bb 100644 --- a/src/viser/client/src/MessageHandler.tsx +++ b/src/viser/client/src/MessageHandler.tsx @@ -416,6 +416,7 @@ function useMessageHandler(): (message: Message) => void { updateGuiProps(message.uuid, message.updates); return; } + // Remove a GUI input. case "GuiRemoveMessage": { removeGui(message.uuid); diff --git a/src/viser/client/src/Splatting/WasmSorter/Sorter.mjs b/src/viser/client/src/Splatting/WasmSorter/Sorter.mjs index a2a2dfcf4..72dcc7d3c 100644 --- a/src/viser/client/src/Splatting/WasmSorter/Sorter.mjs +++ b/src/viser/client/src/Splatting/WasmSorter/Sorter.mjs @@ -1,7 +1,7 @@ var Module = (() => { var _scriptName = import.meta.url; - + return ( async function(moduleArg = {}) { var moduleRtn; diff --git a/src/viser/client/src/WebsocketMessages.ts b/src/viser/client/src/WebsocketMessages.ts index 2735087d7..7c485a6d4 100644 --- a/src/viser/client/src/WebsocketMessages.ts +++ b/src/viser/client/src/WebsocketMessages.ts @@ -474,6 +474,22 @@ export interface GuiPlotlyMessage { visible: boolean; }; } +/** GuiPlotlyExtendTracesMessage(uuid: 'str', container_uuid: 'str', props: 'GuiPlotlyExtendTracesProps') + * + * (automatically generated) + */ +export interface GuiPlotlyExtendTracesMessage { + type: "GuiPlotlyExtendTracesMessage"; + // uuid: string; + container_uuid: string; + props: { + plotly_element_uuids: string[]; + x_data: number[][]; + y_data: number[][]; + history_length: number; + }; +} + /** GuiImageMessage(uuid: 'str', container_uuid: 'str', props: 'GuiImageProps') * * (automatically generated) @@ -1292,6 +1308,7 @@ export type Message = | GuiHtmlMessage | GuiProgressBarMessage | GuiPlotlyMessage + | GuiPlotlyExtendTracesMessage | GuiImageMessage | GuiTabGroupMessage | GuiButtonMessage @@ -1378,6 +1395,7 @@ export type GuiComponentMessage = | GuiHtmlMessage | GuiProgressBarMessage | GuiPlotlyMessage + | GuiPlotlyExtendTracesMessage | GuiImageMessage | GuiTabGroupMessage | GuiButtonMessage @@ -1430,6 +1448,7 @@ const typeSetGuiComponentMessage = new Set([ "GuiHtmlMessage", "GuiProgressBarMessage", "GuiPlotlyMessage", + "GuiPlotlyExtendTracesMessage", "GuiImageMessage", "GuiTabGroupMessage", "GuiButtonMessage", diff --git a/src/viser/client/src/components/PlotlyComponent.tsx b/src/viser/client/src/components/PlotlyComponent.tsx index f525e70d2..959a75432 100644 --- a/src/viser/client/src/components/PlotlyComponent.tsx +++ b/src/viser/client/src/components/PlotlyComponent.tsx @@ -1,5 +1,5 @@ import React from "react"; -import { GuiPlotlyMessage } from "../WebsocketMessages"; +import { GuiPlotlyMessage, GuiPlotlyExtendTracesMessage } from "../WebsocketMessages"; import { useDisclosure } from "@mantine/hooks"; import { Modal, Box, Paper, Tooltip } from "@mantine/core"; import { useElementSize } from "@mantine/hooks"; @@ -7,29 +7,33 @@ import { useElementSize } from "@mantine/hooks"; // When drawing border around the plot, it should be aligned with the folder's. import { folderWrapper } from "./Folder.css"; + const PlotWithAspect = React.memo(function PlotWithAspect({ - jsonStr, + plotJson, aspectRatio, staticPlot, + uuid, }: { - jsonStr: string; + plotJson: any; aspectRatio: number; staticPlot: boolean; + uuid: string; }) { - // Catch if the jsonStr is empty; if so, render an empty div. - if (jsonStr === "") return
; - - // Parse json string, to construct plotly object. - // Note that only the JSON string is kept as state, not the json object. - const plotJson = JSON.parse(jsonStr); + // Catch if the plotJson is empty; if so, render an empty div. + if (!plotJson) return
; // This keeps the zoom-in state, etc, see https://plotly.com/javascript/uirevision/. plotJson.layout.uirevision = "true"; // Box size change -> width value change -> plot rerender trigger. const { ref, width } = useElementSize(); - plotJson.layout.width = width; - plotJson.layout.height = width * aspectRatio; + const plotWidth = width || 1; // Fallback to 1 if width is 0, the main plot's elementSize is 0. + plotJson.layout.width = plotWidth; + plotJson.layout.height = plotWidth * aspectRatio; + // console.warn("plotWidth", plotWidth); + // console.warn("plotJson.layout.width", plotJson.layout.width); + // console.warn("plotJson.layout.height", plotJson.layout.height); + // console.warn("aspectRatio", aspectRatio); // Make the plot non-interactable, if specified. // Ideally, we would use `staticplot`, but this has a known bug with 3D plots: @@ -46,7 +50,13 @@ const PlotWithAspect = React.memo(function PlotWithAspect({ // Use React hooks to update the plotly object, when the plot data changes. // based on https://github.com/plotly/react-plotly.js/issues/242. const plotRef = React.useRef(null); + React.useEffect(() => { + // Set the ID of the plot element + if (plotRef.current) { + plotRef.current.id = uuid; + } + // @ts-ignore - Plotly.js is dynamically imported with an eval() call. Plotly.react( plotRef.current!, @@ -54,7 +64,7 @@ const PlotWithAspect = React.memo(function PlotWithAspect({ plotJson.layout, plotJson.config, ); - }, [plotJson]); + }, [plotJson, uuid, plotWidth]); // Re-render when plot data or width changes return ( -
+
{/* Add a div on top of the plot, to prevent interaction + cursor changes. */} {staticPlot ? (
JSON.parse(plotly_json_str)); + + // Make a copy of the plotJson for modal plot + const [modalPlotJson, setModalPlotJson] = React.useState(() => JSON.parse(plotly_json_str)); + + // Update plot data when new JSON string comes in + React.useEffect(() => { + setPlotJson(JSON.parse(plotly_json_str)); + setModalPlotJson(JSON.parse(plotly_json_str)); + }, [plotly_json_str]); + // Create a modal with the plot, and a button to open it. const [opened, { open, close }] = useDisclosure(false); - return ( + const ddd = ( - {/* Draw static plot in the controlpanel, which can be clicked. */} + uuid={uuid} + /> - {/* Modal contents. keepMounted makes state changes (eg zoom) to the plot - persistent. */} + uuid={`${uuid}-modal`} + /> ); + const t1 = performance.now(); + console.warn("PlotlyComponent time:", t1 - t0); + return ddd; +} + +// Component for handling plot updates +export function PlotlyExtendTracesComponent({ + props: { plotly_element_uuids, x_data, y_data, history_length }, +}: GuiPlotlyExtendTracesMessage) { + // Use React hooks to update the plotly object when new data arrives + React.useEffect(() => { + const t0 = performance.now(); + + // Batch all plot elements first + const t1 = performance.now(); + const plotElements = plotly_element_uuids.flatMap(uuid => { + const main = document.getElementById(uuid); + const modal = document.getElementById(`${uuid}-modal`); + return [main, modal].filter(Boolean); + }); + const t2 = performance.now(); + console.warn("DOM query time:", t2 - t1); + + if (plotElements.length === 0) { + console.warn("Could not find any plot elements with UUIDs:", plotly_element_uuids); + return; + } + + try { + // @ts-ignore - Plotly.js is dynamically imported + const Plotly = (window as any).Plotly; + const t3 = performance.now(); + + // Pre-compute trace indices once + const traceIndices = Array.from({ length: x_data.length }, (_, i) => i); + + // Update each plot with minimal data processing + plotElements.forEach((element, index) => { + const t4 = performance.now(); + // Use a more direct update approach + Plotly.extendTraces( + element, + { x: x_data, y: y_data }, + traceIndices, + history_length, + { mode: 'lines' } // Optimize for line plots + ); + const t5 = performance.now(); + console.warn(`Plot ${index} update time:`, t5 - t4); + }); + + const t6 = performance.now(); + console.warn("Total Plotly updates time:", t6 - t3); + } catch (error) { + console.error("Error updating plots:", error); + } + + const t7 = performance.now(); + console.warn("Total execution time:", t7 - t0); + + }, [plotly_element_uuids, x_data, y_data, history_length]); + + // This component doesn't render anything visible + return null; } diff --git a/src/viser/client/src/components/plotWorker.ts b/src/viser/client/src/components/plotWorker.ts new file mode 100644 index 000000000..58281c5a3 --- /dev/null +++ b/src/viser/client/src/components/plotWorker.ts @@ -0,0 +1,13 @@ +// Worker for processing plot data +self.onmessage = (e) => { + const { x_data, y_data, history_length } = e.data; + + // Process data in worker thread + const processedData = { + x: x_data, + y: y_data + }; + + // Send processed data back to main thread + self.postMessage({ processedData }); +}; diff --git a/src/viser/client/src/csm/CSM.d.ts b/src/viser/client/src/csm/CSM.d.ts index 2373bffc4..dae166ed7 100644 --- a/src/viser/client/src/csm/CSM.d.ts +++ b/src/viser/client/src/csm/CSM.d.ts @@ -32,12 +32,12 @@ export class CSM { customSplitsCallback?: (cascades: number, near: number, far: number, breaks: number[]) => void; fade: boolean; lights: DirectionalLight[]; - + constructor(data: CSMParameters); - + update(): void; updateFrustums(): void; remove(): void; dispose(): void; setupMaterial(material: Material): void; -} \ No newline at end of file +} diff --git a/src/viser/infra/_async_message_buffer.py b/src/viser/infra/_async_message_buffer.py index 6d4058858..67ad82142 100644 --- a/src/viser/infra/_async_message_buffer.py +++ b/src/viser/infra/_async_message_buffer.py @@ -155,3 +155,68 @@ async def window_generator( if flush_wait in done and not self.done: self.flush_event.clear() flush_wait = self.event_loop.create_task(self.flush_event.wait()) + + # async def window_generator( + # self, client_id: int + # ) -> AsyncGenerator[Sequence[Message], None]: + # """Async iterator over messages. Loops infinitely, and waits when no messages + # are available.""" + + # last_sent_id = -1 # ID of the last message we've sent + # flush_wait = self.event_loop.create_task(self.flush_event.wait()) + + # while not self.done: + # window: List[Message] = [] + + # with self.buffer_lock: + # # Get all message IDs > last_sent_id, sorted in order + # next_ids = sorted( + # msg_id for msg_id in self.message_from_id.keys() + # if msg_id > last_sent_id + # ) + + # for msg_id in next_ids: + # # Don't send anything while atomic block is active + # if self.atomic_counter > 0: + # break + + # with self.buffer_lock: + # message = ( + # self.message_from_id.get(msg_id) + # if self.persistent_messages + # else self.message_from_id.pop(msg_id, None) + # ) + + # if message is not None and not self.persistent_messages: + # # Clean up redundancy tracking if needed + # redundancy_key = message.redundancy_key() + # self.id_from_redundancy_key.pop(redundancy_key, None) + + # if message is not None and message.excluded_self_client != client_id: + # window.append(message) + # last_sent_id = msg_id + + # if len(window) >= self.max_window_size: + # break + + # if len(window) > 0: + # yield window + # else: + # # Wait for either new message or flush trigger + # await self.message_event.wait() + # self.message_event.clear() + + # # Optional delay logic (rate limiting) + # with self.buffer_lock: + # pending_ids = [ + # msg_id for msg_id in self.message_from_id.keys() + # if msg_id > last_sent_id + # ] + # if not pending_ids: + # done, pending = await asyncio.wait( + # [flush_wait], timeout=self.window_duration_sec + # ) + # del pending + # if flush_wait in done and not self.done: + # self.flush_event.clear() + # flush_wait = self.event_loop.create_task(self.flush_event.wait()) diff --git a/src/viser/infra/_infra.py b/src/viser/infra/_infra.py index b2a861ba7..39379e64a 100644 --- a/src/viser/infra/_infra.py +++ b/src/viser/infra/_infra.py @@ -68,9 +68,9 @@ def _insert_message(self, message: Message) -> None: def insert_sleep(self, duration: float) -> None: """Insert a sleep into the recorded file. This can be useful for dynamic 3D data.""" - assert self._handler._record_handle is not None, ( - "serialize() was already called!" - ) + assert ( + self._handler._record_handle is not None + ), "serialize() was already called!" self._time += duration def serialize(self) -> bytes: @@ -81,9 +81,9 @@ def serialize(self) -> bytes: Returns: The recording as bytes. """ - assert self._handler._record_handle is not None, ( - "serialize() was already called!" - ) + assert ( + self._handler._record_handle is not None + ), "serialize() was already called!" packed_bytes = msgspec.msgpack.encode( { @@ -135,9 +135,9 @@ def unregister_handler( callback: Callable[[ClientId, TMessage], None | Coroutine] | None = None, ): """Unregister a handler for a particular message type.""" - assert message_cls in self._incoming_handlers, ( - "Tried to unregister a handler that hasn't been registered." - ) + assert ( + message_cls in self._incoming_handlers + ), "Tried to unregister a handler that hasn't been registered." if callback is None: self._incoming_handlers.pop(message_cls) else: