From e5d525411ea4abf2d1980422c20b98072579a11e Mon Sep 17 00:00:00 2001 From: Chris Feijoo Date: Sat, 14 Mar 2026 20:33:08 +0100 Subject: [PATCH] =?UTF-8?q?FE-514:=20Add=20SymPy=20=E2=86=92=20JavaScript?= =?UTF-8?q?=20compilation=20pipeline=20via=20Pyodide?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace direct JS compilation with a SymPy intermediate representation: user TypeScript is compiled to SymPy Python, evaluated/simplified by real SymPy (via Pyodide WASM), then converted back to JavaScript for execution. - Add pyodide dependency and singleton manager for loading SymPy in worker - Create Python codegen module with custom JS printer for SymPy → JS - Create compile-via-sympy.ts orchestrating the full pipeline - Extend compile-to-sympy.ts with symbol tracking and Dynamics per-token body emission (no list comprehension for tokens.map) - Make buildSimulation async, using SymPy compilation pipeline - Add "compiling" progress messages and pyodideUrl support to worker - Update all simulation tests with vi.mock for async SymPy compilation Co-Authored-By: Claude Opus 4.6 --- libs/@hashintel/petrinaut/package.json | 1 + .../simulator/build-simulation.test.ts | 39 ++- .../simulation/simulator/build-simulation.ts | 43 ++- .../simulator/compile-to-sympy.test.ts | 145 +++++++++- .../simulation/simulator/compile-to-sympy.ts | 201 +++++++++++++- .../simulation/simulator/compile-via-sympy.ts | 262 ++++++++++++++++++ .../simulator/compute-next-frame.test.ts | 103 +++++-- .../simulation/simulator/pyodide-manager.ts | 58 ++++ .../src/simulation/simulator/sympy-codegen.ts | 163 +++++++++++ .../src/simulation/worker/messages.ts | 14 +- .../worker/simulation.worker.test.ts | 54 +++- .../simulation/worker/simulation.worker.ts | 91 +++--- .../worker/use-simulation-worker.ts | 12 + .../petrinaut/src/state/editor-provider.tsx | 2 +- libs/@hashintel/petrinaut/vite.config.ts | 1 + yarn.lock | 12 +- 16 files changed, 1073 insertions(+), 128 deletions(-) create mode 100644 libs/@hashintel/petrinaut/src/simulation/simulator/compile-via-sympy.ts create mode 100644 libs/@hashintel/petrinaut/src/simulation/simulator/pyodide-manager.ts create mode 100644 libs/@hashintel/petrinaut/src/simulation/simulator/sympy-codegen.ts diff --git a/libs/@hashintel/petrinaut/package.json b/libs/@hashintel/petrinaut/package.json index b0fc72b48e9..ee54a3501da 100644 --- a/libs/@hashintel/petrinaut/package.json +++ b/libs/@hashintel/petrinaut/package.json @@ -51,6 +51,7 @@ "d3-scale": "4.0.2", "elkjs": "0.11.0", "monaco-editor": "0.55.1", + "pyodide": "0.27.7", "react-icons": "5.5.0", "react-resizable-panels": "4.6.5", "typescript": "5.9.3", diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/build-simulation.test.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/build-simulation.test.ts index b06c6fbb6e3..a824c59bcf1 100644 --- a/libs/@hashintel/petrinaut/src/simulation/simulator/build-simulation.test.ts +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/build-simulation.test.ts @@ -1,10 +1,31 @@ -import { describe, expect, it } from "vitest"; +import { describe, expect, it, vi } from "vitest"; import { buildSimulation } from "./build-simulation"; import type { SimulationInput } from "./types"; +// Mock the SymPy compilation to return simple passthrough functions. +// This lets us test buildSimulation's validation and buffer logic +// without requiring a real Pyodide instance. +vi.mock("./compile-via-sympy", () => ({ + compileDifferentialEquationViaSymPy: vi.fn().mockResolvedValue( + // Returns a no-op dynamics function + () => [], + ), + compileLambdaViaSymPy: vi.fn().mockResolvedValue( + // Returns a constant rate + () => 1.0, + ), + compileTransitionKernelViaSymPy: vi.fn().mockResolvedValue( + // Returns empty kernel output + () => ({}), + ), +})); + +// Create a mock PyodideInterface +const mockPyodide = {} as Parameters[1]; + describe("buildSimulation", () => { - it("builds a simulation with a single place and initial tokens", () => { + it("builds a simulation with a single place and initial tokens", async () => { const input: SimulationInput = { sdcpn: { types: [ @@ -56,7 +77,7 @@ describe("buildSimulation", () => { maxTime: null, }; - const simulationInstance = buildSimulation(input); + const simulationInstance = await buildSimulation(input, mockPyodide); const frame = simulationInstance.frames[0]!; // Verify simulation instance properties @@ -88,7 +109,7 @@ describe("buildSimulation", () => { expect(simulationInstance.differentialEquationFns.has("p1")).toBe(true); }); - it("builds a simulation with multiple places, transitions, and proper buffer layout", () => { + it("builds a simulation with multiple places, transitions, and proper buffer layout", async () => { const input: SimulationInput = { sdcpn: { types: [ @@ -204,7 +225,7 @@ describe("buildSimulation", () => { maxTime: null, }; - const simulationInstance = buildSimulation(input); + const simulationInstance = await buildSimulation(input, mockPyodide); const frame = simulationInstance.frames[0]!; // Verify simulation instance properties @@ -266,7 +287,7 @@ describe("buildSimulation", () => { expect(typeof kernelFn).toBe("function"); }); - it("throws error when initialMarking references non-existent place", () => { + it("throws error when initialMarking references non-existent place", async () => { const input: SimulationInput = { sdcpn: { types: [ @@ -315,12 +336,12 @@ describe("buildSimulation", () => { maxTime: null, }; - expect(() => buildSimulation(input)).toThrow( + await expect(buildSimulation(input, mockPyodide)).rejects.toThrow( "Place with ID p_nonexistent in initialMarking does not exist in SDCPN", ); }); - it("throws error when token dimensions don't match place dimensions", () => { + it("throws error when token dimensions don't match place dimensions", async () => { const input: SimulationInput = { sdcpn: { types: [ @@ -372,7 +393,7 @@ describe("buildSimulation", () => { maxTime: null, }; - expect(() => buildSimulation(input)).toThrow( + await expect(buildSimulation(input, mockPyodide)).rejects.toThrow( "Token dimension mismatch for place p1. Expected 4 values (2 dimensions × 2 tokens), got 3", ); }); diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/build-simulation.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/build-simulation.ts index e6d9b7bd969..8f906d94e56 100644 --- a/libs/@hashintel/petrinaut/src/simulation/simulator/build-simulation.ts +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/build-simulation.ts @@ -1,13 +1,18 @@ +import type { PyodideInterface } from "pyodide"; + import { SDCPNItemError } from "../../core/errors"; import { deriveDefaultParameterValues, mergeParameterValues, } from "../../hooks/use-default-parameter-values"; -import { compileUserCode } from "./compile-user-code"; +import { + compileDifferentialEquationViaSymPy, + compileLambdaViaSymPy, + compileTransitionKernelViaSymPy, +} from "./compile-via-sympy"; import type { DifferentialEquationFn, LambdaFn, - ParameterValues, SimulationFrame, SimulationInput, SimulationInstance, @@ -49,12 +54,16 @@ function getPlaceDimensions( * - All places and transitions initialized with proper state * * @param input - The simulation input configuration + * @param pyodide - Initialized Pyodide instance with SymPy * @returns The initial simulation frame ready for execution * @throws {Error} if place IDs in initialMarking don't match places in SDCPN * @throws {Error} if token dimensions don't match place dimensions * @throws {Error} if user code fails to compile */ -export function buildSimulation(input: SimulationInput): SimulationInstance { +export async function buildSimulation( + input: SimulationInput, + pyodide: PyodideInterface, +): Promise { const { sdcpn, initialMarking, @@ -100,7 +109,7 @@ export function buildSimulation(input: SimulationInput): SimulationInstance { } } - // Compile all differential equation functions + // Compile all differential equation functions via SymPy const differentialEquationFns = new Map(); for (const place of sdcpn.places) { // Skip places without dynamics enabled or without differential equation code @@ -119,9 +128,11 @@ export function buildSimulation(input: SimulationInput): SimulationInstance { const { code } = differentialEquation; try { - const fn = compileUserCode<[Record[], ParameterValues]>( + const fn = await compileDifferentialEquationViaSymPy( code, - "Dynamics", + sdcpn, + place.colorId!, + pyodide, ); differentialEquationFns.set(place.id, fn as DifferentialEquationFn); } catch (error) { @@ -134,13 +145,16 @@ export function buildSimulation(input: SimulationInput): SimulationInstance { } } - // Compile all lambda functions + // Compile all lambda functions via SymPy const lambdaFns = new Map(); for (const transition of sdcpn.transitions) { try { - const fn = compileUserCode< - [Record[]>, ParameterValues] - >(transition.lambdaCode, "Lambda"); + const fn = await compileLambdaViaSymPy( + transition.lambdaCode, + sdcpn, + transition, + pyodide, + ); lambdaFns.set(transition.id, fn as LambdaFn); } catch (error) { throw new SDCPNItemError( @@ -173,9 +187,12 @@ export function buildSimulation(input: SimulationInput): SimulationInstance { } try { - const fn = compileUserCode< - [Record[]>, ParameterValues] - >(transition.transitionKernelCode, "TransitionKernel"); + const fn = await compileTransitionKernelViaSymPy( + transition.transitionKernelCode, + sdcpn, + transition, + pyodide, + ); transitionKernelFns.set(transition.id, fn as TransitionKernelFn); } catch (error) { throw new SDCPNItemError( diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts index 0ae9e4eef5d..b0b7c36c856 100644 --- a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts @@ -37,7 +37,7 @@ describe("compileToSymPy", () => { "export default Lambda(() => 1)", defaultContext, ); - expect(result).toEqual({ ok: true, sympyCode: "1" }); + expect(result).toEqual({ ok: true, sympyCode: "1", symbols: [] }); }); it("should compile a decimal literal", () => { @@ -45,7 +45,7 @@ describe("compileToSymPy", () => { "export default Lambda(() => 3.14)", defaultContext, ); - expect(result).toEqual({ ok: true, sympyCode: "3.14" }); + expect(result).toEqual({ ok: true, sympyCode: "3.14", symbols: [] }); }); it("should compile boolean true", () => { @@ -53,7 +53,7 @@ describe("compileToSymPy", () => { "export default Lambda(() => true)", defaultContext, ); - expect(result).toEqual({ ok: true, sympyCode: "True" }); + expect(result).toEqual({ ok: true, sympyCode: "True", symbols: [] }); }); it("should compile boolean false", () => { @@ -61,7 +61,7 @@ describe("compileToSymPy", () => { "export default Lambda(() => false)", defaultContext, ); - expect(result).toEqual({ ok: true, sympyCode: "False" }); + expect(result).toEqual({ ok: true, sympyCode: "False", symbols: [] }); }); it("should compile Infinity", () => { @@ -69,7 +69,7 @@ describe("compileToSymPy", () => { "export default Lambda(() => Infinity)", defaultContext, ); - expect(result).toEqual({ ok: true, sympyCode: "sp.oo" }); + expect(result).toEqual({ ok: true, sympyCode: "sp.oo", symbols: [] }); }); }); @@ -79,7 +79,11 @@ describe("compileToSymPy", () => { "export default Lambda((tokens, parameters) => parameters.infection_rate)", defaultContext, ); - expect(result).toEqual({ ok: true, sympyCode: "infection_rate" }); + expect(result).toEqual({ + ok: true, + sympyCode: "infection_rate", + symbols: ["infection_rate"], + }); }); it("should compile parameters in arithmetic", () => { @@ -90,6 +94,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "infection_rate * 2", + symbols: ["infection_rate"], }); }); }); @@ -100,7 +105,7 @@ describe("compileToSymPy", () => { "export default Lambda(() => 1 + 2)", defaultContext, ); - expect(result).toEqual({ ok: true, sympyCode: "1 + 2" }); + expect(result).toEqual({ ok: true, sympyCode: "1 + 2", symbols: [] }); }); it("should compile subtraction", () => { @@ -108,7 +113,7 @@ describe("compileToSymPy", () => { "export default Lambda(() => 5 - 3)", defaultContext, ); - expect(result).toEqual({ ok: true, sympyCode: "5 - 3" }); + expect(result).toEqual({ ok: true, sympyCode: "5 - 3", symbols: [] }); }); it("should compile multiplication", () => { @@ -116,7 +121,7 @@ describe("compileToSymPy", () => { "export default Lambda(() => 2 * 3)", defaultContext, ); - expect(result).toEqual({ ok: true, sympyCode: "2 * 3" }); + expect(result).toEqual({ ok: true, sympyCode: "2 * 3", symbols: [] }); }); it("should compile division", () => { @@ -124,7 +129,7 @@ describe("compileToSymPy", () => { "export default Lambda(() => 1 / 3)", defaultContext, ); - expect(result).toEqual({ ok: true, sympyCode: "1 / 3" }); + expect(result).toEqual({ ok: true, sympyCode: "1 / 3", symbols: [] }); }); it("should compile power operator", () => { @@ -135,6 +140,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "satellite_radius**2", + symbols: ["satellite_radius"], }); }); @@ -146,6 +152,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "sp.Mod(10, 3)", + symbols: [], }); }); }); @@ -159,6 +166,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "infection_rate < 5", + symbols: ["infection_rate"], }); }); @@ -170,6 +178,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "infection_rate >= 1", + symbols: ["infection_rate"], }); }); @@ -181,6 +190,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "sp.Eq(infection_rate, 3)", + symbols: ["infection_rate"], }); }); @@ -192,6 +202,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "sp.Ne(infection_rate, 0)", + symbols: ["infection_rate"], }); }); }); @@ -205,6 +216,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "sp.And(infection_rate > 0, recovery_rate > 0)", + symbols: expect.arrayContaining(["infection_rate", "recovery_rate"]), }); }); @@ -216,6 +228,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "sp.Or(sp.Eq(infection_rate, 0), sp.Eq(recovery_rate, 0))", + symbols: expect.arrayContaining(["infection_rate", "recovery_rate"]), }); }); }); @@ -229,6 +242,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "-(infection_rate)", + symbols: ["infection_rate"], }); }); @@ -240,6 +254,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "sp.Not(True)", + symbols: [], }); }); }); @@ -253,6 +268,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "sp.cos(infection_rate)", + symbols: ["infection_rate"], }); }); @@ -264,6 +280,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "sp.sin(infection_rate)", + symbols: ["infection_rate"], }); }); @@ -275,6 +292,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "sp.sqrt(infection_rate)", + symbols: ["infection_rate"], }); }); @@ -286,6 +304,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "sp.log(infection_rate)", + symbols: ["infection_rate"], }); }); @@ -297,6 +316,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "sp.exp(infection_rate)", + symbols: ["infection_rate"], }); }); @@ -308,6 +328,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "sp.Abs(infection_rate)", + symbols: ["infection_rate"], }); }); @@ -319,6 +340,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "(infection_rate)**(2)", + symbols: ["infection_rate"], }); }); @@ -330,6 +352,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "sp.sqrt((infection_rate)**2 + (recovery_rate)**2)", + symbols: expect.arrayContaining(["infection_rate", "recovery_rate"]), }); }); @@ -338,7 +361,7 @@ describe("compileToSymPy", () => { "export default Lambda(() => Math.PI)", defaultContext, ); - expect(result).toEqual({ ok: true, sympyCode: "sp.pi" }); + expect(result).toEqual({ ok: true, sympyCode: "sp.pi", symbols: [] }); }); it("should compile Math.E", () => { @@ -346,7 +369,7 @@ describe("compileToSymPy", () => { "export default Lambda(() => Math.E)", defaultContext, ); - expect(result).toEqual({ ok: true, sympyCode: "sp.E" }); + expect(result).toEqual({ ok: true, sympyCode: "sp.E", symbols: [] }); }); }); @@ -359,6 +382,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "Space_0_x", + symbols: ["Space_0_x"], }); }); @@ -370,6 +394,10 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "Space_0_velocity < crash_threshold", + symbols: expect.arrayContaining([ + "Space_0_velocity", + "crash_threshold", + ]), }); }); }); @@ -384,6 +412,7 @@ describe("compileToSymPy", () => { ok: true, sympyCode: "sp.Piecewise((infection_rate, infection_rate > 1), (0, True))", + symbols: ["infection_rate"], }); }); }); @@ -397,6 +426,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "sp.stats.Normal('X', 0, 1)", + symbols: [], }); }); @@ -408,6 +438,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "sp.stats.Uniform('X', 0, 1)", + symbols: [], }); }); @@ -419,8 +450,35 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "sp.stats.LogNormal('X', 0, 1)", + symbols: [], }); }); + + it("should compile Distribution.Gaussian.map as direct arithmetic", () => { + const result = compileToSymPy( + "export default Lambda(() => Distribution.Gaussian(0, 1).map(x => x * 2))", + defaultContext, + ); + expect(result.ok).toBe(true); + if (result.ok) { + // Should emit direct arithmetic: sp.stats.Normal('X', 0, 1) * 2 + expect(result.sympyCode).toBe("sp.stats.Normal('X', 0, 1) * 2"); + } + }); + + it("should compile Distribution.Gaussian.map with addition", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => Distribution.Gaussian(0, parameters.infection_rate).map(x => x * 2 + 3))", + defaultContext, + ); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.sympyCode).toBe( + "sp.stats.Normal('X', 0, infection_rate) * 2 + 3", + ); + expect(result.symbols).toContain("infection_rate"); + } + }); }); describe("global built-in functions", () => { @@ -432,6 +490,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "sp.Ne(infection_rate, 0)", + symbols: ["infection_rate"], }); }); @@ -443,6 +502,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "sp.Ne(1 + 2, 0)", + symbols: [], }); }); @@ -454,6 +514,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "True", + symbols: [], }); }); @@ -484,6 +545,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "gravitational_constant * 2", + symbols: ["gravitational_constant"], }); }); @@ -499,6 +561,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "infection_rate + recovery_rate", + symbols: expect.arrayContaining(["infection_rate", "recovery_rate"]), }); }); }); @@ -512,6 +575,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "infection_rate", + symbols: ["infection_rate"], }); }); @@ -528,6 +592,9 @@ describe("compileToSymPy", () => { expect(result.sympyCode).toContain("sp.sqrt"); expect(result.sympyCode).toContain("<"); expect(result.sympyCode).toContain("earth_radius"); + expect(result.symbols).toContain("Space_0_x"); + expect(result.symbols).toContain("Space_0_y"); + expect(result.symbols).toContain("earth_radius"); } }); @@ -545,6 +612,8 @@ describe("compileToSymPy", () => { if (result.ok) { expect(result.sympyCode).toContain("gravitational_constant"); expect(result.sympyCode).toContain("Space_0_x"); + expect(result.symbols).toContain("gravitational_constant"); + expect(result.symbols).toContain("Space_0_x"); } }); @@ -603,7 +672,11 @@ describe("compileToSymPy", () => { "export default Lambda(() => [1, 2, 3])", defaultContext, ); - expect(result).toEqual({ ok: true, sympyCode: "[1, 2, 3]" }); + expect(result).toEqual({ + ok: true, + sympyCode: "[1, 2, 3]", + symbols: [], + }); }); }); @@ -628,10 +701,15 @@ describe("compileToSymPy", () => { ); expect(result.ok).toBe(true); if (result.ok) { - expect(result.sympyCode).toContain("for _iter in tokens"); + // Dynamics tokens.map() emits per-token body only (no list comprehension) + expect(result.sympyCode).not.toContain("for _iter in tokens"); expect(result.sympyCode).toContain("_iter_x"); expect(result.sympyCode).toContain("_iter_velocity"); expect(result.sympyCode).toContain("sp.cos(_iter_direction)"); + // _iter_* symbols should be tracked + expect(result.symbols).toContain("_iter_x"); + expect(result.symbols).toContain("_iter_velocity"); + expect(result.symbols).toContain("gravitational_constant"); } }); @@ -643,6 +721,7 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "[_iter + 1 for _iter in tokens]", + symbols: [], }); }); @@ -654,10 +733,48 @@ describe("compileToSymPy", () => { expect(result).toEqual({ ok: true, sympyCode: "[_iter_x * infection_rate for _iter in tokens]", + symbols: ["infection_rate"], }); }); }); + describe("symbol tracking", () => { + it("should track parameter symbols used in expression", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => parameters.infection_rate * parameters.recovery_rate)", + defaultContext, + ); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.symbols).toContain("infection_rate"); + expect(result.symbols).toContain("recovery_rate"); + } + }); + + it("should track token field symbols", () => { + const result = compileToSymPy( + "export default Lambda((tokens) => tokens.Space[0].x + tokens.Space[0].y)", + defaultContext, + ); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.symbols).toContain("Space_0_x"); + expect(result.symbols).toContain("Space_0_y"); + } + }); + + it("should return empty symbols for literal-only expressions", () => { + const result = compileToSymPy( + "export default Lambda(() => 42)", + defaultContext, + ); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.symbols).toEqual([]); + } + }); + }); + describe("error handling", () => { it("should reject code without default export", () => { const result = compileToSymPy( diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts index 7d00e5a34ba..3e39f0316f6 100644 --- a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts @@ -72,6 +72,11 @@ export function buildContextForDifferentialEquation( } export type SymPyResult = + | { ok: true; sympyCode: string; symbols: string[] } + | { ok: false; error: string; start: number; length: number }; + +/** Internal result type without symbols (used by emitSymPy and helpers). */ +type InternalResult = | { ok: true; sympyCode: string } | { ok: false; error: string; start: number; length: number }; @@ -181,12 +186,26 @@ export function compileToSymPy( // Extract parameter names for the inner function const localBindings = new Map(); const innerParams = extractFunctionParams(arg, sourceFile); + const symbols = new Set(); // Compile the body const body = arg.body; if (ts.isBlock(body)) { - return compileBlock(body, context, localBindings, sourceFile); + const blockResult = compileBlock( + body, + context, + localBindings, + symbols, + sourceFile, + innerParams, + ); + if (!blockResult.ok) return blockResult; + return { + ok: true, + sympyCode: blockResult.sympyCode, + symbols: [...symbols], + }; } // Expression body — emit directly @@ -195,10 +214,11 @@ export function compileToSymPy( context, localBindings, innerParams, + symbols, sourceFile, ); if (!result.ok) return result; - return { ok: true, sympyCode: result.sympyCode }; + return { ok: true, sympyCode: result.sympyCode, symbols: [...symbols] }; } function extractFunctionParams( @@ -212,8 +232,10 @@ function compileBlock( block: ts.Block, context: SymPyCompilationContext, localBindings: Map, + symbols: Set, sourceFile: ts.SourceFile, -): SymPyResult { + innerParams: string[] = [], +): InternalResult { const lines: string[] = []; for (const stmt of block.statements) { @@ -238,7 +260,8 @@ function compileBlock( decl.initializer, context, localBindings, - [], + innerParams, + symbols, sourceFile, ); if (!valueResult.ok) return valueResult; @@ -253,7 +276,8 @@ function compileBlock( stmt.expression, context, localBindings, - [], + innerParams, + symbols, sourceFile, ); if (!result.ok) return result; @@ -289,17 +313,103 @@ function compileBlock( * * Emits: `[ for _iter in ]` */ +/** + * Check if an expression is a Distribution call (Distribution.Gaussian/Uniform/Lognormal). + */ +function isDistributionCall(node: ts.Expression): boolean { + if (!ts.isCallExpression(node)) return false; + const callee = node.expression; + return ( + ts.isPropertyAccessExpression(callee) && + ts.isIdentifier(callee.expression) && + callee.expression.text === "Distribution" + ); +} + +/** + * Compiles `.map(callback)` calls. + * + * For Distribution calls: emits direct arithmetic on the random variable + * e.g. Distribution.Gaussian(0, 1).map(x => x * 2) → sp.stats.Normal('X', 0, 1) * 2 + * + * For array/token calls: emits Python list comprehension + * e.g. tokens.map(t => t + 1) → [_iter + 1 for _iter in tokens] + */ function compileMapCall( collection: ts.Expression, callback: ts.ArrowFunction | ts.FunctionExpression, context: SymPyCompilationContext, outerBindings: Map, innerParams: string[], + symbols: Set, sourceFile: ts.SourceFile, -): SymPyResult { +): InternalResult { + // Special case: Distribution.Type(...).map(fn) → direct arithmetic on random variable + if (isDistributionCall(collection)) { + // Compile the distribution expression (e.g., sp.stats.Normal('X', 0, 1)) + const distResult = emitSymPy( + collection, + context, + outerBindings, + innerParams, + symbols, + sourceFile, + ); + if (!distResult.ok) return distResult; + + // The callback parameter becomes the distribution expression itself + const mapBindings = new Map(outerBindings); + const param = callback.parameters[0]; + if (param) { + const paramName = param.name; + if (ts.isIdentifier(paramName)) { + // Simple identifier: (x) => x * 2 + // Bind 'x' to the distribution expression so arithmetic applies directly + mapBindings.set(paramName.getText(sourceFile), distResult.sympyCode); + } + } + + // Compile the body with the distribution substituted in + const body = callback.body; + let bodyResult: InternalResult; + if (ts.isBlock(body)) { + bodyResult = compileBlock( + body, + context, + mapBindings, + symbols, + sourceFile, + ); + } else { + bodyResult = emitSymPy( + body, + context, + mapBindings, + innerParams, + symbols, + sourceFile, + ); + } + if (!bodyResult.ok) return bodyResult; + + return { ok: true, sympyCode: bodyResult.sympyCode }; + } + + // Standard case: array/token .map() const iterVar = "_iter"; const mapBindings = new Map(outerBindings); + // Check if this is a Dynamics tokens.map() call. + // For Dynamics, tokens.map(...) should emit the per-token body only, + // not a list comprehension, because `tokens` is a runtime value that + // Python/SymPy cannot iterate over at compile time. The JS wrapper + // handles the .map() iteration at runtime. + const isTokensMap = + context.constructorFnName === "Dynamics" && + ts.isIdentifier(collection) && + innerParams.length > 0 && + collection.text === innerParams[0]; + const param = callback.parameters[0]; if (param) { const paramName = param.name; @@ -308,30 +418,52 @@ function compileMapCall( // Each field becomes a symbol like _iter_x, _iter_y for (const element of paramName.elements) { const fieldName = element.name.getText(sourceFile); - mapBindings.set(fieldName, `${iterVar}_${fieldName}`); + const symName = `${iterVar}_${fieldName}`; + mapBindings.set(fieldName, symName); + if (isTokensMap) { + symbols.add(symName); + } } } else { // Simple identifier: (token) => ... - mapBindings.set(paramName.getText(sourceFile), iterVar); + const symName = iterVar; + mapBindings.set(paramName.getText(sourceFile), symName); + if (isTokensMap) { + symbols.add(symName); + } } } // Compile the body const body = callback.body; - let bodyResult: SymPyResult; + let bodyResult: InternalResult; if (ts.isBlock(body)) { - bodyResult = compileBlock(body, context, mapBindings, sourceFile); + bodyResult = compileBlock(body, context, mapBindings, symbols, sourceFile); } else { - bodyResult = emitSymPy(body, context, mapBindings, innerParams, sourceFile); + bodyResult = emitSymPy( + body, + context, + mapBindings, + innerParams, + symbols, + sourceFile, + ); } if (!bodyResult.ok) return bodyResult; + // For tokens.map(), emit only the per-token body expression. + // The JS wrapper will handle iterating over tokens at runtime. + if (isTokensMap) { + return { ok: true, sympyCode: bodyResult.sympyCode }; + } + // Compile the collection expression const collectionResult = emitSymPy( collection, context, outerBindings, innerParams, + symbols, sourceFile, ); if (!collectionResult.ok) return collectionResult; @@ -372,8 +504,9 @@ function emitSymPy( context: SymPyCompilationContext, localBindings: Map, innerParams: string[], + symbols: Set, sourceFile: ts.SourceFile, -): SymPyResult { +): InternalResult { // Numeric literal if (ts.isNumericLiteral(node)) { return { ok: true, sympyCode: node.text }; @@ -404,6 +537,7 @@ function emitSymPy( return { ok: true, sympyCode: localBindings.get(name)! }; } if (context.parameterNames.has(name)) { + symbols.add(name); return { ok: true, sympyCode: name }; } // Could be a destructured token field or function param @@ -417,6 +551,7 @@ function emitSymPy( context, localBindings, innerParams, + symbols, sourceFile, ); if (!inner.ok) return inner; @@ -430,6 +565,7 @@ function emitSymPy( context, localBindings, innerParams, + symbols, sourceFile, ); if (!operand.ok) return operand; @@ -457,6 +593,7 @@ function emitSymPy( context, localBindings, innerParams, + symbols, sourceFile, ); if (!left.ok) return left; @@ -465,6 +602,7 @@ function emitSymPy( context, localBindings, innerParams, + symbols, sourceFile, ); if (!right.ok) return right; @@ -558,6 +696,7 @@ function emitSymPy( context, localBindings, innerParams, + symbols, sourceFile, ); if (!condition.ok) return condition; @@ -566,6 +705,7 @@ function emitSymPy( context, localBindings, innerParams, + symbols, sourceFile, ); if (!whenTrue.ok) return whenTrue; @@ -574,6 +714,7 @@ function emitSymPy( context, localBindings, innerParams, + symbols, sourceFile, ); if (!whenFalse.ok) return whenFalse; @@ -601,6 +742,7 @@ function emitSymPy( ts.isIdentifier(node.expression) && node.expression.text === "parameters" ) { + symbols.add(propName); return { ok: true, sympyCode: propName }; } @@ -618,9 +760,11 @@ function emitSymPy( const placeName = placePropAccess.name.text; const indexExpr = elemAccess.argumentExpression; const indexText = indexExpr.getText(sourceFile); + const symbolName = `${placeName}_${indexText}_${propName}`; + symbols.add(symbolName); return { ok: true, - sympyCode: `${placeName}_${indexText}_${propName}`, + sympyCode: symbolName, }; } } @@ -632,6 +776,7 @@ function emitSymPy( context, localBindings, innerParams, + symbols, sourceFile, ); if (!obj.ok) return obj; @@ -645,6 +790,7 @@ function emitSymPy( context, localBindings, innerParams, + symbols, sourceFile, ); if (!obj.ok) return obj; @@ -653,6 +799,7 @@ function emitSymPy( context, localBindings, innerParams, + symbols, sourceFile, ); if (!index.ok) return index; @@ -680,6 +827,7 @@ function emitSymPy( context, localBindings, innerParams, + symbols, sourceFile, ); if (!r.ok) return r; @@ -696,6 +844,7 @@ function emitSymPy( context, localBindings, innerParams, + symbols, sourceFile, ); if (!base.ok) return base; @@ -704,6 +853,7 @@ function emitSymPy( context, localBindings, innerParams, + symbols, sourceFile, ); if (!exp.ok) return exp; @@ -724,7 +874,14 @@ function emitSymPy( const args: string[] = []; for (const a of node.arguments) { - const r = emitSymPy(a, context, localBindings, innerParams, sourceFile); + const r = emitSymPy( + a, + context, + localBindings, + innerParams, + symbols, + sourceFile, + ); if (!r.ok) return r; args.push(r.sympyCode); } @@ -740,7 +897,14 @@ function emitSymPy( const distName = callee.name.text; const args: string[] = []; for (const a of node.arguments) { - const r = emitSymPy(a, context, localBindings, innerParams, sourceFile); + const r = emitSymPy( + a, + context, + localBindings, + innerParams, + symbols, + sourceFile, + ); if (!r.ok) return r; args.push(r.sympyCode); } @@ -778,6 +942,7 @@ function emitSymPy( context, localBindings, innerParams, + symbols, sourceFile, ); if (!arg.ok) return arg; @@ -790,6 +955,7 @@ function emitSymPy( context, localBindings, innerParams, + symbols, sourceFile, ); } @@ -809,6 +975,7 @@ function emitSymPy( context, localBindings, innerParams, + symbols, sourceFile, ); } @@ -830,6 +997,7 @@ function emitSymPy( context, localBindings, innerParams, + symbols, sourceFile, ); if (!result.ok) return result; @@ -855,6 +1023,7 @@ function emitSymPy( context, localBindings, innerParams, + symbols, sourceFile, ); if (!val.ok) return val; @@ -870,6 +1039,7 @@ function emitSymPy( context, localBindings, innerParams, + symbols, sourceFile, ); } @@ -881,6 +1051,7 @@ function emitSymPy( context, localBindings, innerParams, + symbols, sourceFile, ); } diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-via-sympy.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-via-sympy.ts new file mode 100644 index 00000000000..33ea279ea4d --- /dev/null +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-via-sympy.ts @@ -0,0 +1,262 @@ +/** + * Orchestrates the full TypeScript → SymPy → JavaScript compilation pipeline. + * + * 1. Compile user TypeScript to SymPy Python code (compile-to-sympy.ts) + * 2. Run SymPy via Pyodide to evaluate and simplify the expression + * 3. Convert back to JavaScript using custom printer (sympy-codegen.ts) + * 4. Wrap into an executable JS function with runtime argument unpacking + */ + +import type { PyodideInterface } from "pyodide"; + +import type { SDCPN } from "../../core/types/sdcpn"; +import { + buildContextForDifferentialEquation, + buildContextForTransition, + compileToSymPy, + type SymPyCompilationContext, +} from "./compile-to-sympy"; +import { distributionRuntimeCode } from "./distribution"; +import { runSymPyCodegen } from "./sympy-codegen"; + +/** + * Expression type determines how the SymPy result is structured: + * - "scalar": single value (Lambda functions) + * - "array-of-objects": array of {field: expr} objects (Dynamics) + * - "dict-of-lists": {PlaceName: [{field: expr}, ...]} (TransitionKernel) + */ +type ExprType = "scalar" | "array-of-objects" | "dict-of-lists"; + +/** + * Determine the expression type from the constructor function name. + */ +function exprTypeForConstructor(constructorFnName: string): ExprType { + switch (constructorFnName) { + case "Lambda": + return "scalar"; + case "Dynamics": + return "array-of-objects"; + case "TransitionKernel": + return "dict-of-lists"; + default: + return "scalar"; + } +} + +/** + * Build the JavaScript function preamble that unpacks runtime arguments + * into the symbol names used by the SymPy-generated JS expression. + * + * For Lambda/TransitionKernel (tokens by place): + * var Space_0_x = tokens["Space"][0]["x"]; + * var infection_rate = parameters["infection_rate"]; + * + * For Dynamics (flat token array): + * var x = tokens[0]["x"]; + * var infection_rate = parameters["infection_rate"]; + */ +function buildUnpackingPreamble( + context: SymPyCompilationContext, + usedSymbols: string[], + mode: "outer" | "per-token" = "outer", +): string { + const lines: string[] = []; + + for (const sym of usedSymbols) { + // Check if this symbol is a parameter + if (context.parameterNames.has(sym)) { + lines.push(`var ${sym} = parameters[${JSON.stringify(sym)}];`); + continue; + } + + // For Dynamics per-token mode: _iter_fieldName → __token__["fieldName"] + if (mode === "per-token") { + const iterMatch = sym.match(/^_iter_(.+)$/); + if (iterMatch) { + const fieldName = iterMatch[1]!; + lines.push(`var ${sym} = __token__[${JSON.stringify(fieldName)}];`); + continue; + } + // _iter (simple param, not destructured) → __token__ + if (sym === "_iter") { + lines.push(`var _iter = __token__;`); + continue; + } + } + + // Check if this is a token field symbol: PlaceName_index_fieldName + let matched = false; + for (const [placeName, fields] of context.placeTokenFields) { + for (const field of fields) { + // Match patterns like Space_0_x, Space_1_velocity + const pattern = new RegExp(`^${placeName}_(\\d+)_${field}$`); + const match = sym.match(pattern); + if (match) { + const index = match[1]; + if (context.constructorFnName === "Dynamics") { + // Dynamics: tokens is a flat array of token objects + lines.push( + `var ${sym} = tokens[${index}][${JSON.stringify(field)}];`, + ); + } else { + // Lambda/TransitionKernel: tokens is {PlaceName: [token, ...]} + lines.push( + `var ${sym} = tokens[${JSON.stringify(placeName)}][${index}][${JSON.stringify(field)}];`, + ); + } + matched = true; + break; + } + } + if (matched) break; + } + } + + return lines.join("\n"); +} + +/** + * Build the Distribution namespace object for runtime use in generated functions. + * Reuses the same runtime code from distribution.ts. + */ +function buildDistributionNamespace(): unknown { + // eslint-disable-next-line @typescript-eslint/no-implied-eval, no-new-func + return new Function(`${distributionRuntimeCode}\nreturn Distribution;`)(); +} + +/** + * Build a JS function string from SymPy-generated JS code and wrap it + * so it can be called with (tokens, parameters) arguments. + */ +function buildJsFunction( + jsCode: string, + context: SymPyCompilationContext, + usedSymbols: string[], + exprType: ExprType, +): (...args: unknown[]) => unknown { + const preamble = buildUnpackingPreamble(context, usedSymbols); + + let body: string; + + if (exprType === "scalar") { + // scalar: jsCode is a single expression + body = `${preamble}\nreturn ${jsCode};`; + } else if (exprType === "array-of-objects") { + // Dynamics: jsCode is a JSON string of {field: "expr"} for a SINGLE token. + // The SymPy code represents the per-token derivative expression. + // We wrap it in a tokens.map() to apply it to all tokens at runtime. + const parsed = JSON.parse(jsCode) as Record; + const perTokenPreamble = buildUnpackingPreamble( + context, + usedSymbols, + "per-token", + ); + const entries = Object.entries(parsed) + .map(([key, expr]) => `${JSON.stringify(key)}: ${expr}`) + .join(", "); + body = `${preamble}\nreturn tokens.map(function(__token__) {\n ${perTokenPreamble}\n return {${entries}};\n});`; + } else { + // dict-of-lists: jsCode is a JSON string of {PlaceName: [{field: "expr"}, ...]} + const parsed = JSON.parse(jsCode) as Record< + string, + Record[] + >; + const entries = Object.entries(parsed) + .map(([placeName, tokens]) => { + const tokenArray = tokens + .map((token) => { + const fields = Object.entries(token) + .map(([key, expr]) => `${JSON.stringify(key)}: ${expr}`) + .join(", "); + return `{${fields}}`; + }) + .join(", "); + return `${JSON.stringify(placeName)}: [${tokenArray}]`; + }) + .join(",\n "); + body = `${preamble}\nreturn {${entries}};`; + } + + // eslint-disable-next-line @typescript-eslint/no-implied-eval, no-new-func + return new Function("Distribution", "tokens", "parameters", body).bind( + null, + buildDistributionNamespace(), + ) as (...args: unknown[]) => unknown; +} + +/** + * Compile user TypeScript code via the SymPy pipeline. + * + * @param code - User TypeScript code (e.g., `export default Lambda((tokens, params) => ...)`) + * @param constructorFnName - "Lambda" | "Dynamics" | "TransitionKernel" + * @param context - Compilation context with parameter names and token fields + * @param pyodide - Initialized Pyodide instance with SymPy + * @returns Compiled JS function ready for execution + */ +export async function compileUserCodeViaSymPy( + code: string, + constructorFnName: string, + context: SymPyCompilationContext, + pyodide: PyodideInterface, +): Promise<(...args: T) => unknown> { + // Step 1: TypeScript → SymPy Python code + const sympyResult = compileToSymPy(code, context); + + if (!sympyResult.ok) { + throw new Error(sympyResult.error); + } + + const { sympyCode, symbols } = sympyResult; + const exprType = exprTypeForConstructor(constructorFnName); + + // Step 2: Run SymPy via Pyodide → JavaScript code string + const jsCode = await runSymPyCodegen(pyodide, sympyCode, symbols, exprType); + + // Step 3: Build executable JS function + return buildJsFunction(jsCode, context, symbols, exprType) as ( + ...args: T + ) => unknown; +} + +/** + * Build a SymPyCompilationContext and compile user code for a differential equation. + */ +export async function compileDifferentialEquationViaSymPy( + code: string, + sdcpn: SDCPN, + colorId: string, + pyodide: PyodideInterface, +): Promise<(...args: unknown[]) => unknown> { + const context = buildContextForDifferentialEquation(sdcpn, colorId); + return compileUserCodeViaSymPy(code, "Dynamics", context, pyodide); +} + +/** + * Build a SymPyCompilationContext and compile user code for a lambda function. + */ +export async function compileLambdaViaSymPy( + code: string, + sdcpn: SDCPN, + transition: SDCPN["transitions"][0], + pyodide: PyodideInterface, +): Promise<(...args: unknown[]) => unknown> { + const context = buildContextForTransition(sdcpn, transition, "Lambda"); + return compileUserCodeViaSymPy(code, "Lambda", context, pyodide); +} + +/** + * Build a SymPyCompilationContext and compile user code for a transition kernel. + */ +export async function compileTransitionKernelViaSymPy( + code: string, + sdcpn: SDCPN, + transition: SDCPN["transitions"][0], + pyodide: PyodideInterface, +): Promise<(...args: unknown[]) => unknown> { + const context = buildContextForTransition( + sdcpn, + transition, + "TransitionKernel", + ); + return compileUserCodeViaSymPy(code, "TransitionKernel", context, pyodide); +} diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/compute-next-frame.test.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/compute-next-frame.test.ts index 3780808e6c5..f5c12e65319 100644 --- a/libs/@hashintel/petrinaut/src/simulation/simulator/compute-next-frame.test.ts +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/compute-next-frame.test.ts @@ -1,11 +1,49 @@ -import { describe, expect, it } from "vitest"; +import { describe, expect, it, vi } from "vitest"; import type { SDCPN } from "../../core/types/sdcpn"; import { buildSimulation } from "./build-simulation"; import { computeNextFrame } from "./compute-next-frame"; +// Mock SymPy compilation with realistic functions for these tests +vi.mock("./compile-via-sympy", () => ({ + compileDifferentialEquationViaSymPy: vi + .fn() + .mockImplementation((code: string) => { + // Parse simple dynamics patterns from test code + if (code.includes("({ x: 1, y: 1 })")) { + return Promise.resolve((tokens: Record[]) => + tokens.map(() => ({ x: 1, y: 1 })), + ); + } + if (code.includes("({ x: 1 })")) { + return Promise.resolve((tokens: Record[]) => + tokens.map(() => ({ x: 1 })), + ); + } + return Promise.resolve(() => []); + }), + compileLambdaViaSymPy: vi.fn().mockImplementation((code: string) => { + if (code.includes("0.0001")) { + return Promise.resolve(() => 0.0001); + } + return Promise.resolve(() => 1.0); + }), + compileTransitionKernelViaSymPy: vi + .fn() + .mockImplementation((code: string) => { + if (code.includes("Place1")) { + return Promise.resolve(() => ({ + Place1: [{ x: 100.0, y: 200.0 }], + })); + } + return Promise.resolve(() => ({})); + }), +})); + +const mockPyodide = {} as Parameters[1]; + describe("computeNextFrame", () => { - it("should compute next frame with dynamics and transitions", () => { + it("should compute next frame with dynamics and transitions", async () => { // GIVEN a simple SDCPN with one place and one transition const sdcpn: SDCPN = { types: [ @@ -62,14 +100,17 @@ describe("computeNextFrame", () => { ]); // Build the simulation - const simulation = buildSimulation({ - sdcpn, - initialMarking, - parameterValues: {}, - seed: 42, - dt: 0.1, - maxTime: null, - }); + const simulation = await buildSimulation( + { + sdcpn, + initialMarking, + parameterValues: {}, + seed: 42, + dt: 0.1, + maxTime: null, + }, + mockPyodide, + ); // WHEN computing the next frame const result = computeNextFrame(simulation); @@ -91,7 +132,7 @@ describe("computeNextFrame", () => { expect(nextFrame.buffer[1]).toBeCloseTo(20.1); }); - it("should skip dynamics for places without type", () => { + it("should skip dynamics for places without type", async () => { // GIVEN a place without a type const sdcpn: SDCPN = { types: [], @@ -115,14 +156,17 @@ describe("computeNextFrame", () => { ["p1", { values: new Float64Array([]), count: 0 }], ]); - const simulation = buildSimulation({ - sdcpn, - initialMarking, - parameterValues: {}, - seed: 42, - dt: 0.1, - maxTime: null, - }); + const simulation = await buildSimulation( + { + sdcpn, + initialMarking, + parameterValues: {}, + seed: 42, + dt: 0.1, + maxTime: null, + }, + mockPyodide, + ); // WHEN computing the next frame const result = computeNextFrame(simulation); @@ -132,7 +176,7 @@ describe("computeNextFrame", () => { expect(result.transitionFired).toBe(false); }); - it("should skip dynamics for places with dynamics disabled", () => { + it("should skip dynamics for places with dynamics disabled", async () => { // GIVEN a place with dynamics disabled const sdcpn: SDCPN = { types: [ @@ -171,14 +215,17 @@ describe("computeNextFrame", () => { ["p1", { values: new Float64Array([10.0]), count: 1 }], ]); - const simulation = buildSimulation({ - sdcpn, - initialMarking, - parameterValues: {}, - seed: 42, - dt: 0.1, - maxTime: null, - }); + const simulation = await buildSimulation( + { + sdcpn, + initialMarking, + parameterValues: {}, + seed: 42, + dt: 0.1, + maxTime: null, + }, + mockPyodide, + ); // WHEN computing the next frame const result = computeNextFrame(simulation); diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/pyodide-manager.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/pyodide-manager.ts new file mode 100644 index 00000000000..cfb810316fb --- /dev/null +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/pyodide-manager.ts @@ -0,0 +1,58 @@ +/** + * Singleton manager for loading and caching Pyodide + SymPy in the WebWorker. + * + * Pyodide (Python-in-WASM) is loaded once per worker lifetime and reused + * for all SymPy compilation calls. SymPy is installed via micropip on first load. + */ + +import type { PyodideInterface } from "pyodide"; + +/** + * Default CDN URL for Pyodide assets matching the installed version. + * + * Pyodide's auto-detection of its asset location (via stack trace parsing) + * breaks when the JS entry is pre-bundled by Vite/Storybook into a deps cache + * directory that doesn't contain the WASM/zip files. We always pass an explicit + * indexURL to avoid this. Consumers can override with a self-hosted URL via + * the `pyodideUrl` parameter. + */ +const PYODIDE_VERSION = "0.27.7"; +const PYODIDE_CDN_URL = `https://cdn.jsdelivr.net/pyodide/v${PYODIDE_VERSION}/full/`; + +let cachedPyodide: PyodideInterface | null = null; + +/** + * Load Pyodide and install SymPy. Caches the instance for the worker lifetime. + * + * @param pyodideUrl - Optional URL where Pyodide WASM assets are served. + * If not provided, uses the official Pyodide CDN. + * @returns The initialized Pyodide instance with SymPy available. + */ +export async function loadPyodideAndSymPy( + pyodideUrl?: string, +): Promise { + if (cachedPyodide) { + return cachedPyodide; + } + + const { loadPyodide } = await import("pyodide"); + + const pyodide = await loadPyodide({ + indexURL: pyodideUrl ?? PYODIDE_CDN_URL, + }); + + // Install sympy via micropip (Pyodide's package manager) + await pyodide.loadPackage("micropip"); + const micropip = pyodide.pyimport("micropip"); + await micropip.install("sympy"); + + cachedPyodide = pyodide; + return pyodide; +} + +/** + * Dispose the cached Pyodide instance, freeing WASM memory. + */ +export function disposePyodide(): void { + cachedPyodide = null; +} diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/sympy-codegen.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/sympy-codegen.ts new file mode 100644 index 00000000000..34ac2a91b5c --- /dev/null +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/sympy-codegen.ts @@ -0,0 +1,163 @@ +/** + * Python codegen module for converting SymPy expressions to JavaScript. + * + * The Python source is stored as a TS string constant and executed via Pyodide. + * It defines a custom JS printer that handles SymPy RandomSymbol → Distribution.Type() + * and a compile_expression() function for full expression evaluation. + */ + +import type { PyodideInterface } from "pyodide"; + +/** + * Python source code that defines the SymPy → JS compilation module. + * Loaded into Pyodide once, then called for each expression. + */ +export const SYMPY_CODEGEN_PYTHON = ` +import json +import sympy as sp +from sympy.stats import Normal, Uniform, LogNormal +from sympy.stats.crv_types import NormalDistribution, UniformDistribution, LogNormalDistribution +from sympy.stats.rv import RandomSymbol +from sympy.printing.jscode import JavascriptCodePrinter + +class PetrinautJSPrinter(JavascriptCodePrinter): + """Custom JS printer that handles SymPy RandomSymbol -> Distribution.Type()""" + + def _print_RandomSymbol(self, expr): + # Extract the underlying distribution from the probability space + dist = expr.pspace.distribution + if isinstance(dist, NormalDistribution): + mean_js = self._print(dist.mean) + std_js = self._print(dist.std) + return f"Distribution.Gaussian({mean_js}, {std_js})" + elif isinstance(dist, UniformDistribution): + left_js = self._print(dist.left) + right_js = self._print(dist.right) + return f"Distribution.Uniform({left_js}, {right_js})" + elif isinstance(dist, LogNormalDistribution): + mean_js = self._print(dist.mean) + std_js = self._print(dist.std) + return f"Distribution.Lognormal({mean_js}, {std_js})" + else: + raise ValueError(f"Unsupported distribution: {type(dist)}") + + def _print_Expr_with_random(self, expr): + """ + Handle expressions containing RandomSymbols that didn't simplify + to a known distribution. Decomposes into base distribution + .map(). + """ + # Find all RandomSymbol nodes in the expression tree + random_symbols = list(expr.atoms(RandomSymbol)) + if len(random_symbols) != 1: + raise ValueError( + f"Expression contains {len(random_symbols)} random variables; expected exactly 1" + ) + rs = random_symbols[0] + # Create a dummy symbol to substitute + x = sp.Symbol('__x__') + # Replace the random symbol with x + body = expr.subs(rs, x) + # Print the base distribution + dist_js = self._print_RandomSymbol(rs) + # Print the body as a JS arrow function + body_js = self._print(body).replace('__x__', 'x') + return f"{dist_js}.map((x) => {body_js})" + + +def _has_random_symbol(expr): + """Check if a SymPy expression contains any RandomSymbol.""" + if isinstance(expr, RandomSymbol): + return True + if hasattr(expr, 'atoms'): + return bool(expr.atoms(RandomSymbol)) + return False + + +def _print_expr(printer, expr): + """Print a single expression, handling RandomSymbol decomposition.""" + if isinstance(expr, RandomSymbol): + return printer._print_RandomSymbol(expr) + if _has_random_symbol(expr): + return printer._print_Expr_with_random(expr) + return printer.doprint(expr) + + +def compile_expression(sympy_code, symbols, expr_type): + """ + Evaluate a SymPy code string and convert the result to JavaScript. + + Args: + sympy_code: Python code string producing a SymPy expression + symbols: list of symbol names to declare + expr_type: "scalar" | "array-of-objects" | "dict-of-lists" + + Returns: + JS code string (or JSON for structured types) + """ + sym_dict = {name: sp.Symbol(name) for name in symbols} + namespace = { + "sp": sp, + **sym_dict, + "True": True, + "False": False, + } + + # Evaluate the SymPy code string + expr = eval(sympy_code, namespace) + printer = PetrinautJSPrinter() + + if expr_type == "scalar": + return _print_expr(printer, expr) + elif expr_type == "dict-of-lists": + # TransitionKernel: {'PlaceName': [{'field': expr}, ...]} + result = {} + for place_name, token_list in expr.items(): + result[place_name] = [ + {k: _print_expr(printer, v) for k, v in token.items()} + for token in token_list + ] + return json.dumps(result) + elif expr_type == "array-of-objects": + # Dynamics: {'field': expr} — a single per-token derivative expression. + # The JS wrapper handles iterating over tokens at runtime. + return json.dumps( + {k: _print_expr(printer, v) for k, v in expr.items()} + ) + else: + raise ValueError(f"Unknown expr_type: {expr_type}") +`; + +let codegenLoaded = false; + +/** + * Run the SymPy codegen pipeline: evaluate a SymPy expression and convert to JS code. + * + * @param pyodide - Initialized Pyodide instance with SymPy + * @param sympyCode - Python code string that evaluates to a SymPy expression + * @param symbols - Symbol names used in the expression + * @param exprType - Expression structure type + * @returns JavaScript code string + */ +export async function runSymPyCodegen( + pyodide: PyodideInterface, + sympyCode: string, + symbols: string[], + exprType: "scalar" | "array-of-objects" | "dict-of-lists", +): Promise { + // Load the codegen module once + if (!codegenLoaded) { + await pyodide.runPythonAsync(SYMPY_CODEGEN_PYTHON); + codegenLoaded = true; + } + + // Call compile_expression with the provided arguments + const result = await pyodide.runPythonAsync(` +compile_expression( + ${JSON.stringify(sympyCode)}, + ${JSON.stringify(symbols)}, + ${JSON.stringify(exprType)} +) +`); + + return result as string; +} diff --git a/libs/@hashintel/petrinaut/src/simulation/worker/messages.ts b/libs/@hashintel/petrinaut/src/simulation/worker/messages.ts index 40bf6873596..5a5f6cd7ca2 100644 --- a/libs/@hashintel/petrinaut/src/simulation/worker/messages.ts +++ b/libs/@hashintel/petrinaut/src/simulation/worker/messages.ts @@ -33,6 +33,8 @@ export type InitMessage = { maxFramesAhead?: number; /** Number of frames to compute in each batch before checking for messages */ batchSize?: number; + /** Optional URL where Pyodide WASM assets are served (for self-hosting) */ + pyodideUrl?: string; }; /** @@ -146,6 +148,15 @@ export type PausedMessage = { frameNumber: number; }; +/** + * Progress update during compilation phase (Pyodide + SymPy loading). + */ +export type CompilingMessage = { + type: "compiling"; + /** Current phase of the compilation pipeline */ + phase: "loading-sympy" | "compiling"; +}; + /** * Union of all messages that can be sent from worker to main thread. */ @@ -155,4 +166,5 @@ export type ToMainMessage = | FramesMessage | CompleteMessage | ErrorMessage - | PausedMessage; + | PausedMessage + | CompilingMessage; diff --git a/libs/@hashintel/petrinaut/src/simulation/worker/simulation.worker.test.ts b/libs/@hashintel/petrinaut/src/simulation/worker/simulation.worker.test.ts index 9b9c49f8840..e1a17ab7c80 100644 --- a/libs/@hashintel/petrinaut/src/simulation/worker/simulation.worker.test.ts +++ b/libs/@hashintel/petrinaut/src/simulation/worker/simulation.worker.test.ts @@ -10,6 +10,18 @@ import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; import type { SDCPN } from "../../core/types/sdcpn"; import type { ToMainMessage, ToWorkerMessage } from "./messages"; +// Mock pyodide-manager to avoid loading real Pyodide in tests +vi.mock("../simulator/pyodide-manager", () => ({ + loadPyodideAndSymPy: vi.fn().mockResolvedValue({}), +})); + +// Mock SymPy compilation to return simple passthrough functions +vi.mock("../simulator/compile-via-sympy", () => ({ + compileDifferentialEquationViaSymPy: vi.fn().mockResolvedValue(() => []), + compileLambdaViaSymPy: vi.fn().mockResolvedValue(() => 1.0), + compileTransitionKernelViaSymPy: vi.fn().mockResolvedValue(() => ({})), +})); + // Store messages posted by worker let postedMessages: ToMainMessage[] = []; @@ -108,7 +120,7 @@ describe("simulation.worker", () => { expect(readyMessages[0]?.initialFrameCount).toBe(0); }); - it("initializes simulation with valid SDCPN", () => { + it("initializes simulation with valid SDCPN", async () => { clearMessages(); const sdcpn = createMinimalSDCPN(); @@ -122,6 +134,11 @@ describe("simulation.worker", () => { maxTime: null, }); + // Wait for async init to complete + await vi.waitFor(() => { + expect(getMessages("ready")).toHaveLength(1); + }); + // Should send initial frame and ready message const frameMessages = getMessages("frame"); expect(frameMessages).toHaveLength(1); @@ -132,7 +149,7 @@ describe("simulation.worker", () => { expect(readyMessages[0]?.initialFrameCount).toBe(1); }); - it("posts error message for invalid SDCPN", () => { + it("posts error message for invalid SDCPN", async () => { clearMessages(); // SDCPN with invalid initial marking (place doesn't exist) @@ -149,6 +166,11 @@ describe("simulation.worker", () => { maxTime: null, }); + // Wait for async init to complete + await vi.waitFor(() => { + expect(getMessages("error")).toHaveLength(1); + }); + const errorMessages = getMessages("error"); expect(errorMessages).toHaveLength(1); expect(errorMessages[0]?.message).toContain("nonexistent"); @@ -166,7 +188,7 @@ describe("simulation.worker", () => { expect(errorMessages[0]?.message).toContain("not initialized"); }); - it("posts paused message when pausing", () => { + it("posts paused message when pausing", async () => { clearMessages(); // Initialize first @@ -180,6 +202,11 @@ describe("simulation.worker", () => { dt: 0.1, maxTime: null, }); + + // Wait for async init to complete + await vi.waitFor(() => { + expect(getMessages("ready")).toHaveLength(1); + }); clearMessages(); // Pause @@ -190,7 +217,7 @@ describe("simulation.worker", () => { expect(pausedMessages[0]?.frameNumber).toBe(0); }); - it("clears state on stop", () => { + it("clears state on stop", async () => { clearMessages(); // Initialize @@ -204,6 +231,11 @@ describe("simulation.worker", () => { dt: 0.1, maxTime: null, }); + + // Wait for async init to complete + await vi.waitFor(() => { + expect(getMessages("ready")).toHaveLength(1); + }); clearMessages(); // Stop @@ -219,7 +251,7 @@ describe("simulation.worker", () => { }); describe("backpressure", () => { - it("accepts setBackpressure message", () => { + it("accepts setBackpressure message", async () => { clearMessages(); // Initialize @@ -233,6 +265,11 @@ describe("simulation.worker", () => { dt: 0.1, maxTime: null, }); + + // Wait for async init to complete + await vi.waitFor(() => { + expect(getMessages("ready")).toHaveLength(1); + }); clearMessages(); // Set backpressure config - should not error @@ -248,7 +285,7 @@ describe("simulation.worker", () => { }); describe("ack (backpressure)", () => { - it("accepts ack message", () => { + it("accepts ack message", async () => { clearMessages(); // Initialize @@ -262,6 +299,11 @@ describe("simulation.worker", () => { dt: 0.1, maxTime: null, }); + + // Wait for async init to complete + await vi.waitFor(() => { + expect(getMessages("ready")).toHaveLength(1); + }); clearMessages(); // Send ack - should not error diff --git a/libs/@hashintel/petrinaut/src/simulation/worker/simulation.worker.ts b/libs/@hashintel/petrinaut/src/simulation/worker/simulation.worker.ts index 0c30c45b1ad..a3882a13644 100644 --- a/libs/@hashintel/petrinaut/src/simulation/worker/simulation.worker.ts +++ b/libs/@hashintel/petrinaut/src/simulation/worker/simulation.worker.ts @@ -12,6 +12,7 @@ import { SDCPNItemError } from "../../core/errors"; import { buildSimulation } from "../simulator/build-simulation"; import { computeNextFrame } from "../simulator/compute-next-frame"; +import { loadPyodideAndSymPy } from "../simulator/pyodide-manager"; import type { SimulationInstance } from "../simulator/types"; import type { ToMainMessage, ToWorkerMessage } from "./messages"; @@ -143,50 +144,60 @@ self.onmessage = (event: MessageEvent) => { switch (message.type) { case "init": { - try { - // Convert serialized initialMarking back to Map - const initialMarking = new Map(message.initialMarking); - - // Build simulation (compiles user code) - simulation = buildSimulation({ - sdcpn: message.sdcpn, - initialMarking, - parameterValues: message.parameterValues, - seed: message.seed, - dt: message.dt, - maxTime: message.maxTime, - }); + void (async () => { + try { + // Load Pyodide + SymPy (cached after first load) + postTypedMessage({ type: "compiling", phase: "loading-sympy" }); + const pyodide = await loadPyodideAndSymPy(message.pyodideUrl); - // Configure backpressure from init message or use defaults - maxFramesAhead = message.maxFramesAhead ?? DEFAULT_MAX_FRAMES_AHEAD; - batchSize = message.batchSize ?? DEFAULT_BATCH_SIZE; + // Convert serialized initialMarking back to Map + const initialMarking = new Map(message.initialMarking); - // Reset to -1: blocks computation until first ack - lastAckedFrame = -1; - isRunning = false; - simulationStatus = "ready"; + // Build simulation (compiles user code via SymPy) + postTypedMessage({ type: "compiling", phase: "compiling" }); + simulation = await buildSimulation( + { + sdcpn: message.sdcpn, + initialMarking, + parameterValues: message.parameterValues, + seed: message.seed, + dt: message.dt, + maxTime: message.maxTime, + }, + pyodide, + ); - // Send initial frame - const initialFrame = simulation.frames[0]; - if (initialFrame) { - postTypedMessage({ type: "frame", frame: initialFrame }); - } + // Configure backpressure from init message or use defaults + maxFramesAhead = message.maxFramesAhead ?? DEFAULT_MAX_FRAMES_AHEAD; + batchSize = message.batchSize ?? DEFAULT_BATCH_SIZE; - postTypedMessage({ - type: "ready", - initialFrameCount: simulation.frames.length, - }); - } catch (error) { - simulationStatus = "error"; - postTypedMessage({ - type: "error", - message: - error instanceof Error - ? error.message - : "Failed to initialize simulation", - itemId: error instanceof SDCPNItemError ? error.itemId : null, - }); - } + // Reset to -1: blocks computation until first ack + lastAckedFrame = -1; + isRunning = false; + simulationStatus = "ready"; + + // Send initial frame + const initialFrame = simulation.frames[0]; + if (initialFrame) { + postTypedMessage({ type: "frame", frame: initialFrame }); + } + + postTypedMessage({ + type: "ready", + initialFrameCount: simulation.frames.length, + }); + } catch (error) { + simulationStatus = "error"; + postTypedMessage({ + type: "error", + message: + error instanceof Error + ? error.message + : "Failed to initialize simulation", + itemId: error instanceof SDCPNItemError ? error.itemId : null, + }); + } + })(); break; } diff --git a/libs/@hashintel/petrinaut/src/simulation/worker/use-simulation-worker.ts b/libs/@hashintel/petrinaut/src/simulation/worker/use-simulation-worker.ts index 85c83e6303d..df967a6321b 100644 --- a/libs/@hashintel/petrinaut/src/simulation/worker/use-simulation-worker.ts +++ b/libs/@hashintel/petrinaut/src/simulation/worker/use-simulation-worker.ts @@ -23,6 +23,7 @@ import type { ToMainMessage, ToWorkerMessage } from "./messages"; export type WorkerStatus = | "idle" | "initializing" + | "compiling" | "ready" | "running" | "paused" @@ -54,6 +55,8 @@ export type InitializeParams = { maxFramesAhead?: number; /** Number of frames to compute in each batch before checking for messages */ batchSize?: number; + /** Optional URL where Pyodide WASM assets are served (for self-hosting) */ + pyodideUrl?: string; }; /** @@ -182,6 +185,13 @@ export function useSimulationWorker(): { })); break; + case "compiling": + setState((prev) => ({ + ...prev, + status: "compiling", + })); + break; + case "error": setState((prev) => ({ ...prev, @@ -237,6 +247,7 @@ export function useSimulationWorker(): { maxTime, maxFramesAhead, batchSize, + pyodideUrl, }) => { // Cancel any pending initialization if (pendingInitRef.current) { @@ -269,6 +280,7 @@ export function useSimulationWorker(): { maxTime, maxFramesAhead, batchSize, + pyodideUrl, }); return promise; diff --git a/libs/@hashintel/petrinaut/src/state/editor-provider.tsx b/libs/@hashintel/petrinaut/src/state/editor-provider.tsx index e847881ada6..2e556410658 100644 --- a/libs/@hashintel/petrinaut/src/state/editor-provider.tsx +++ b/libs/@hashintel/petrinaut/src/state/editor-provider.tsx @@ -73,7 +73,7 @@ export const EditorProvider: React.FC = ({ children }) => { }, setLeftSidebarWidth: (width) => setState((prev) => ({ ...prev, leftSidebarWidth: width })), - setPropertiesPanelWidth: (width) => + setPropertiesPanelWidth: (width: number) => setState((prev) => ({ ...prev, propertiesPanelWidth: width })), setBottomPanelOpen: (isOpen) => { triggerPanelAnimation(); diff --git a/libs/@hashintel/petrinaut/vite.config.ts b/libs/@hashintel/petrinaut/vite.config.ts index 57623a868df..85bded2b588 100644 --- a/libs/@hashintel/petrinaut/vite.config.ts +++ b/libs/@hashintel/petrinaut/vite.config.ts @@ -21,6 +21,7 @@ export default defineConfig(({ command }) => ({ "react-dom", "@xyflow/react", "@babel/standalone", + "pyodide", ], output: { globals: { diff --git a/yarn.lock b/yarn.lock index 49fd0ac25d6..83186d3b2f6 100644 --- a/yarn.lock +++ b/yarn.lock @@ -8033,6 +8033,7 @@ __metadata: jsdom: "npm:24.1.3" monaco-editor: "npm:0.55.1" oxlint: "npm:1.55.0" + pyodide: "npm:0.27.7" react: "npm:19.2.3" react-dom: "npm:19.2.3" react-icons: "npm:5.5.0" @@ -39210,6 +39211,15 @@ __metadata: languageName: node linkType: hard +"pyodide@npm:0.27.7": + version: 0.27.7 + resolution: "pyodide@npm:0.27.7" + dependencies: + ws: "npm:^8.5.0" + checksum: 10c0/9ea914db3f75dd89e494d68e9809c64d96b29e0bdbc32dbb790889974f6379aac621b61712891508b4df0c0bfc6e7c370ff963f64c306c95e3fa1a190095dcec + languageName: node + linkType: hard + "qs@npm:6.14.2": version: 6.14.2 resolution: "qs@npm:6.14.2" @@ -46553,7 +46563,7 @@ __metadata: languageName: node linkType: hard -"ws@npm:^8.15.1, ws@npm:^8.17.1, ws@npm:^8.18.0, ws@npm:^8.18.2, ws@npm:^8.18.3": +"ws@npm:^8.15.1, ws@npm:^8.17.1, ws@npm:^8.18.0, ws@npm:^8.18.2, ws@npm:^8.18.3, ws@npm:^8.5.0": version: 8.19.0 resolution: "ws@npm:8.19.0" peerDependencies: