From c54838e7e862a7c4115d3087b20433dc8e03fb16 Mon Sep 17 00:00:00 2001 From: Chris Feijoo Date: Fri, 13 Mar 2026 02:32:40 +0100 Subject: [PATCH 01/14] FE-514: Add SymPy compilation diagnostics and export option Integrate compileToSymPy into the LSP checker so SymPy-incompatible expressions surface as warnings in the Diagnostics tab with accurate source positions. Add "JSON with SymPy expressions" export menu entry. Co-Authored-By: Claude Opus 4.6 --- .../petrinaut/src/lsp/lib/checker.ts | 116 +++ .../simulator/compile-to-sympy.test.ts | 624 ++++++++++++++ .../simulation/simulator/compile-to-sympy.ts | 767 ++++++++++++++++++ .../src/views/Editor/editor-view.tsx | 10 + .../src/views/Editor/lib/export-sympy.ts | 93 +++ 5 files changed, 1610 insertions(+) create mode 100644 libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts create mode 100644 libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts create mode 100644 libs/@hashintel/petrinaut/src/views/Editor/lib/export-sympy.ts diff --git a/libs/@hashintel/petrinaut/src/lsp/lib/checker.ts b/libs/@hashintel/petrinaut/src/lsp/lib/checker.ts index 3c5029c5fe3..375a488b7e2 100644 --- a/libs/@hashintel/petrinaut/src/lsp/lib/checker.ts +++ b/libs/@hashintel/petrinaut/src/lsp/lib/checker.ts @@ -1,6 +1,12 @@ import type ts from "typescript"; import type { SDCPN } from "../../core/types/sdcpn"; +import { + buildContextForDifferentialEquation, + buildContextForTransition, + compileToSymPy, + type SymPyResult, +} from "../../simulation/simulator/compile-to-sympy"; import type { SDCPNLanguageServer } from "./create-sdcpn-language-service"; import { getItemFilePath } from "./file-paths"; @@ -27,6 +33,113 @@ export type SDCPNCheckResult = { itemDiagnostics: SDCPNDiagnostic[]; }; +/** + * Creates a synthetic ts.Diagnostic from a SymPy compilation error result. + * Uses category 0 (Warning) since SymPy compilation failures are informational + * — the TypeScript code may still be valid, just not convertible to SymPy. + */ +function makeSymPyDiagnostic( + result: SymPyResult & { ok: false }, +): ts.Diagnostic { + return { + category: 0, // Warning + code: 99000, // Custom code for SymPy diagnostics + messageText: `SymPy: ${result.error}`, + file: undefined, + start: result.start, + length: result.length, + }; +} + +/** + * Appends a SymPy diagnostic to the item diagnostics list, merging with + * any existing entry for the same item. + */ +function appendSymPyDiagnostic( + itemDiagnostics: SDCPNDiagnostic[], + itemId: string, + itemType: ItemType, + filePath: string, + result: SymPyResult & { ok: false }, +): void { + const diag = makeSymPyDiagnostic(result); + const existing = itemDiagnostics.find( + (di) => di.itemId === itemId && di.itemType === itemType, + ); + if (existing) { + existing.diagnostics.push(diag); + } else { + itemDiagnostics.push({ itemId, itemType, filePath, diagnostics: [diag] }); + } +} + +/** + * Runs SymPy compilation on all SDCPN code expressions and appends + * any errors as warning diagnostics. + */ +function checkSymPyCompilation( + sdcpn: SDCPN, + itemDiagnostics: SDCPNDiagnostic[], +): void { + // Check differential equations + for (const de of sdcpn.differentialEquations) { + const ctx = buildContextForDifferentialEquation(sdcpn, de.colorId); + const result = compileToSymPy(de.code, ctx); + if (!result.ok) { + const filePath = getItemFilePath("differential-equation-code", { + id: de.id, + }); + appendSymPyDiagnostic( + itemDiagnostics, + de.id, + "differential-equation", + filePath, + result, + ); + } + } + + // Check transition lambdas and kernels + for (const transition of sdcpn.transitions) { + const lambdaCtx = buildContextForTransition(sdcpn, transition, "Lambda"); + const lambdaResult = compileToSymPy(transition.lambdaCode, lambdaCtx); + if (!lambdaResult.ok) { + const filePath = getItemFilePath("transition-lambda-code", { + transitionId: transition.id, + }); + appendSymPyDiagnostic( + itemDiagnostics, + transition.id, + "transition-lambda", + filePath, + lambdaResult, + ); + } + + const kernelCtx = buildContextForTransition( + sdcpn, + transition, + "TransitionKernel", + ); + const kernelResult = compileToSymPy( + transition.transitionKernelCode, + kernelCtx, + ); + if (!kernelResult.ok) { + const filePath = getItemFilePath("transition-kernel-code", { + transitionId: transition.id, + }); + appendSymPyDiagnostic( + itemDiagnostics, + transition.id, + "transition-kernel", + filePath, + kernelResult, + ); + } + } +} + /** * Checks the validity of an SDCPN by running TypeScript validation * on all user-provided code (transitions and differential equations). @@ -111,6 +224,9 @@ export function checkSDCPN( } } + // Run SymPy compilation checks on all code expressions + checkSymPyCompilation(sdcpn, itemDiagnostics); + return { isValid: itemDiagnostics.length === 0, itemDiagnostics, diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts new file mode 100644 index 00000000000..889adeec036 --- /dev/null +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts @@ -0,0 +1,624 @@ +import { describe, expect, it } from "vitest"; + +import { + compileToSymPy, + type SymPyCompilationContext, +} from "./compile-to-sympy"; + +const defaultContext: SymPyCompilationContext = { + parameterNames: new Set([ + "infection_rate", + "recovery_rate", + "gravitational_constant", + "earth_radius", + "satellite_radius", + "crash_threshold", + ]), + placeTokenFields: new Map([ + ["Space", ["x", "y", "direction", "velocity"]], + ["Susceptible", []], + ["Infected", []], + ]), + constructorFnName: "Lambda", +}; + +function dynamicsContext(): SymPyCompilationContext { + return { ...defaultContext, constructorFnName: "Dynamics" }; +} + +function kernelContext(): SymPyCompilationContext { + return { ...defaultContext, constructorFnName: "TransitionKernel" }; +} + +describe("compileToSymPy", () => { + describe("basic expressions", () => { + it("should compile a numeric literal", () => { + const result = compileToSymPy( + "export default Lambda(() => 1)", + defaultContext, + ); + expect(result).toEqual({ ok: true, sympyCode: "1" }); + }); + + it("should compile a decimal literal", () => { + const result = compileToSymPy( + "export default Lambda(() => 3.14)", + defaultContext, + ); + expect(result).toEqual({ ok: true, sympyCode: "3.14" }); + }); + + it("should compile boolean true", () => { + const result = compileToSymPy( + "export default Lambda(() => true)", + defaultContext, + ); + expect(result).toEqual({ ok: true, sympyCode: "True" }); + }); + + it("should compile boolean false", () => { + const result = compileToSymPy( + "export default Lambda(() => false)", + defaultContext, + ); + expect(result).toEqual({ ok: true, sympyCode: "False" }); + }); + + it("should compile Infinity", () => { + const result = compileToSymPy( + "export default Lambda(() => Infinity)", + defaultContext, + ); + expect(result).toEqual({ ok: true, sympyCode: "sp.oo" }); + }); + }); + + describe("parameter access", () => { + it("should compile parameters.x to symbol x", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => parameters.infection_rate)", + defaultContext, + ); + expect(result).toEqual({ ok: true, sympyCode: "infection_rate" }); + }); + + it("should compile parameters in arithmetic", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => parameters.infection_rate * 2)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "infection_rate * 2", + }); + }); + }); + + describe("binary arithmetic", () => { + it("should compile addition", () => { + const result = compileToSymPy( + "export default Lambda(() => 1 + 2)", + defaultContext, + ); + expect(result).toEqual({ ok: true, sympyCode: "1 + 2" }); + }); + + it("should compile subtraction", () => { + const result = compileToSymPy( + "export default Lambda(() => 5 - 3)", + defaultContext, + ); + expect(result).toEqual({ ok: true, sympyCode: "5 - 3" }); + }); + + it("should compile multiplication", () => { + const result = compileToSymPy( + "export default Lambda(() => 2 * 3)", + defaultContext, + ); + expect(result).toEqual({ ok: true, sympyCode: "2 * 3" }); + }); + + it("should compile division", () => { + const result = compileToSymPy( + "export default Lambda(() => 1 / 3)", + defaultContext, + ); + expect(result).toEqual({ ok: true, sympyCode: "1 / 3" }); + }); + + it("should compile power operator", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => parameters.satellite_radius ** 2)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "satellite_radius**2", + }); + }); + + it("should compile modulo", () => { + const result = compileToSymPy( + "export default Lambda(() => 10 % 3)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.Mod(10, 3)", + }); + }); + }); + + describe("comparison operators", () => { + it("should compile less than", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => parameters.infection_rate < 5)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "infection_rate < 5", + }); + }); + + it("should compile greater than or equal", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => parameters.infection_rate >= 1)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "infection_rate >= 1", + }); + }); + + it("should compile strict equality to Eq", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => parameters.infection_rate === 3)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.Eq(infection_rate, 3)", + }); + }); + + it("should compile inequality to Ne", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => parameters.infection_rate !== 0)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.Ne(infection_rate, 0)", + }); + }); + }); + + describe("logical operators", () => { + it("should compile && to sp.And", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => parameters.infection_rate > 0 && parameters.recovery_rate > 0)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.And(infection_rate > 0, recovery_rate > 0)", + }); + }); + + it("should compile || to sp.Or", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => parameters.infection_rate === 0 || parameters.recovery_rate === 0)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.Or(sp.Eq(infection_rate, 0), sp.Eq(recovery_rate, 0))", + }); + }); + }); + + describe("prefix unary operators", () => { + it("should compile negation", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => -parameters.infection_rate)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "-(infection_rate)", + }); + }); + + it("should compile logical not", () => { + const result = compileToSymPy( + "export default Lambda(() => !true)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.Not(True)", + }); + }); + }); + + describe("Math functions", () => { + it("should compile Math.cos", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => Math.cos(parameters.infection_rate))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.cos(infection_rate)", + }); + }); + + it("should compile Math.sin", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => Math.sin(parameters.infection_rate))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.sin(infection_rate)", + }); + }); + + it("should compile Math.sqrt", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => Math.sqrt(parameters.infection_rate))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.sqrt(infection_rate)", + }); + }); + + it("should compile Math.log", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => Math.log(parameters.infection_rate))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.log(infection_rate)", + }); + }); + + it("should compile Math.exp", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => Math.exp(parameters.infection_rate))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.exp(infection_rate)", + }); + }); + + it("should compile Math.abs", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => Math.abs(parameters.infection_rate))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.Abs(infection_rate)", + }); + }); + + it("should compile Math.pow to exponentiation", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => Math.pow(parameters.infection_rate, 2))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "(infection_rate)**(2)", + }); + }); + + it("should compile Math.hypot to sqrt of sum of squares", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => Math.hypot(parameters.infection_rate, parameters.recovery_rate))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.sqrt((infection_rate)**2 + (recovery_rate)**2)", + }); + }); + + it("should compile Math.PI", () => { + const result = compileToSymPy( + "export default Lambda(() => Math.PI)", + defaultContext, + ); + expect(result).toEqual({ ok: true, sympyCode: "sp.pi" }); + }); + + it("should compile Math.E", () => { + const result = compileToSymPy( + "export default Lambda(() => Math.E)", + defaultContext, + ); + expect(result).toEqual({ ok: true, sympyCode: "sp.E" }); + }); + }); + + describe("token access", () => { + it("should compile tokens.Place[0].field to symbol", () => { + const result = compileToSymPy( + "export default Lambda((tokens) => tokens.Space[0].x)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "Space_0_x", + }); + }); + + it("should compile token field comparison", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => tokens.Space[0].velocity < parameters.crash_threshold)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "Space_0_velocity < crash_threshold", + }); + }); + }); + + describe("conditional (ternary) expression", () => { + it("should compile to Piecewise", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => parameters.infection_rate > 1 ? parameters.infection_rate : 0)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: + "sp.Piecewise((infection_rate, infection_rate > 1), (0, True))", + }); + }); + }); + + describe("Distribution calls", () => { + it("should compile Distribution.Gaussian", () => { + const result = compileToSymPy( + "export default Lambda(() => Distribution.Gaussian(0, 1))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.stats.Normal('X', 0, 1)", + }); + }); + + it("should compile Distribution.Uniform", () => { + const result = compileToSymPy( + "export default Lambda(() => Distribution.Uniform(0, 1))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.stats.Uniform('X', 0, 1)", + }); + }); + + it("should compile Distribution.Lognormal", () => { + const result = compileToSymPy( + "export default Lambda(() => Distribution.Lognormal(0, 1))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.stats.LogNormal('X', 0, 1)", + }); + }); + }); + + describe("block body with const and return", () => { + it("should compile block body with const bindings", () => { + const result = compileToSymPy( + `export default Dynamics((tokens, parameters) => { + const mu = parameters.gravitational_constant; + return mu * 2; + })`, + dynamicsContext(), + ); + expect(result).toEqual({ + ok: true, + sympyCode: "gravitational_constant * 2", + }); + }); + + it("should compile block body with multiple const bindings", () => { + const result = compileToSymPy( + `export default Dynamics((tokens, parameters) => { + const a = parameters.infection_rate; + const b = parameters.recovery_rate; + return a + b; + })`, + dynamicsContext(), + ); + expect(result).toEqual({ + ok: true, + sympyCode: "infection_rate + recovery_rate", + }); + }); + }); + + describe("real-world expressions", () => { + it("should compile SIR infection rate lambda", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => parameters.infection_rate)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "infection_rate", + }); + }); + + it("should compile satellite crash predicate lambda", () => { + const result = compileToSymPy( + `export default Lambda((tokens, parameters) => { + const distance = Math.hypot(tokens.Space[0].x, tokens.Space[0].y); + return distance < parameters.earth_radius + parameters.crash_threshold + parameters.satellite_radius; + })`, + defaultContext, + ); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.sympyCode).toContain("sp.sqrt"); + expect(result.sympyCode).toContain("<"); + expect(result.sympyCode).toContain("earth_radius"); + } + }); + + it("should compile orbital dynamics expression", () => { + const result = compileToSymPy( + `export default Dynamics((tokens, parameters) => { + const mu = parameters.gravitational_constant; + const r = Math.hypot(tokens.Space[0].x, tokens.Space[0].y); + const ax = (-mu * tokens.Space[0].x) / (r * r * r); + return ax; + })`, + dynamicsContext(), + ); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.sympyCode).toContain("gravitational_constant"); + expect(result.sympyCode).toContain("Space_0_x"); + } + }); + + it("should compile transition kernel with object literal", () => { + const result = compileToSymPy( + `export default TransitionKernel((tokens) => { + return { + x: tokens.Space[0].x, + y: tokens.Space[0].y, + velocity: 0, + direction: 0 + }; + })`, + kernelContext(), + ); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.sympyCode).toContain("'x': Space_0_x"); + expect(result.sympyCode).toContain("'y': Space_0_y"); + expect(result.sympyCode).toContain("'velocity': 0"); + } + }); + }); + + describe("error handling", () => { + it("should reject code without default export", () => { + const result = compileToSymPy( + "const x = Lambda(() => 1);", + defaultContext, + ); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toContain("No default export"); + expect(result.start).toBe(0); + expect(result.length).toBe(0); + } + }); + + it("should reject wrong constructor function name with position", () => { + const code = "export default WrongName(() => 1)"; + const result = compileToSymPy(code, defaultContext); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toContain("Expected Lambda(...)"); + // "WrongName" starts at position 15 + expect(result.start).toBe(code.indexOf("WrongName")); + expect(result.length).toBe("WrongName".length); + } + }); + + it("should reject unsupported Math function with position", () => { + const code = "export default Lambda(() => Math.random())"; + const result = compileToSymPy(code, defaultContext); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toContain("Unsupported Math function"); + // Points to the "Math.random" callee + expect(result.start).toBe(code.indexOf("Math.random")); + expect(result.length).toBe("Math.random".length); + } + }); + + it("should reject if statements in block body with position", () => { + const code = `export default Lambda((tokens, parameters) => { + if (parameters.infection_rate > 1) { + return parameters.infection_rate; + } + return 0; + })`; + const result = compileToSymPy(code, defaultContext); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toContain("Unsupported statement"); + expect(result.start).toBe(code.indexOf("if")); + expect(result.length).toBeGreaterThan(0); + } + }); + + it("should reject let declarations with position", () => { + const code = `export default Lambda(() => { + let x = 1; + return x; + })`; + const result = compileToSymPy(code, defaultContext); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toContain("let"); + expect(result.start).toBe(code.indexOf("let x = 1;")); + expect(result.length).toBe("let x = 1;".length); + } + }); + + it("should reject string literals with position", () => { + const code = `export default Lambda(() => "hello")`; + const result = compileToSymPy(code, defaultContext); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toContain("String literals"); + expect(result.start).toBe(code.indexOf('"hello"')); + expect(result.length).toBe('"hello"'.length); + } + }); + + it("should reject unsupported function calls with position", () => { + const code = `export default Lambda(() => console.log(1))`; + const result = compileToSymPy(code, defaultContext); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toContain("Unsupported function call"); + expect(result.start).toBe(code.indexOf("console.log(1)")); + expect(result.length).toBe("console.log(1)".length); + } + }); + + it("should reject unsupported binary operator with position", () => { + const code = `export default Lambda(() => 1 << 2)`; + const result = compileToSymPy(code, defaultContext); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toContain("Unsupported binary operator"); + expect(result.start).toBe(code.indexOf("<<")); + expect(result.length).toBe("<<".length); + } + }); + }); +}); diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts new file mode 100644 index 00000000000..c91274ed817 --- /dev/null +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts @@ -0,0 +1,767 @@ +import ts from "typescript"; + +import type { SDCPN, Transition } from "../../core/types/sdcpn"; + +/** + * Context for SymPy compilation, derived from the SDCPN model. + * Tells the compiler which identifiers are parameters vs. token fields. + */ +export type SymPyCompilationContext = { + parameterNames: Set; + /** Maps place name to its token field names */ + placeTokenFields: Map; + constructorFnName: string; +}; + +/** + * Builds a SymPyCompilationContext from an SDCPN model for a given transition. + */ +export function buildContextForTransition( + sdcpn: SDCPN, + transition: Transition, + constructorFnName: string, +): SymPyCompilationContext { + const parameterNames = new Set( + sdcpn.parameters.map((param) => param.variableName), + ); + const placeTokenFields = new Map(); + + const placeById = new Map(sdcpn.places.map((pl) => [pl.id, pl])); + const colorById = new Map(sdcpn.types.map((ct) => [ct.id, ct])); + + for (const arc of transition.inputArcs) { + const place = placeById.get(arc.placeId); + if (!place?.colorId) { + continue; + } + const color = colorById.get(place.colorId); + if (!color) { + continue; + } + placeTokenFields.set( + place.name, + color.elements.map((el) => el.name), + ); + } + + return { parameterNames, placeTokenFields, constructorFnName }; +} + +/** + * Builds a SymPyCompilationContext from an SDCPN model for a differential equation. + */ +export function buildContextForDifferentialEquation( + sdcpn: SDCPN, + colorId: string, +): SymPyCompilationContext { + const parameterNames = new Set( + sdcpn.parameters.map((param) => param.variableName), + ); + const placeTokenFields = new Map(); + + const color = sdcpn.types.find((ct) => ct.id === colorId); + if (color) { + // DE operates on tokens of its color type + placeTokenFields.set( + color.name, + color.elements.map((el) => el.name), + ); + } + + return { parameterNames, placeTokenFields, constructorFnName: "Dynamics" }; +} + +export type SymPyResult = + | { ok: true; sympyCode: string } + | { ok: false; error: string; start: number; length: number }; + +/** Shorthand for building an error result with position from a TS AST node. */ +function err( + error: string, + node: ts.Node, + sourceFile: ts.SourceFile, +): SymPyResult & { ok: false } { + return { + ok: false, + error, + start: node.getStart(sourceFile), + length: node.getWidth(sourceFile), + }; +} + +/** Error result for cases where no specific node is available. */ +function errNoPos(error: string): SymPyResult & { ok: false } { + return { ok: false, error, start: 0, length: 0 }; +} + +/** + * Compiles a Petrinaut TypeScript expression to SymPy Python code. + * + * Expects code following the pattern: + * `export default ConstructorFn((params...) => expression)` + * + * Only a restricted subset of TypeScript is supported — pure expressions + * with arithmetic, Math functions, parameter/token access, and distributions. + * Anything outside this subset is rejected with a diagnostic. + * + * @param code - The TypeScript expression code string + * @param context - Compilation context with parameter names and token fields + * @returns Either `{ ok: true, sympyCode }` or `{ ok: false, error }` + */ +export function compileToSymPy( + code: string, + context: SymPyCompilationContext, +): SymPyResult { + const sourceFile = ts.createSourceFile( + "input.ts", + code, + ts.ScriptTarget.ES2015, + true, + ); + + // Find the default export + const exportAssignment = sourceFile.statements.find( + (stmt): stmt is ts.ExportAssignment => + ts.isExportAssignment(stmt) && !stmt.isExportEquals, + ); + + if (!exportAssignment) { + // Try export default as ExpressionStatement pattern + const exportDefault = sourceFile.statements.find((stmt) => { + if (ts.isExportAssignment(stmt)) { + return true; + } + // Handle "export default X(...)" which parses as ExportAssignment + return false; + }); + if (!exportDefault) { + return errNoPos("No default export found"); + } + } + + const exportExpr = exportAssignment!.expression; + + // Expect ConstructorFn(...) + if (!ts.isCallExpression(exportExpr)) { + return err( + `Expected ${context.constructorFnName}(...), got ${ts.SyntaxKind[exportExpr.kind]}`, + exportExpr, + sourceFile, + ); + } + + const callee = exportExpr.expression; + if (!ts.isIdentifier(callee) || callee.text !== context.constructorFnName) { + return err( + `Expected ${context.constructorFnName}(...), got ${callee.getText(sourceFile)}(...)`, + callee, + sourceFile, + ); + } + + if (exportExpr.arguments.length !== 1) { + return err( + `${context.constructorFnName} expects exactly one argument`, + exportExpr, + sourceFile, + ); + } + + const arg = exportExpr.arguments[0]!; + + // The argument should be an arrow function or function expression + if (!ts.isArrowFunction(arg) && !ts.isFunctionExpression(arg)) { + return err( + `Expected a function argument, got ${ts.SyntaxKind[arg.kind]}`, + arg, + sourceFile, + ); + } + + // Extract parameter names for the inner function + const localBindings = new Map(); + const innerParams = extractFunctionParams(arg, sourceFile); + + // Compile the body + const body = arg.body; + + if (ts.isBlock(body)) { + return compileBlock(body, context, localBindings, sourceFile); + } + + // Expression body — emit directly + const result = emitSymPy( + body, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!result.ok) return result; + return { ok: true, sympyCode: result.sympyCode }; +} + +function extractFunctionParams( + fn: ts.ArrowFunction | ts.FunctionExpression, + sourceFile: ts.SourceFile, +): string[] { + return fn.parameters.map((p) => p.name.getText(sourceFile)); +} + +function compileBlock( + block: ts.Block, + context: SymPyCompilationContext, + localBindings: Map, + sourceFile: ts.SourceFile, +): SymPyResult { + const lines: string[] = []; + + for (const stmt of block.statements) { + if (ts.isVariableStatement(stmt)) { + for (const decl of stmt.declarationList.declarations) { + if (!decl.initializer) { + return err( + "Variable declaration without initializer", + decl, + sourceFile, + ); + } + if (stmt.declarationList.flags & ts.NodeFlags.Let) { + return err( + "'let' declarations are not supported, use 'const'", + stmt, + sourceFile, + ); + } + const name = decl.name.getText(sourceFile); + const valueResult = emitSymPy( + decl.initializer, + context, + localBindings, + [], + sourceFile, + ); + if (!valueResult.ok) return valueResult; + localBindings.set(name, valueResult.sympyCode); + lines.push(`${name} = ${valueResult.sympyCode}`); + } + } else if (ts.isReturnStatement(stmt)) { + if (!stmt.expression) { + return err("Empty return statement", stmt, sourceFile); + } + const result = emitSymPy( + stmt.expression, + context, + localBindings, + [], + sourceFile, + ); + if (!result.ok) return result; + lines.push(result.sympyCode); + } else if (ts.isExpressionStatement(stmt)) { + // Allow comments parsed as expression statements, skip them + continue; + } else { + return err( + `Unsupported statement: ${ts.SyntaxKind[stmt.kind]}`, + stmt, + sourceFile, + ); + } + } + + if (lines.length === 0) { + return err("Empty function body", block, sourceFile); + } + + return { ok: true, sympyCode: lines[lines.length - 1]! }; +} + +const MATH_FUNCTION_MAP: Record = { + cos: "sp.cos", + sin: "sp.sin", + tan: "sp.tan", + acos: "sp.acos", + asin: "sp.asin", + atan: "sp.atan", + atan2: "sp.atan2", + sqrt: "sp.sqrt", + log: "sp.log", + exp: "sp.exp", + abs: "sp.Abs", + floor: "sp.floor", + ceil: "sp.ceiling", + pow: "sp.Pow", + min: "sp.Min", + max: "sp.Max", +}; + +const MATH_CONSTANT_MAP: Record = { + PI: "sp.pi", + E: "sp.E", + Infinity: "sp.oo", +}; + +function emitSymPy( + node: ts.Node, + context: SymPyCompilationContext, + localBindings: Map, + innerParams: string[], + sourceFile: ts.SourceFile, +): SymPyResult { + // Numeric literal + if (ts.isNumericLiteral(node)) { + return { ok: true, sympyCode: node.text }; + } + + // String literal — not supported in symbolic math + if (ts.isStringLiteral(node)) { + return err( + "String literals are not supported in symbolic expressions", + node, + sourceFile, + ); + } + + // Boolean literals + if (node.kind === ts.SyntaxKind.TrueKeyword) { + return { ok: true, sympyCode: "True" }; + } + if (node.kind === ts.SyntaxKind.FalseKeyword) { + return { ok: true, sympyCode: "False" }; + } + + // Identifier + if (ts.isIdentifier(node)) { + const name = node.text; + if (name === "Infinity") return { ok: true, sympyCode: "sp.oo" }; + if (localBindings.has(name)) { + return { ok: true, sympyCode: localBindings.get(name)! }; + } + if (context.parameterNames.has(name)) { + return { ok: true, sympyCode: name }; + } + // Could be a destructured token field or function param + return { ok: true, sympyCode: name }; + } + + // Parenthesized expression + if (ts.isParenthesizedExpression(node)) { + const inner = emitSymPy( + node.expression, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!inner.ok) return inner; + return { ok: true, sympyCode: `(${inner.sympyCode})` }; + } + + // Prefix unary expression (-x, !x) + if (ts.isPrefixUnaryExpression(node)) { + const operand = emitSymPy( + node.operand, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!operand.ok) return operand; + + switch (node.operator) { + case ts.SyntaxKind.MinusToken: + return { ok: true, sympyCode: `-(${operand.sympyCode})` }; + case ts.SyntaxKind.ExclamationToken: + return { ok: true, sympyCode: `sp.Not(${operand.sympyCode})` }; + case ts.SyntaxKind.PlusToken: + return operand; + default: + return err( + `Unsupported prefix operator: ${ts.SyntaxKind[node.operator]}`, + node, + sourceFile, + ); + } + } + + // Binary expression + if (ts.isBinaryExpression(node)) { + const left = emitSymPy( + node.left, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!left.ok) return left; + const right = emitSymPy( + node.right, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!right.ok) return right; + + switch (node.operatorToken.kind) { + case ts.SyntaxKind.PlusToken: + return { + ok: true, + sympyCode: `${left.sympyCode} + ${right.sympyCode}`, + }; + case ts.SyntaxKind.MinusToken: + return { + ok: true, + sympyCode: `${left.sympyCode} - ${right.sympyCode}`, + }; + case ts.SyntaxKind.AsteriskToken: + return { + ok: true, + sympyCode: `${left.sympyCode} * ${right.sympyCode}`, + }; + case ts.SyntaxKind.SlashToken: + return { + ok: true, + sympyCode: `${left.sympyCode} / ${right.sympyCode}`, + }; + case ts.SyntaxKind.AsteriskAsteriskToken: + return { + ok: true, + sympyCode: `${left.sympyCode}**${right.sympyCode}`, + }; + case ts.SyntaxKind.PercentToken: + return { + ok: true, + sympyCode: `sp.Mod(${left.sympyCode}, ${right.sympyCode})`, + }; + case ts.SyntaxKind.LessThanToken: + return { + ok: true, + sympyCode: `${left.sympyCode} < ${right.sympyCode}`, + }; + case ts.SyntaxKind.LessThanEqualsToken: + return { + ok: true, + sympyCode: `${left.sympyCode} <= ${right.sympyCode}`, + }; + case ts.SyntaxKind.GreaterThanToken: + return { + ok: true, + sympyCode: `${left.sympyCode} > ${right.sympyCode}`, + }; + case ts.SyntaxKind.GreaterThanEqualsToken: + return { + ok: true, + sympyCode: `${left.sympyCode} >= ${right.sympyCode}`, + }; + case ts.SyntaxKind.EqualsEqualsToken: + case ts.SyntaxKind.EqualsEqualsEqualsToken: + return { + ok: true, + sympyCode: `sp.Eq(${left.sympyCode}, ${right.sympyCode})`, + }; + case ts.SyntaxKind.ExclamationEqualsToken: + case ts.SyntaxKind.ExclamationEqualsEqualsToken: + return { + ok: true, + sympyCode: `sp.Ne(${left.sympyCode}, ${right.sympyCode})`, + }; + case ts.SyntaxKind.AmpersandAmpersandToken: + return { + ok: true, + sympyCode: `sp.And(${left.sympyCode}, ${right.sympyCode})`, + }; + case ts.SyntaxKind.BarBarToken: + return { + ok: true, + sympyCode: `sp.Or(${left.sympyCode}, ${right.sympyCode})`, + }; + default: + return err( + `Unsupported binary operator: ${node.operatorToken.getText(sourceFile)}`, + node.operatorToken, + sourceFile, + ); + } + } + + // Conditional (ternary) expression + if (ts.isConditionalExpression(node)) { + const condition = emitSymPy( + node.condition, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!condition.ok) return condition; + const whenTrue = emitSymPy( + node.whenTrue, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!whenTrue.ok) return whenTrue; + const whenFalse = emitSymPy( + node.whenFalse, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!whenFalse.ok) return whenFalse; + return { + ok: true, + sympyCode: `sp.Piecewise((${whenTrue.sympyCode}, ${condition.sympyCode}), (${whenFalse.sympyCode}, True))`, + }; + } + + // Property access: parameters.x, tokens.Place[0].field, Math.PI + if (ts.isPropertyAccessExpression(node)) { + const propName = node.name.text; + + // Math constants: Math.PI, Math.E + if (ts.isIdentifier(node.expression) && node.expression.text === "Math") { + const constant = MATH_CONSTANT_MAP[propName]; + if (constant) return { ok: true, sympyCode: constant }; + // Math.method will be handled as part of a CallExpression + // Return a placeholder that the call expression handler will use + return { ok: true, sympyCode: `Math.${propName}` }; + } + + // parameters.x + if ( + ts.isIdentifier(node.expression) && + node.expression.text === "parameters" + ) { + return { ok: true, sympyCode: propName }; + } + + // tokens.Place[0].field — handle the chain + // First check: something.field where something is an element access + if (ts.isElementAccessExpression(node.expression)) { + // e.g., tokens.Space[0].x + const elemAccess = node.expression; + if (ts.isPropertyAccessExpression(elemAccess.expression)) { + const placePropAccess = elemAccess.expression; + if ( + ts.isIdentifier(placePropAccess.expression) && + placePropAccess.expression.text === "tokens" + ) { + const placeName = placePropAccess.name.text; + const indexExpr = elemAccess.argumentExpression; + const indexText = indexExpr.getText(sourceFile); + return { + ok: true, + sympyCode: `${placeName}_${indexText}_${propName}`, + }; + } + } + } + + // Generic property access — emit as dot access + const obj = emitSymPy( + node.expression, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!obj.ok) return obj; + return { ok: true, sympyCode: `${obj.sympyCode}_${propName}` }; + } + + // Element access: tokens.Place[0], arr[i] + if (ts.isElementAccessExpression(node)) { + const obj = emitSymPy( + node.expression, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!obj.ok) return obj; + const index = emitSymPy( + node.argumentExpression, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!index.ok) return index; + return { ok: true, sympyCode: `${obj.sympyCode}_${index.sympyCode}` }; + } + + // Call expression: Math.cos(x), Math.hypot(a, b), Distribution.Gaussian(m, s) + if (ts.isCallExpression(node)) { + const callee = node.expression; + + // Math.fn(...) + if ( + ts.isPropertyAccessExpression(callee) && + ts.isIdentifier(callee.expression) && + callee.expression.text === "Math" + ) { + const fnName = callee.name.text; + + // Special case: Math.hypot(a, b) -> sp.sqrt(a**2 + b**2) + if (fnName === "hypot") { + const args: string[] = []; + for (const a of node.arguments) { + const r = emitSymPy( + a, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!r.ok) return r; + args.push(r.sympyCode); + } + const sumOfSquares = args.map((a) => `(${a})**2`).join(" + "); + return { ok: true, sympyCode: `sp.sqrt(${sumOfSquares})` }; + } + + // Special case: Math.pow(a, b) -> a**b + if (fnName === "pow" && node.arguments.length === 2) { + const base = emitSymPy( + node.arguments[0]!, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!base.ok) return base; + const exp = emitSymPy( + node.arguments[1]!, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!exp.ok) return exp; + return { + ok: true, + sympyCode: `(${base.sympyCode})**(${exp.sympyCode})`, + }; + } + + const sympyFn = MATH_FUNCTION_MAP[fnName]; + if (!sympyFn) { + return err( + `Unsupported Math function: Math.${fnName}`, + callee, + sourceFile, + ); + } + + const args: string[] = []; + for (const a of node.arguments) { + const r = emitSymPy(a, context, localBindings, innerParams, sourceFile); + if (!r.ok) return r; + args.push(r.sympyCode); + } + return { ok: true, sympyCode: `${sympyFn}(${args.join(", ")})` }; + } + + // Distribution.Gaussian(m, s), Distribution.Uniform(a, b), Distribution.Lognormal(mu, sigma) + if ( + ts.isPropertyAccessExpression(callee) && + ts.isIdentifier(callee.expression) && + callee.expression.text === "Distribution" + ) { + const distName = callee.name.text; + const args: string[] = []; + for (const a of node.arguments) { + const r = emitSymPy(a, context, localBindings, innerParams, sourceFile); + if (!r.ok) return r; + args.push(r.sympyCode); + } + + switch (distName) { + case "Gaussian": + return { + ok: true, + sympyCode: `sp.stats.Normal('X', ${args.join(", ")})`, + }; + case "Uniform": + return { + ok: true, + sympyCode: `sp.stats.Uniform('X', ${args.join(", ")})`, + }; + case "Lognormal": + return { + ok: true, + sympyCode: `sp.stats.LogNormal('X', ${args.join(", ")})`, + }; + default: + return err( + `Unsupported distribution: Distribution.${distName}`, + callee, + sourceFile, + ); + } + } + + return err( + `Unsupported function call: ${callee.getText(sourceFile)}`, + node, + sourceFile, + ); + } + + // Object literal expression { field: expr, ... } + if (ts.isObjectLiteralExpression(node)) { + const entries: string[] = []; + for (const prop of node.properties) { + if (!ts.isPropertyAssignment(prop)) { + return err( + `Unsupported object property kind: ${ts.SyntaxKind[prop.kind]}`, + prop, + sourceFile, + ); + } + const key = prop.name.getText(sourceFile); + const val = emitSymPy( + prop.initializer, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!val.ok) return val; + entries.push(`'${key}': ${val.sympyCode}`); + } + return { ok: true, sympyCode: `{${entries.join(", ")}}` }; + } + + // Non-null assertion (x!) — just unwrap + if (ts.isNonNullExpression(node)) { + return emitSymPy( + node.expression, + context, + localBindings, + innerParams, + sourceFile, + ); + } + + // Type assertion (x as T) — just unwrap + if (ts.isAsExpression(node)) { + return emitSymPy( + node.expression, + context, + localBindings, + innerParams, + sourceFile, + ); + } + + return err( + `Unsupported syntax: ${ts.SyntaxKind[node.kind]}`, + node, + sourceFile, + ); +} diff --git a/libs/@hashintel/petrinaut/src/views/Editor/editor-view.tsx b/libs/@hashintel/petrinaut/src/views/Editor/editor-view.tsx index e1c2a62916a..c99f58a3a92 100644 --- a/libs/@hashintel/petrinaut/src/views/Editor/editor-view.tsx +++ b/libs/@hashintel/petrinaut/src/views/Editor/editor-view.tsx @@ -27,6 +27,7 @@ import { import { BottomBar } from "./components/BottomBar/bottom-bar"; import { ImportErrorDialog } from "./components/import-error-dialog"; import { TopBar } from "./components/TopBar/top-bar"; +import { exportWithSymPy } from "./lib/export-sympy"; import { exportTikZ } from "./lib/export-tikz"; import { BottomPanel } from "./panels/BottomPanel/panel"; import { LeftSideBar } from "./panels/LeftSideBar/panel"; @@ -156,6 +157,10 @@ export const EditorView = ({ exportTikZ({ petriNetDefinition, title }); } + function handleExportWithSymPy() { + exportWithSymPy({ petriNetDefinition, title }); + } + async function handleImport() { const result = await importSDCPN(); if (!result) { @@ -244,6 +249,11 @@ export const EditorView = ({ label: "TikZ", onClick: handleExportTikZ, }, + { + id: "export-sympy", + label: "JSON with SymPy expressions", + onClick: handleExportWithSymPy, + }, ], }, { diff --git a/libs/@hashintel/petrinaut/src/views/Editor/lib/export-sympy.ts b/libs/@hashintel/petrinaut/src/views/Editor/lib/export-sympy.ts new file mode 100644 index 00000000000..281549292d3 --- /dev/null +++ b/libs/@hashintel/petrinaut/src/views/Editor/lib/export-sympy.ts @@ -0,0 +1,93 @@ +import type { SDCPN } from "../../../core/types/sdcpn"; +import { + buildContextForDifferentialEquation, + buildContextForTransition, + compileToSymPy, +} from "../../../simulation/simulator/compile-to-sympy"; + +type SymPyExpression = { + name: string; + type: string; + sympyCode: string | null; + error: string | null; +}; + +/** + * Converts all expressions in an SDCPN model to SymPy and produces a JSON + * export containing both the original model and the SymPy representations. + */ +export function exportWithSymPy({ + petriNetDefinition, + title, +}: { + petriNetDefinition: SDCPN; + title: string; +}): void { + const expressions: SymPyExpression[] = []; + + // Convert differential equation expressions + for (const de of petriNetDefinition.differentialEquations) { + const ctx = buildContextForDifferentialEquation( + petriNetDefinition, + de.colorId, + ); + const result = compileToSymPy(de.code, ctx); + expressions.push({ + name: de.name, + type: "differential-equation", + sympyCode: result.ok ? result.sympyCode : null, + error: result.ok ? null : result.error, + }); + } + + // Convert transition lambda and kernel expressions + for (const transition of petriNetDefinition.transitions) { + const lambdaCtx = buildContextForTransition( + petriNetDefinition, + transition, + "Lambda", + ); + const lambdaResult = compileToSymPy(transition.lambdaCode, lambdaCtx); + expressions.push({ + name: `${transition.name} (lambda)`, + type: "transition-lambda", + sympyCode: lambdaResult.ok ? lambdaResult.sympyCode : null, + error: lambdaResult.ok ? null : lambdaResult.error, + }); + + const kernelCtx = buildContextForTransition( + petriNetDefinition, + transition, + "TransitionKernel", + ); + const kernelResult = compileToSymPy( + transition.transitionKernelCode, + kernelCtx, + ); + expressions.push({ + name: `${transition.name} (kernel)`, + type: "transition-kernel", + sympyCode: kernelResult.ok ? kernelResult.sympyCode : null, + error: kernelResult.ok ? null : kernelResult.error, + }); + } + + const exportData = { + title, + sympy_expressions: expressions, + ...petriNetDefinition, + }; + + const jsonString = JSON.stringify(exportData, null, 2); + const blob = new Blob([jsonString], { type: "application/json" }); + const url = URL.createObjectURL(blob); + const link = document.createElement("a"); + link.href = url; + link.download = `${title.replace(/[^a-z0-9]/gi, "_").toLowerCase()}_sympy_${new Date().toISOString().replace(/:/g, "-")}.json`; + + document.body.appendChild(link); + link.click(); + + document.body.removeChild(link); + URL.revokeObjectURL(url); +} From f9c62710c7653c9eddf945b6e85b7d1bc03875c6 Mon Sep 17 00:00:00 2001 From: Chris Feijoo Date: Fri, 13 Mar 2026 02:47:45 +0100 Subject: [PATCH 02/14] FE-514: Support Boolean() and Number() global calls in SymPy compiler Boolean(expr) maps to sp.Ne(expr, 0) matching JS truthiness semantics. Number(expr) passes through as identity in symbolic math context. Co-Authored-By: Claude Opus 4.6 --- .../simulator/compile-to-sympy.test.ts | 49 +++++++++++++++++++ .../simulation/simulator/compile-to-sympy.ts | 25 ++++++++++ 2 files changed, 74 insertions(+) diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts index 889adeec036..5de3e91ded4 100644 --- a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts @@ -423,6 +423,55 @@ describe("compileToSymPy", () => { }); }); + describe("global built-in functions", () => { + it("should compile Boolean(expr) to sp.Ne(expr, 0)", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => Boolean(parameters.infection_rate))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.Ne(infection_rate, 0)", + }); + }); + + it("should compile Boolean with arithmetic expression", () => { + const result = compileToSymPy( + "export default Lambda(() => Boolean(1 + 2))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.Ne(1 + 2, 0)", + }); + }); + + it("should compile Number(expr) as identity", () => { + const result = compileToSymPy( + "export default Lambda(() => Number(true))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "True", + }); + }); + + it("should compile Boolean in block body with return", () => { + const result = compileToSymPy( + `export default Lambda((tokens, parameters) => { + const sum = parameters.infection_rate + parameters.recovery_rate; + return Boolean(sum); + })`, + defaultContext, + ); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.sympyCode).toContain("sp.Ne"); + } + }); + }); + describe("block body with const and return", () => { it("should compile block body with const bindings", () => { const result = compileToSymPy( diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts index c91274ed817..8038280a8fc 100644 --- a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts @@ -705,6 +705,31 @@ function emitSymPy( } } + // Global built-in functions: Boolean(expr), Number(expr) + if (ts.isIdentifier(callee)) { + if (callee.text === "Boolean" && node.arguments.length === 1) { + const arg = emitSymPy( + node.arguments[0]!, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!arg.ok) return arg; + return { ok: true, sympyCode: `sp.Ne(${arg.sympyCode}, 0)` }; + } + + if (callee.text === "Number" && node.arguments.length === 1) { + return emitSymPy( + node.arguments[0]!, + context, + localBindings, + innerParams, + sourceFile, + ); + } + } + return err( `Unsupported function call: ${callee.getText(sourceFile)}`, node, From 1470a080b83a1f70de2cd126c85e1b0e3a4481cb Mon Sep 17 00:00:00 2001 From: Chris Feijoo Date: Fri, 13 Mar 2026 02:51:24 +0100 Subject: [PATCH 03/14] FE-514: Support array literal expressions in SymPy compiler Maps TypeScript array literals to Python lists, enabling transition kernels that return arrays of token objects to compile to SymPy. Co-Authored-By: Claude Opus 4.6 --- .../simulator/compile-to-sympy.test.ts | 38 +++++++++++++++++++ .../simulation/simulator/compile-to-sympy.ts | 17 +++++++++ 2 files changed, 55 insertions(+) diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts index 5de3e91ded4..706cdbd6780 100644 --- a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts @@ -567,6 +567,44 @@ describe("compileToSymPy", () => { expect(result.sympyCode).toContain("'velocity': 0"); } }); + + it("should compile transition kernel with array of objects", () => { + const result = compileToSymPy( + `export default TransitionKernel((tokens) => { + return { + Debris: [ + { + x: tokens.Space[0].x, + y: tokens.Space[0].y, + velocity: 0, + direction: 0 + }, + { + x: tokens.Space[1].x, + y: tokens.Space[1].y, + velocity: 0, + direction: 0 + }, + ] + }; + })`, + kernelContext(), + ); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.sympyCode).toContain("'Debris': ["); + expect(result.sympyCode).toContain("'x': Space_0_x"); + expect(result.sympyCode).toContain("'x': Space_1_x"); + } + }); + + it("should compile simple array literal", () => { + const result = compileToSymPy( + "export default Lambda(() => [1, 2, 3])", + defaultContext, + ); + expect(result).toEqual({ ok: true, sympyCode: "[1, 2, 3]" }); + }); }); describe("error handling", () => { diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts index 8038280a8fc..1b3a6897fcb 100644 --- a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts @@ -737,6 +737,23 @@ function emitSymPy( ); } + // Array literal expression [a, b, c] + if (ts.isArrayLiteralExpression(node)) { + const elements: string[] = []; + for (const elem of node.elements) { + const result = emitSymPy( + elem, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!result.ok) return result; + elements.push(result.sympyCode); + } + return { ok: true, sympyCode: `[${elements.join(", ")}]` }; + } + // Object literal expression { field: expr, ... } if (ts.isObjectLiteralExpression(node)) { const entries: string[] = []; From 82bcfbd4690afbe301e6b321323d7dde2aaacaa7 Mon Sep 17 00:00:00 2001 From: Chris Feijoo Date: Fri, 13 Mar 2026 03:01:52 +0100 Subject: [PATCH 04/14] FE-514: Support .map() as Python list comprehension in SymPy compiler Compiles tokens.map(callback) to [body for _iter in collection], handling both destructured ({ x, y }) and simple (token) parameters. Enables dynamics expressions that iterate over token arrays. Co-Authored-By: Claude Opus 4.6 --- .../simulator/compile-to-sympy.test.ts | 51 ++++++++++++ .../simulation/simulator/compile-to-sympy.ts | 81 +++++++++++++++++++ 2 files changed, 132 insertions(+) diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts index 706cdbd6780..6a653dbd4ca 100644 --- a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts @@ -607,6 +607,57 @@ describe("compileToSymPy", () => { }); }); + describe(".map() list comprehension", () => { + it("should compile tokens.map with destructured params", () => { + const result = compileToSymPy( + `export default Dynamics((tokens, parameters) => { + const mu = parameters.gravitational_constant; + return tokens.map(({ x, y, direction, velocity }) => { + const r = Math.hypot(x, y); + const ax = (-mu * x) / (r * r * r); + const ay = (-mu * y) / (r * r * r); + return { + x: velocity * Math.cos(direction), + y: velocity * Math.sin(direction), + direction: (-ax * Math.sin(direction) + ay * Math.cos(direction)) / velocity, + velocity: ax * Math.cos(direction) + ay * Math.sin(direction), + }; + }); + })`, + dynamicsContext(), + ); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.sympyCode).toContain("for _iter in tokens"); + expect(result.sympyCode).toContain("_iter_x"); + expect(result.sympyCode).toContain("_iter_velocity"); + expect(result.sympyCode).toContain("sp.cos(_iter_direction)"); + } + }); + + it("should compile simple .map with identifier param", () => { + const result = compileToSymPy( + `export default Lambda((tokens) => tokens.map((token) => token + 1))`, + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "[_iter + 1 for _iter in tokens]", + }); + }); + + it("should compile .map with expression body", () => { + const result = compileToSymPy( + `export default Lambda((tokens, parameters) => tokens.map(({ x }) => x * parameters.infection_rate))`, + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "[_iter_x * infection_rate for _iter in tokens]", + }); + }); + }); + describe("error handling", () => { it("should reject code without default export", () => { const result = compileToSymPy( diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts index 1b3a6897fcb..f877e0e6b01 100644 --- a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts @@ -277,6 +277,68 @@ function compileBlock( return { ok: true, sympyCode: lines[lines.length - 1]! }; } +/** + * Compiles `collection.map(callback)` to a Python list comprehension. + * + * Handles two callback parameter styles: + * - Destructured: `({ x, y }) => ...` → binds each field as `_iter_x`, `_iter_y` + * - Simple identifier: `(token) => ...` → binds as-is + * + * Emits: `[ for _iter in ]` + */ +function compileMapCall( + collection: ts.Expression, + callback: ts.ArrowFunction | ts.FunctionExpression, + context: SymPyCompilationContext, + outerBindings: Map, + innerParams: string[], + sourceFile: ts.SourceFile, +): SymPyResult { + const iterVar = "_iter"; + const mapBindings = new Map(outerBindings); + + const param = callback.parameters[0]; + if (param) { + const paramName = param.name; + if (ts.isObjectBindingPattern(paramName)) { + // Destructured: ({ x, y, ... }) => ... + // Each field becomes a symbol like _iter_x, _iter_y + for (const element of paramName.elements) { + const fieldName = element.name.getText(sourceFile); + mapBindings.set(fieldName, `${iterVar}_${fieldName}`); + } + } else { + // Simple identifier: (token) => ... + mapBindings.set(paramName.getText(sourceFile), iterVar); + } + } + + // Compile the body + const body = callback.body; + let bodyResult: SymPyResult; + if (ts.isBlock(body)) { + bodyResult = compileBlock(body, context, mapBindings, sourceFile); + } else { + bodyResult = emitSymPy(body, context, mapBindings, innerParams, sourceFile); + } + if (!bodyResult.ok) return bodyResult; + + // Compile the collection expression + const collectionResult = emitSymPy( + collection, + context, + outerBindings, + innerParams, + sourceFile, + ); + if (!collectionResult.ok) return collectionResult; + + return { + ok: true, + sympyCode: `[${bodyResult.sympyCode} for ${iterVar} in ${collectionResult.sympyCode}]`, + }; +} + const MATH_FUNCTION_MAP: Record = { cos: "sp.cos", sin: "sp.sin", @@ -730,6 +792,25 @@ function emitSymPy( } } + // .map(callback) on arrays/tokens — emit as Python list comprehension + if ( + ts.isPropertyAccessExpression(callee) && + callee.name.text === "map" && + node.arguments.length === 1 + ) { + const callback = node.arguments[0]!; + if (ts.isArrowFunction(callback) || ts.isFunctionExpression(callback)) { + return compileMapCall( + callee.expression, + callback, + context, + localBindings, + innerParams, + sourceFile, + ); + } + } + return err( `Unsupported function call: ${callee.getText(sourceFile)}`, node, From 1b64b08ac573e8d5bb84fe6e767a93011ec53a5b Mon Sep 17 00:00:00 2001 From: Chris Feijoo Date: Fri, 13 Mar 2026 03:09:26 +0100 Subject: [PATCH 05/14] FE-514: Reject var declarations and standalone expressions in SymPy compiler MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Only const is allowed — both let and var are now rejected. Standalone expression statements (not assigned or returned) produce a diagnostic. Co-Authored-By: Claude Opus 4.6 --- .../simulator/compile-to-sympy.test.ts | 30 +++++++++++++++++++ .../simulation/simulator/compile-to-sympy.ts | 11 ++++--- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts index 6a653dbd4ca..0ae9e4eef5d 100644 --- a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts @@ -726,6 +726,19 @@ describe("compileToSymPy", () => { } }); + it("should reject var declarations", () => { + const code = `export default Lambda(() => { + var x = 1; + return x; + })`; + const result = compileToSymPy(code, defaultContext); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toContain("var"); + expect(result.error).toContain("use 'const'"); + } + }); + it("should reject string literals with position", () => { const code = `export default Lambda(() => "hello")`; const result = compileToSymPy(code, defaultContext); @@ -748,6 +761,23 @@ describe("compileToSymPy", () => { } }); + it("should reject standalone expression statements", () => { + const code = `export default Lambda((tokensByPlace, parameters) => { + const a = Boolean(1 + 2); + Boolean(1 + 2); + return Boolean(1 + 2); + })`; + const result = compileToSymPy(code, defaultContext); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toContain("Standalone expression has no effect"); + // The standalone expression is the second line in the block + const standalonePos = code.indexOf("\n Boolean(1 + 2);") + 11; + expect(result.start).toBe(standalonePos); + expect(result.length).toBe("Boolean(1 + 2);".length); + } + }); + it("should reject unsupported binary operator with position", () => { const code = `export default Lambda(() => 1 << 2)`; const result = compileToSymPy(code, defaultContext); diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts index f877e0e6b01..7d00e5a34ba 100644 --- a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts @@ -226,9 +226,9 @@ function compileBlock( sourceFile, ); } - if (stmt.declarationList.flags & ts.NodeFlags.Let) { + if (!(stmt.declarationList.flags & ts.NodeFlags.Const)) { return err( - "'let' declarations are not supported, use 'const'", + "'let' and 'var' declarations are not supported, use 'const'", stmt, sourceFile, ); @@ -259,8 +259,11 @@ function compileBlock( if (!result.ok) return result; lines.push(result.sympyCode); } else if (ts.isExpressionStatement(stmt)) { - // Allow comments parsed as expression statements, skip them - continue; + return err( + "Standalone expression has no effect — assign to a const or return it", + stmt, + sourceFile, + ); } else { return err( `Unsupported statement: ${ts.SyntaxKind[stmt.kind]}`, From 06bcd93af2ae180beffa7c878773377a5e4a676a Mon Sep 17 00:00:00 2001 From: Chris Feijoo Date: Wed, 18 Mar 2026 01:11:55 +0100 Subject: [PATCH 06/14] FE-514: Introduce JSON IR between TypeScript and SymPy compilation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Refactors the compiler into two layers: TypeScript → Expression IR (JSON) → SymPy. The IR captures bindings, expressions, distributions, and derived distributions as a typed JSON AST, enabling future backends beyond SymPy. Key changes: - expression-ir.ts: 18 IR node types (number, symbol, parameter, tokenAccess, binary, unary, call, distribution, derivedDistribution, piecewise, etc.) - compile-to-ir.ts: TS→IR compiler with scope tracking for distribution bindings and support for Distribution.map() as derived distributions - ir-to-sympy.ts: IR→SymPy converter with let-binding inlining - compile-to-sympy.ts: thin wrapper composing the two layers - checker.ts: separate SymPy warnings from TS errors so isValid only reflects TypeScript validity; skip kernel checks for uncoloured outputs - Storybook playground for live TS→IR visualization Co-Authored-By: Claude Opus 4.6 (1M context) --- .../petrinaut/src/lsp/lib/checker.ts | 64 +- .../src/lsp/worker/language-server.worker.ts | 30 +- .../simulator/compile-to-ir.stories.tsx | 150 +++ .../simulator/compile-to-ir.test.ts | 545 +++++++++++ .../src/simulation/simulator/compile-to-ir.ts | 851 +++++++++++++++++ .../simulator/compile-to-sympy.test.ts | 29 + .../simulation/simulator/compile-to-sympy.ts | 900 +----------------- .../src/simulation/simulator/expression-ir.ts | 169 ++++ .../simulation/simulator/ir-to-sympy.test.ts | 333 +++++++ .../src/simulation/simulator/ir-to-sympy.ts | 207 ++++ 10 files changed, 2362 insertions(+), 916 deletions(-) create mode 100644 libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-ir.stories.tsx create mode 100644 libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-ir.test.ts create mode 100644 libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-ir.ts create mode 100644 libs/@hashintel/petrinaut/src/simulation/simulator/expression-ir.ts create mode 100644 libs/@hashintel/petrinaut/src/simulation/simulator/ir-to-sympy.test.ts create mode 100644 libs/@hashintel/petrinaut/src/simulation/simulator/ir-to-sympy.ts diff --git a/libs/@hashintel/petrinaut/src/lsp/lib/checker.ts b/libs/@hashintel/petrinaut/src/lsp/lib/checker.ts index 375a488b7e2..c456d8e8576 100644 --- a/libs/@hashintel/petrinaut/src/lsp/lib/checker.ts +++ b/libs/@hashintel/petrinaut/src/lsp/lib/checker.ts @@ -27,10 +27,12 @@ export type SDCPNDiagnostic = { }; export type SDCPNCheckResult = { - /** Whether the SDCPN is valid (no errors) */ + /** Whether the SDCPN is valid (no TypeScript errors) */ isValid: boolean; - /** All diagnostics grouped by item */ + /** TypeScript error diagnostics grouped by item */ itemDiagnostics: SDCPNDiagnostic[]; + /** SymPy compilation warning diagnostics (informational, do not affect validity) */ + sympyDiagnostics: SDCPNDiagnostic[]; }; /** @@ -77,10 +79,8 @@ function appendSymPyDiagnostic( * Runs SymPy compilation on all SDCPN code expressions and appends * any errors as warning diagnostics. */ -function checkSymPyCompilation( - sdcpn: SDCPN, - itemDiagnostics: SDCPNDiagnostic[], -): void { +function checkSymPyCompilation(sdcpn: SDCPN): SDCPNDiagnostic[] { + const itemDiagnostics: SDCPNDiagnostic[] = []; // Check differential equations for (const de of sdcpn.differentialEquations) { const ctx = buildContextForDifferentialEquation(sdcpn, de.colorId); @@ -116,28 +116,39 @@ function checkSymPyCompilation( ); } - const kernelCtx = buildContextForTransition( - sdcpn, - transition, - "TransitionKernel", - ); - const kernelResult = compileToSymPy( - transition.transitionKernelCode, - kernelCtx, - ); - if (!kernelResult.ok) { - const filePath = getItemFilePath("transition-kernel-code", { - transitionId: transition.id, - }); - appendSymPyDiagnostic( - itemDiagnostics, - transition.id, - "transition-kernel", - filePath, - kernelResult, + // Only check TransitionKernel if there are coloured output places, + // matching the TypeScript checker's behavior + const hasColouredOutputPlaces = transition.outputArcs.some((arc) => { + const place = sdcpn.places.find((pl) => pl.id === arc.placeId); + return place?.colorId != null; + }); + + if (hasColouredOutputPlaces) { + const kernelCtx = buildContextForTransition( + sdcpn, + transition, + "TransitionKernel", + ); + const kernelResult = compileToSymPy( + transition.transitionKernelCode, + kernelCtx, ); + if (!kernelResult.ok) { + const filePath = getItemFilePath("transition-kernel-code", { + transitionId: transition.id, + }); + appendSymPyDiagnostic( + itemDiagnostics, + transition.id, + "transition-kernel", + filePath, + kernelResult, + ); + } } } + + return itemDiagnostics; } /** @@ -225,10 +236,11 @@ export function checkSDCPN( } // Run SymPy compilation checks on all code expressions - checkSymPyCompilation(sdcpn, itemDiagnostics); + const sympyDiagnostics = checkSymPyCompilation(sdcpn); return { isValid: itemDiagnostics.length === 0, itemDiagnostics, + sympyDiagnostics, }; } diff --git a/libs/@hashintel/petrinaut/src/lsp/worker/language-server.worker.ts b/libs/@hashintel/petrinaut/src/lsp/worker/language-server.worker.ts index b955e7c45ed..1636a286ed0 100644 --- a/libs/@hashintel/petrinaut/src/lsp/worker/language-server.worker.ts +++ b/libs/@hashintel/petrinaut/src/lsp/worker/language-server.worker.ts @@ -60,20 +60,22 @@ function publishAllDiagnostics(sdcpn: SDCPN): void { } const result = checkSDCPN(sdcpn, server); - const params: PublishDiagnosticsParams[] = result.itemDiagnostics.map( - (item) => { - const uri = filePathToUri(item.filePath); - // Use user content (without prefix) because diagnostic offsets have - // already been adjusted to be relative to user content by adjustDiagnostics. - const userContent = server!.getUserContent(item.filePath) ?? ""; - return { - uri: uri ?? item.filePath, - diagnostics: item.diagnostics.map((diag) => - serializeDiagnostic(diag, userContent), - ), - }; - }, - ); + const allDiagnostics = [ + ...result.itemDiagnostics, + ...result.sympyDiagnostics, + ]; + const params: PublishDiagnosticsParams[] = allDiagnostics.map((item) => { + const uri = filePathToUri(item.filePath); + // Use user content (without prefix) because diagnostic offsets have + // already been adjusted to be relative to user content by adjustDiagnostics. + const userContent = server!.getUserContent(item.filePath) ?? ""; + return { + uri: uri ?? item.filePath, + diagnostics: item.diagnostics.map((diag) => + serializeDiagnostic(diag, userContent), + ), + }; + }); self.postMessage({ jsonrpc: "2.0", diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-ir.stories.tsx b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-ir.stories.tsx new file mode 100644 index 00000000000..90119c042e1 --- /dev/null +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-ir.stories.tsx @@ -0,0 +1,150 @@ +import type { Meta, StoryObj } from "@storybook/react-vite"; +import Editor, { loader, type Monaco } from "@monaco-editor/react"; +import * as monaco from "monaco-editor"; +import { useCallback, useRef, useState } from "react"; + +import type { CompilationContext } from "./compile-to-ir"; +import { compileToIR } from "./compile-to-ir"; + +type IRResult = ReturnType; + +// Use bundled Monaco directly, no workers needed for basic editing +loader.config({ monaco }); + +const DEFAULT_CODE = `export default Lambda((tokens, parameters) => { + const rate = parameters.infection_rate; + return rate * tokens.Space[0].x; +})`; + +const DEFAULT_CONTEXT: CompilationContext = { + parameterNames: new Set([ + "infection_rate", + "recovery_rate", + "gravitational_constant", + ]), + placeTokenFields: new Map([["Space", ["x", "y", "direction", "velocity"]]]), + constructorFnName: "Lambda", +}; + +function formatResult(result: IRResult): string { + if (result.ok) { + return JSON.stringify(result.ir, null, 2); + } + return JSON.stringify( + { error: result.error, start: result.start, length: result.length }, + null, + 2, + ); +} + +function compile(code: string): string { + return formatResult(compileToIR(code, DEFAULT_CONTEXT)); +} + +function IRPlayground() { + const [output, setOutput] = useState(() => compile(DEFAULT_CODE)); + const timerRef = useRef>(undefined); + + const onChange = useCallback((value: string | undefined) => { + clearTimeout(timerRef.current); + timerRef.current = setTimeout(() => { + setOutput(compile(value ?? "")); + }, 300); + }, []); + + /** Disable all TypeScript diagnostics in the editor */ + const onMount = useCallback((_editor: unknown, instance: Monaco) => { + instance.languages.typescript.typescriptDefaults.setDiagnosticsOptions({ + noSemanticValidation: true, + noSyntaxValidation: true, + }); + }, []); + + return ( +
+
+
+ TypeScript +
+
+ +
+
+
+
+ Expression IR (JSON) +
+
+ +
+
+
+ ); +} + +const meta = { + title: "Compiler / TypeScript to IR", + parameters: { + layout: "fullscreen", + }, +} satisfies Meta; + +export default meta; + +type Story = StoryObj; + +export const Playground: Story = { + render: () => ( +
+ +
+ ), +}; diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-ir.test.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-ir.test.ts new file mode 100644 index 00000000000..6a2956dab66 --- /dev/null +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-ir.test.ts @@ -0,0 +1,545 @@ +import { describe, expect, it } from "vitest"; + +import type { CompilationContext } from "./compile-to-ir"; +import { compileToIR } from "./compile-to-ir"; +import type { ExpressionIR } from "./expression-ir"; + +const defaultContext: CompilationContext = { + parameterNames: new Set([ + "infection_rate", + "recovery_rate", + "gravitational_constant", + ]), + placeTokenFields: new Map([ + ["Space", ["x", "y", "direction", "velocity"]], + ["Susceptible", []], + ]), + constructorFnName: "Lambda", +}; + +function dynamicsContext(): CompilationContext { + return { ...defaultContext, constructorFnName: "Dynamics" }; +} + +function expectIR(code: string, expected: ExpressionIR, ctx = defaultContext) { + const result = compileToIR(code, ctx); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.ir).toEqual(expected); + } +} + +describe("compileToIR", () => { + describe("literals", () => { + it("should compile numeric literal", () => { + expectIR("export default Lambda(() => 42)", { + type: "number", + value: "42", + }); + }); + + it("should compile decimal literal", () => { + expectIR("export default Lambda(() => 3.14)", { + type: "number", + value: "3.14", + }); + }); + + it("should compile boolean true", () => { + expectIR("export default Lambda(() => true)", { + type: "boolean", + value: true, + }); + }); + + it("should compile boolean false", () => { + expectIR("export default Lambda(() => false)", { + type: "boolean", + value: false, + }); + }); + + it("should compile Infinity", () => { + expectIR("export default Lambda(() => Infinity)", { + type: "infinity", + }); + }); + }); + + describe("parameters", () => { + it("should compile parameter access to parameter node", () => { + expectIR( + "export default Lambda((tokens, parameters) => parameters.infection_rate)", + { type: "parameter", name: "infection_rate" }, + ); + }); + }); + + describe("token access", () => { + it("should compile tokens.Place[0].field to tokenAccess", () => { + expectIR("export default Lambda((tokens) => tokens.Space[0].x)", { + type: "tokenAccess", + place: "Space", + index: { type: "number", value: "0" }, + field: "x", + }); + }); + }); + + describe("binary operations", () => { + it("should compile addition", () => { + expectIR("export default Lambda(() => 1 + 2)", { + type: "binary", + op: "+", + left: { type: "number", value: "1" }, + right: { type: "number", value: "2" }, + }); + }); + + it("should compile strict equality to ==", () => { + expectIR( + "export default Lambda((tokens, parameters) => parameters.infection_rate === 3)", + { + type: "binary", + op: "==", + left: { type: "parameter", name: "infection_rate" }, + right: { type: "number", value: "3" }, + }, + ); + }); + + it("should compile !== to !=", () => { + expectIR( + "export default Lambda((tokens, parameters) => parameters.infection_rate !== 0)", + { + type: "binary", + op: "!=", + left: { type: "parameter", name: "infection_rate" }, + right: { type: "number", value: "0" }, + }, + ); + }); + + it("should compile && to logical and", () => { + expectIR( + "export default Lambda((tokens, parameters) => parameters.infection_rate > 0 && parameters.recovery_rate > 0)", + { + type: "binary", + op: "&&", + left: { + type: "binary", + op: ">", + left: { type: "parameter", name: "infection_rate" }, + right: { type: "number", value: "0" }, + }, + right: { + type: "binary", + op: ">", + left: { type: "parameter", name: "recovery_rate" }, + right: { type: "number", value: "0" }, + }, + }, + ); + }); + }); + + describe("unary operations", () => { + it("should compile negation", () => { + expectIR( + "export default Lambda((tokens, parameters) => -parameters.infection_rate)", + { + type: "unary", + op: "-", + operand: { type: "parameter", name: "infection_rate" }, + }, + ); + }); + + it("should compile logical not", () => { + expectIR("export default Lambda(() => !true)", { + type: "unary", + op: "!", + operand: { type: "boolean", value: true }, + }); + }); + + it("should unwrap unary plus", () => { + expectIR( + "export default Lambda((tokens, parameters) => +parameters.infection_rate)", + { type: "parameter", name: "infection_rate" }, + ); + }); + }); + + describe("math functions", () => { + it("should compile Math.cos to call node", () => { + expectIR( + "export default Lambda((tokens, parameters) => Math.cos(parameters.infection_rate))", + { + type: "call", + fn: "cos", + args: [{ type: "parameter", name: "infection_rate" }], + }, + ); + }); + + it("should compile Math.hypot to call node", () => { + expectIR( + "export default Lambda((tokens, parameters) => Math.hypot(parameters.infection_rate, parameters.recovery_rate))", + { + type: "call", + fn: "hypot", + args: [ + { type: "parameter", name: "infection_rate" }, + { type: "parameter", name: "recovery_rate" }, + ], + }, + ); + }); + + it("should compile Math.PI to symbol", () => { + expectIR("export default Lambda(() => Math.PI)", { + type: "symbol", + name: "PI", + }); + }); + + it("should compile Math.E to symbol", () => { + expectIR("export default Lambda(() => Math.E)", { + type: "symbol", + name: "E", + }); + }); + }); + + describe("distributions", () => { + it("should compile Distribution.Gaussian", () => { + expectIR("export default Lambda(() => Distribution.Gaussian(0, 1))", { + type: "distribution", + distribution: "Gaussian", + args: [ + { type: "number", value: "0" }, + { type: "number", value: "1" }, + ], + }); + }); + + it("should compile Distribution.Uniform", () => { + expectIR("export default Lambda(() => Distribution.Uniform(0, 1))", { + type: "distribution", + distribution: "Uniform", + args: [ + { type: "number", value: "0" }, + { type: "number", value: "1" }, + ], + }); + }); + }); + + describe("conditional", () => { + it("should compile ternary to piecewise", () => { + expectIR( + "export default Lambda((tokens, parameters) => parameters.infection_rate > 1 ? parameters.infection_rate : 0)", + { + type: "piecewise", + condition: { + type: "binary", + op: ">", + left: { type: "parameter", name: "infection_rate" }, + right: { type: "number", value: "1" }, + }, + whenTrue: { type: "parameter", name: "infection_rate" }, + whenFalse: { type: "number", value: "0" }, + }, + ); + }); + }); + + describe("global functions", () => { + it("should compile Boolean(expr) to != 0", () => { + expectIR( + "export default Lambda((tokens, parameters) => Boolean(parameters.infection_rate))", + { + type: "binary", + op: "!=", + left: { type: "parameter", name: "infection_rate" }, + right: { type: "number", value: "0" }, + }, + ); + }); + + it("should compile Number(expr) as identity", () => { + expectIR("export default Lambda(() => Number(true))", { + type: "boolean", + value: true, + }); + }); + }); + + describe("collections", () => { + it("should compile array literal", () => { + expectIR("export default Lambda(() => [1, 2, 3])", { + type: "array", + elements: [ + { type: "number", value: "1" }, + { type: "number", value: "2" }, + { type: "number", value: "3" }, + ], + }); + }); + + it("should compile object literal", () => { + expectIR("export default Lambda(() => ({ x: 1, y: 2 }))", { + type: "object", + entries: [ + { key: "x", value: { type: "number", value: "1" } }, + { key: "y", value: { type: "number", value: "2" } }, + ], + }); + }); + }); + + describe("let bindings", () => { + it("should compile block body with const to let node", () => { + const result = compileToIR( + `export default Dynamics((tokens, parameters) => { + const mu = parameters.gravitational_constant; + return mu; + })`, + dynamicsContext(), + ); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.ir).toEqual({ + type: "let", + bindings: [ + { + name: "mu", + value: { type: "parameter", name: "gravitational_constant" }, + }, + ], + body: { type: "symbol", name: "mu" }, + }); + } + }); + + it("should compile multiple const bindings", () => { + const result = compileToIR( + `export default Dynamics((tokens, parameters) => { + const a = parameters.infection_rate; + const b = parameters.recovery_rate; + return a; + })`, + dynamicsContext(), + ); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.ir.type).toBe("let"); + if (result.ir.type === "let") { + expect(result.ir.bindings).toHaveLength(2); + expect(result.ir.bindings[0]!.name).toBe("a"); + expect(result.ir.bindings[1]!.name).toBe("b"); + } + } + }); + + it("should not emit let node for block without bindings", () => { + const result = compileToIR( + `export default Dynamics((tokens, parameters) => { + return parameters.infection_rate; + })`, + dynamicsContext(), + ); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.ir).toEqual({ + type: "parameter", + name: "infection_rate", + }); + } + }); + }); + + describe(".map() list comprehension", () => { + it("should compile .map with destructured params", () => { + const result = compileToIR( + `export default Lambda((tokens, parameters) => tokens.map(({ x }) => x * parameters.infection_rate))`, + defaultContext, + ); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.ir).toEqual({ + type: "listComprehension", + variable: "_iter", + collection: { type: "symbol", name: "tokens" }, + body: { + type: "binary", + op: "*", + left: { type: "symbol", name: "_iter_x" }, + right: { type: "parameter", name: "infection_rate" }, + }, + }); + } + }); + + it("should compile .map with simple identifier param", () => { + const result = compileToIR( + `export default Lambda((tokens) => tokens.map((token) => token))`, + defaultContext, + ); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.ir).toEqual({ + type: "listComprehension", + variable: "_iter", + collection: { type: "symbol", name: "tokens" }, + body: { type: "symbol", name: "_iter" }, + }); + } + }); + }); + + describe("derived distributions", () => { + it("should compile dist.map(arrow) to derivedDistribution", () => { + const result = compileToIR( + `export default Lambda(() => Distribution.Gaussian(0, 10).map((x) => x * 2))`, + defaultContext, + ); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.ir).toEqual({ + type: "derivedDistribution", + distribution: { + type: "distribution", + distribution: "Gaussian", + args: [ + { type: "number", value: "0" }, + { type: "number", value: "10" }, + ], + }, + variable: "_x", + body: { + type: "binary", + op: "*", + left: { type: "symbol", name: "_x" }, + right: { type: "number", value: "2" }, + }, + }); + } + }); + + it("should compile dist.map(Math.cos) with function reference", () => { + const result = compileToIR( + `export default Lambda(() => Distribution.Gaussian(0, 10).map(Math.cos))`, + defaultContext, + ); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.ir).toEqual({ + type: "derivedDistribution", + distribution: { + type: "distribution", + distribution: "Gaussian", + args: [ + { type: "number", value: "0" }, + { type: "number", value: "10" }, + ], + }, + variable: "_x", + body: { + type: "call", + fn: "cos", + args: [{ type: "symbol", name: "_x" }], + }, + }); + } + }); + + it("should detect distribution through const binding", () => { + const result = compileToIR( + `export default Lambda(() => { + const angle = Distribution.Gaussian(0, 10); + return angle.map(Math.cos); + })`, + defaultContext, + ); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.ir.type).toBe("let"); + if (result.ir.type === "let") { + expect(result.ir.body).toEqual({ + type: "derivedDistribution", + distribution: { type: "symbol", name: "angle" }, + variable: "_x", + body: { + type: "call", + fn: "cos", + args: [{ type: "symbol", name: "_x" }], + }, + }); + } + } + }); + + it("should chain derived distributions", () => { + const result = compileToIR( + `export default Lambda((tokens, parameters) => { + const angle = Distribution.Gaussian(0, 10); + const cosAngle = angle.map(Math.cos); + return cosAngle.map((x) => x * parameters.infection_rate); + })`, + defaultContext, + ); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.ir.type).toBe("let"); + if (result.ir.type === "let") { + expect(result.ir.body.type).toBe("derivedDistribution"); + } + } + }); + }); + + describe("error handling", () => { + it("should reject code without default export", () => { + const result = compileToIR("const x = Lambda(() => 1);", defaultContext); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toContain("No default export"); + } + }); + + it("should reject string literals", () => { + const result = compileToIR( + `export default Lambda(() => "hello")`, + defaultContext, + ); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toContain("String literals"); + } + }); + + it("should reject unsupported Math function", () => { + const result = compileToIR( + "export default Lambda(() => Math.random())", + defaultContext, + ); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toContain("Unsupported Math function"); + } + }); + + it("should reject let declarations", () => { + const result = compileToIR( + `export default Lambda(() => { let x = 1; return x; })`, + defaultContext, + ); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toContain("let"); + } + }); + }); +}); diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-ir.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-ir.ts new file mode 100644 index 00000000000..bc59ba1dc17 --- /dev/null +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-ir.ts @@ -0,0 +1,851 @@ +import ts from "typescript"; + +import type { SDCPN, Transition } from "../../core/types/sdcpn"; +import type { BinaryOp, ExpressionIR, UnaryOp } from "./expression-ir"; + +/** + * Context for compilation, derived from the SDCPN model. + * Tells the compiler which identifiers are parameters vs. token fields. + */ +export type CompilationContext = { + parameterNames: Set; + /** Maps place name to its token field names */ + placeTokenFields: Map; + constructorFnName: string; +}; + +/** + * Builds a CompilationContext from an SDCPN model for a given transition. + */ +export function buildContextForTransition( + sdcpn: SDCPN, + transition: Transition, + constructorFnName: string, +): CompilationContext { + const parameterNames = new Set( + sdcpn.parameters.map((param) => param.variableName), + ); + const placeTokenFields = new Map(); + + const placeById = new Map(sdcpn.places.map((pl) => [pl.id, pl])); + const colorById = new Map(sdcpn.types.map((ct) => [ct.id, ct])); + + for (const arc of transition.inputArcs) { + const place = placeById.get(arc.placeId); + if (!place?.colorId) { + continue; + } + const color = colorById.get(place.colorId); + if (!color) { + continue; + } + placeTokenFields.set( + place.name, + color.elements.map((el) => el.name), + ); + } + + return { parameterNames, placeTokenFields, constructorFnName }; +} + +/** + * Builds a CompilationContext from an SDCPN model for a differential equation. + */ +export function buildContextForDifferentialEquation( + sdcpn: SDCPN, + colorId: string, +): CompilationContext { + const parameterNames = new Set( + sdcpn.parameters.map((param) => param.variableName), + ); + const placeTokenFields = new Map(); + + const color = sdcpn.types.find((ct) => ct.id === colorId); + if (color) { + placeTokenFields.set( + color.name, + color.elements.map((el) => el.name), + ); + } + + return { parameterNames, placeTokenFields, constructorFnName: "Dynamics" }; +} + +export type IRResult = + | { ok: true; ir: ExpressionIR } + | { ok: false; error: string; start: number; length: number }; + +/** Shorthand for building an error result with position from a TS AST node. */ +function err( + error: string, + node: ts.Node, + sourceFile: ts.SourceFile, +): IRResult & { ok: false } { + return { + ok: false, + error, + start: node.getStart(sourceFile), + length: node.getWidth(sourceFile), + }; +} + +/** Error result for cases where no specific node is available. */ +function errNoPos(error: string): IRResult & { ok: false } { + return { ok: false, error, start: 0, length: 0 }; +} + +/** + * Scope tracking for the IR compiler. + * Tracks local binding names and symbol overrides (from .map() destructuring). + */ +type Scope = { + /** Names defined by `const` in the current scope */ + localBindingNames: Set; + /** Rewritten names for .map() destructured parameters (e.g., x → _iter_x) */ + symbolOverrides: Map; + /** Names bound to distribution expressions (for .map() detection) */ + distributionBindings: Set; +}; + +function emptyScope(): Scope { + return { + localBindingNames: new Set(), + symbolOverrides: new Map(), + distributionBindings: new Set(), + }; +} + +/** + * Compiles a Petrinaut TypeScript expression to the expression IR. + * + * Expects code following the pattern: + * `export default ConstructorFn((params...) => expression)` + * + * Only a restricted subset of TypeScript is supported — pure expressions + * with arithmetic, Math functions, parameter/token access, and distributions. + * Anything outside this subset is rejected with a diagnostic. + * + * @param code - The TypeScript expression code string + * @param context - Compilation context with parameter names and token fields + * @returns Either `{ ok: true, ir }` or `{ ok: false, error }` + */ +export function compileToIR( + code: string, + context: CompilationContext, +): IRResult { + const sourceFile = ts.createSourceFile( + "input.ts", + code, + ts.ScriptTarget.ES2015, + true, + ); + + // Find the default export + const exportAssignment = sourceFile.statements.find( + (stmt): stmt is ts.ExportAssignment => + ts.isExportAssignment(stmt) && !stmt.isExportEquals, + ); + + if (!exportAssignment) { + return errNoPos("No default export found"); + } + + const exportExpr = exportAssignment.expression; + + // Expect ConstructorFn(...) + if (!ts.isCallExpression(exportExpr)) { + return err( + `Expected ${context.constructorFnName}(...), got ${ts.SyntaxKind[exportExpr.kind]}`, + exportExpr, + sourceFile, + ); + } + + const callee = exportExpr.expression; + if (!ts.isIdentifier(callee) || callee.text !== context.constructorFnName) { + return err( + `Expected ${context.constructorFnName}(...), got ${callee.getText(sourceFile)}(...)`, + callee, + sourceFile, + ); + } + + if (exportExpr.arguments.length !== 1) { + return err( + `${context.constructorFnName} expects exactly one argument`, + exportExpr, + sourceFile, + ); + } + + const arg = exportExpr.arguments[0]!; + + // The argument should be an arrow function or function expression + if (!ts.isArrowFunction(arg) && !ts.isFunctionExpression(arg)) { + return err( + `Expected a function argument, got ${ts.SyntaxKind[arg.kind]}`, + arg, + sourceFile, + ); + } + + const scope = emptyScope(); + + // Compile the body + const body = arg.body; + + if (ts.isBlock(body)) { + return compileBlockToIR(body, context, scope, sourceFile); + } + + // Expression body — emit directly + return emitIR(body, context, scope, sourceFile); +} + +function compileBlockToIR( + block: ts.Block, + context: CompilationContext, + outerScope: Scope, + sourceFile: ts.SourceFile, +): IRResult { + const bindings: { name: string; value: ExpressionIR }[] = []; + const scope: Scope = { + localBindingNames: new Set(outerScope.localBindingNames), + symbolOverrides: new Map(outerScope.symbolOverrides), + distributionBindings: new Set(outerScope.distributionBindings), + }; + let bodyIR: ExpressionIR | undefined; + + for (const stmt of block.statements) { + if (ts.isVariableStatement(stmt)) { + for (const decl of stmt.declarationList.declarations) { + if (!decl.initializer) { + return err( + "Variable declaration without initializer", + decl, + sourceFile, + ); + } + if (!(stmt.declarationList.flags & ts.NodeFlags.Const)) { + return err( + "'let' and 'var' declarations are not supported, use 'const'", + stmt, + sourceFile, + ); + } + const name = decl.name.getText(sourceFile); + const valueResult = emitIR( + decl.initializer, + context, + scope, + sourceFile, + ); + if (!valueResult.ok) return valueResult; + bindings.push({ name, value: valueResult.ir }); + scope.localBindingNames.add(name); + if ( + valueResult.ir.type === "distribution" || + valueResult.ir.type === "derivedDistribution" + ) { + scope.distributionBindings.add(name); + } + } + } else if (ts.isReturnStatement(stmt)) { + if (!stmt.expression) { + return err("Empty return statement", stmt, sourceFile); + } + const result = emitIR(stmt.expression, context, scope, sourceFile); + if (!result.ok) return result; + bodyIR = result.ir; + } else if (ts.isExpressionStatement(stmt)) { + return err( + "Standalone expression has no effect — assign to a const or return it", + stmt, + sourceFile, + ); + } else { + return err( + `Unsupported statement: ${ts.SyntaxKind[stmt.kind]}`, + stmt, + sourceFile, + ); + } + } + + if (!bodyIR) { + return err("Empty function body", block, sourceFile); + } + + if (bindings.length > 0) { + return { ok: true, ir: { type: "let", bindings, body: bodyIR } }; + } + return { ok: true, ir: bodyIR }; +} + +/** + * Compiles `collection.map(callback)` to a list comprehension IR node. + */ +function compileMapCallToIR( + collection: ts.Expression, + callback: ts.ArrowFunction | ts.FunctionExpression, + context: CompilationContext, + outerScope: Scope, + sourceFile: ts.SourceFile, +): IRResult { + const iterVar = "_iter"; + const mapScope: Scope = { + localBindingNames: new Set(outerScope.localBindingNames), + symbolOverrides: new Map(outerScope.symbolOverrides), + distributionBindings: new Set(outerScope.distributionBindings), + }; + + const param = callback.parameters[0]; + if (param) { + const paramName = param.name; + if (ts.isObjectBindingPattern(paramName)) { + for (const element of paramName.elements) { + const fieldName = element.name.getText(sourceFile); + mapScope.symbolOverrides.set(fieldName, `${iterVar}_${fieldName}`); + } + } else { + mapScope.symbolOverrides.set(paramName.getText(sourceFile), iterVar); + } + } + + // Compile the body + const body = callback.body; + let bodyResult: IRResult; + if (ts.isBlock(body)) { + bodyResult = compileBlockToIR(body, context, mapScope, sourceFile); + } else { + bodyResult = emitIR(body, context, mapScope, sourceFile); + } + if (!bodyResult.ok) return bodyResult; + + // Compile the collection expression + const collectionResult = emitIR(collection, context, outerScope, sourceFile); + if (!collectionResult.ok) return collectionResult; + + return { + ok: true, + ir: { + type: "listComprehension", + body: bodyResult.ir, + variable: iterVar, + collection: collectionResult.ir, + }, + }; +} + +/** + * Checks whether a TS expression will produce a distribution IR node. + * + * Handles two cases: + * - Direct: `Distribution.Gaussian(...)` call expressions + * - Via binding: identifiers bound to a distribution in a const declaration + */ +function isDistributionExpression(node: ts.Expression, scope: Scope): boolean { + // Direct: Distribution.Fn(...) + if ( + ts.isCallExpression(node) && + ts.isPropertyAccessExpression(node.expression) && + ts.isIdentifier(node.expression.expression) && + node.expression.expression.text === "Distribution" + ) { + return true; + } + + // Via binding: const angle = Distribution.Gaussian(0, 10); angle.map(...) + if (ts.isIdentifier(node) && scope.distributionBindings.has(node.text)) { + return true; + } + + return false; +} + +/** + * Compiles `distribution.map(transform)` to a derived distribution IR node. + * + * Handles two callback forms: + * - Arrow/function expression: `dist.map((x) => Math.cos(x))` + * - Function reference: `dist.map(Math.cos)` — expanded to `(_x) => Math.cos(_x)` + */ +function compileDerivedDistribution( + collection: ts.Expression, + callback: ts.Expression, + context: CompilationContext, + scope: Scope, + sourceFile: ts.SourceFile, +): IRResult { + const variable = "_x"; + + // Compile the base distribution + const distResult = emitIR(collection, context, scope, sourceFile); + if (!distResult.ok) return distResult; + + let bodyIR: ExpressionIR; + + if (ts.isArrowFunction(callback) || ts.isFunctionExpression(callback)) { + // Arrow function: (x) => expr + // Bind the parameter name to our variable + const mapScope: Scope = { + localBindingNames: new Set(scope.localBindingNames), + symbolOverrides: new Map(scope.symbolOverrides), + distributionBindings: new Set(scope.distributionBindings), + }; + + const param = callback.parameters[0]; + if (param) { + mapScope.symbolOverrides.set(param.name.getText(sourceFile), variable); + } + + const body = callback.body; + let bodyResult: IRResult; + if (ts.isBlock(body)) { + bodyResult = compileBlockToIR(body, context, mapScope, sourceFile); + } else { + bodyResult = emitIR(body, context, mapScope, sourceFile); + } + if (!bodyResult.ok) return bodyResult; + bodyIR = bodyResult.ir; + } else { + // Function reference: Math.cos, Math.sin, etc. + // Expand to: (_x) => fn(_x) + const fnResult = emitIR(callback, context, scope, sourceFile); + if (!fnResult.ok) return fnResult; + + // The fn should be a call-like reference (e.g. Math.cos). + // We synthesize a call node: fn(_x) + if (fnResult.ir.type === "symbol" && fnResult.ir.name.startsWith("Math.")) { + // Math.cos → { type: "call", fn: "cos", args: [symbol("_x")] } + const fnName = fnResult.ir.name.slice("Math.".length); + bodyIR = { + type: "call", + fn: fnName, + args: [{ type: "symbol", name: variable }], + }; + } else { + return err( + "Distribution .map() callback must be a function expression or a Math function reference", + callback, + sourceFile, + ); + } + } + + return { + ok: true, + ir: { + type: "derivedDistribution", + distribution: distResult.ir, + variable, + body: bodyIR, + }, + }; +} + +const SUPPORTED_MATH_FUNCTIONS = new Set([ + "cos", + "sin", + "tan", + "acos", + "asin", + "atan", + "atan2", + "sqrt", + "log", + "exp", + "abs", + "floor", + "ceil", + "pow", + "min", + "max", + "hypot", +]); + +const MATH_CONSTANTS: Record = { + PI: { type: "symbol", name: "PI" }, + E: { type: "symbol", name: "E" }, +}; + +const TS_BINARY_OP_MAP: Partial> = { + [ts.SyntaxKind.PlusToken]: "+", + [ts.SyntaxKind.MinusToken]: "-", + [ts.SyntaxKind.AsteriskToken]: "*", + [ts.SyntaxKind.SlashToken]: "/", + [ts.SyntaxKind.AsteriskAsteriskToken]: "**", + [ts.SyntaxKind.PercentToken]: "%", + [ts.SyntaxKind.LessThanToken]: "<", + [ts.SyntaxKind.LessThanEqualsToken]: "<=", + [ts.SyntaxKind.GreaterThanToken]: ">", + [ts.SyntaxKind.GreaterThanEqualsToken]: ">=", + [ts.SyntaxKind.EqualsEqualsToken]: "==", + [ts.SyntaxKind.EqualsEqualsEqualsToken]: "==", + [ts.SyntaxKind.ExclamationEqualsToken]: "!=", + [ts.SyntaxKind.ExclamationEqualsEqualsToken]: "!=", + [ts.SyntaxKind.AmpersandAmpersandToken]: "&&", + [ts.SyntaxKind.BarBarToken]: "||", +}; + +const TS_UNARY_OP_MAP: Partial> = { + [ts.SyntaxKind.MinusToken]: "-", + [ts.SyntaxKind.ExclamationToken]: "!", + [ts.SyntaxKind.PlusToken]: "+", +}; + +function emitIR( + node: ts.Node, + context: CompilationContext, + scope: Scope, + sourceFile: ts.SourceFile, +): IRResult { + // Numeric literal + if (ts.isNumericLiteral(node)) { + return { ok: true, ir: { type: "number", value: node.text } }; + } + + // String literal — not supported in symbolic math + if (ts.isStringLiteral(node)) { + return err( + "String literals are not supported in symbolic expressions", + node, + sourceFile, + ); + } + + // Boolean literals + if (node.kind === ts.SyntaxKind.TrueKeyword) { + return { ok: true, ir: { type: "boolean", value: true } }; + } + if (node.kind === ts.SyntaxKind.FalseKeyword) { + return { ok: true, ir: { type: "boolean", value: false } }; + } + + // Identifier + if (ts.isIdentifier(node)) { + const name = node.text; + if (name === "Infinity") { + return { ok: true, ir: { type: "infinity" } }; + } + if (scope.symbolOverrides.has(name)) { + return { + ok: true, + ir: { type: "symbol", name: scope.symbolOverrides.get(name)! }, + }; + } + if (scope.localBindingNames.has(name)) { + return { ok: true, ir: { type: "symbol", name } }; + } + if (context.parameterNames.has(name)) { + return { ok: true, ir: { type: "parameter", name } }; + } + // Could be a destructured token field or function param + return { ok: true, ir: { type: "symbol", name } }; + } + + // Parenthesized expression + if (ts.isParenthesizedExpression(node)) { + return emitIR(node.expression, context, scope, sourceFile); + } + + // Prefix unary expression (-x, !x) + if (ts.isPrefixUnaryExpression(node)) { + const operand = emitIR(node.operand, context, scope, sourceFile); + if (!operand.ok) return operand; + + const op = TS_UNARY_OP_MAP[node.operator]; + if (!op) { + return err( + `Unsupported prefix operator: ${ts.SyntaxKind[node.operator]}`, + node, + sourceFile, + ); + } + + if (op === "+") { + return operand; + } + + return { ok: true, ir: { type: "unary", op, operand: operand.ir } }; + } + + // Binary expression + if (ts.isBinaryExpression(node)) { + const left = emitIR(node.left, context, scope, sourceFile); + if (!left.ok) return left; + const right = emitIR(node.right, context, scope, sourceFile); + if (!right.ok) return right; + + const op = TS_BINARY_OP_MAP[node.operatorToken.kind]; + if (!op) { + return err( + `Unsupported binary operator: ${node.operatorToken.getText(sourceFile)}`, + node.operatorToken, + sourceFile, + ); + } + + return { + ok: true, + ir: { type: "binary", op, left: left.ir, right: right.ir }, + }; + } + + // Conditional (ternary) expression + if (ts.isConditionalExpression(node)) { + const condition = emitIR(node.condition, context, scope, sourceFile); + if (!condition.ok) return condition; + const whenTrue = emitIR(node.whenTrue, context, scope, sourceFile); + if (!whenTrue.ok) return whenTrue; + const whenFalse = emitIR(node.whenFalse, context, scope, sourceFile); + if (!whenFalse.ok) return whenFalse; + return { + ok: true, + ir: { + type: "piecewise", + condition: condition.ir, + whenTrue: whenTrue.ir, + whenFalse: whenFalse.ir, + }, + }; + } + + // Property access: parameters.x, tokens.Place[0].field, Math.PI + if (ts.isPropertyAccessExpression(node)) { + const propName = node.name.text; + + // Math constants: Math.PI, Math.E + if (ts.isIdentifier(node.expression) && node.expression.text === "Math") { + const constant = MATH_CONSTANTS[propName]; + if (constant) return { ok: true, ir: constant }; + // Math.method — return a placeholder for the call expression handler + return { ok: true, ir: { type: "symbol", name: `Math.${propName}` } }; + } + + // parameters.x + if ( + ts.isIdentifier(node.expression) && + node.expression.text === "parameters" + ) { + return { ok: true, ir: { type: "parameter", name: propName } }; + } + + // tokens.Place[0].field — handle the chain + if (ts.isElementAccessExpression(node.expression)) { + const elemAccess = node.expression; + if (ts.isPropertyAccessExpression(elemAccess.expression)) { + const placePropAccess = elemAccess.expression; + if ( + ts.isIdentifier(placePropAccess.expression) && + placePropAccess.expression.text === "tokens" + ) { + const placeName = placePropAccess.name.text; + const indexResult = emitIR( + elemAccess.argumentExpression, + context, + scope, + sourceFile, + ); + if (!indexResult.ok) return indexResult; + return { + ok: true, + ir: { + type: "tokenAccess", + place: placeName, + index: indexResult.ir, + field: propName, + }, + }; + } + } + } + + // Generic property access + const obj = emitIR(node.expression, context, scope, sourceFile); + if (!obj.ok) return obj; + return { + ok: true, + ir: { type: "propertyAccess", object: obj.ir, property: propName }, + }; + } + + // Element access: tokens.Place[0], arr[i] + if (ts.isElementAccessExpression(node)) { + const obj = emitIR(node.expression, context, scope, sourceFile); + if (!obj.ok) return obj; + const index = emitIR(node.argumentExpression, context, scope, sourceFile); + if (!index.ok) return index; + return { + ok: true, + ir: { type: "elementAccess", object: obj.ir, index: index.ir }, + }; + } + + // Call expression + if (ts.isCallExpression(node)) { + const callee = node.expression; + + // Math.fn(...) + if ( + ts.isPropertyAccessExpression(callee) && + ts.isIdentifier(callee.expression) && + callee.expression.text === "Math" + ) { + const fnName = callee.name.text; + + if (!SUPPORTED_MATH_FUNCTIONS.has(fnName)) { + return err( + `Unsupported Math function: Math.${fnName}`, + callee, + sourceFile, + ); + } + + const args: ExpressionIR[] = []; + for (const a of node.arguments) { + const r = emitIR(a, context, scope, sourceFile); + if (!r.ok) return r; + args.push(r.ir); + } + return { ok: true, ir: { type: "call", fn: fnName, args } }; + } + + // Distribution.Gaussian(m, s), etc. + if ( + ts.isPropertyAccessExpression(callee) && + ts.isIdentifier(callee.expression) && + callee.expression.text === "Distribution" + ) { + const distName = callee.name.text; + const supportedDistributions = ["Gaussian", "Uniform", "Lognormal"]; + if (!supportedDistributions.includes(distName)) { + return err( + `Unsupported distribution: Distribution.${distName}`, + callee, + sourceFile, + ); + } + + const args: ExpressionIR[] = []; + for (const a of node.arguments) { + const r = emitIR(a, context, scope, sourceFile); + if (!r.ok) return r; + args.push(r.ir); + } + return { + ok: true, + ir: { type: "distribution", distribution: distName, args }, + }; + } + + // Global built-in functions: Boolean(expr), Number(expr) + if (ts.isIdentifier(callee)) { + if (callee.text === "Boolean" && node.arguments.length === 1) { + const argResult = emitIR( + node.arguments[0]!, + context, + scope, + sourceFile, + ); + if (!argResult.ok) return argResult; + // Boolean(expr) → expr != 0 + return { + ok: true, + ir: { + type: "binary", + op: "!=", + left: argResult.ir, + right: { type: "number", value: "0" }, + }, + }; + } + + if (callee.text === "Number" && node.arguments.length === 1) { + return emitIR(node.arguments[0]!, context, scope, sourceFile); + } + } + + // .map(callback) + if ( + ts.isPropertyAccessExpression(callee) && + callee.name.text === "map" && + node.arguments.length === 1 + ) { + const callback = node.arguments[0]!; + + // Check if the target is a distribution (for derived distributions) + if (isDistributionExpression(callee.expression, scope)) { + return compileDerivedDistribution( + callee.expression, + callback, + context, + scope, + sourceFile, + ); + } + + if (ts.isArrowFunction(callback) || ts.isFunctionExpression(callback)) { + return compileMapCallToIR( + callee.expression, + callback, + context, + scope, + sourceFile, + ); + } + } + + return err( + `Unsupported function call: ${callee.getText(sourceFile)}`, + node, + sourceFile, + ); + } + + // Array literal expression [a, b, c] + if (ts.isArrayLiteralExpression(node)) { + const elements: ExpressionIR[] = []; + for (const elem of node.elements) { + const result = emitIR(elem, context, scope, sourceFile); + if (!result.ok) return result; + elements.push(result.ir); + } + return { ok: true, ir: { type: "array", elements } }; + } + + // Object literal expression { field: expr, ... } + if (ts.isObjectLiteralExpression(node)) { + const entries: { key: string; value: ExpressionIR }[] = []; + for (const prop of node.properties) { + if (!ts.isPropertyAssignment(prop)) { + return err( + `Unsupported object property kind: ${ts.SyntaxKind[prop.kind]}`, + prop, + sourceFile, + ); + } + const key = prop.name.getText(sourceFile); + const val = emitIR(prop.initializer, context, scope, sourceFile); + if (!val.ok) return val; + entries.push({ key, value: val.ir }); + } + return { ok: true, ir: { type: "object", entries } }; + } + + // Non-null assertion (x!) — just unwrap + if (ts.isNonNullExpression(node)) { + return emitIR(node.expression, context, scope, sourceFile); + } + + // Type assertion (x as T) — just unwrap + if (ts.isAsExpression(node)) { + return emitIR(node.expression, context, scope, sourceFile); + } + + return err( + `Unsupported syntax: ${ts.SyntaxKind[node.kind]}`, + node, + sourceFile, + ); +} diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts index 0ae9e4eef5d..719bc1f1a0a 100644 --- a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts @@ -423,6 +423,35 @@ describe("compileToSymPy", () => { }); }); + describe("derived distributions", () => { + it("should compile distribution.map(Math.cos) end-to-end", () => { + const result = compileToSymPy( + `export default Lambda(() => { + const angle = Distribution.Gaussian(0, 10); + return angle.map(Math.cos); + })`, + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: + "DerivedDistribution(sp.stats.Normal('X', 0, 10), lambda _x: sp.cos(_x))", + }); + }); + + it("should compile inline distribution.map(arrow)", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => Distribution.Uniform(0, 1).map((x) => x * parameters.infection_rate))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: + "DerivedDistribution(sp.stats.Uniform('X', 0, 1), lambda _x: _x * infection_rate)", + }); + }); + }); + describe("global built-in functions", () => { it("should compile Boolean(expr) to sp.Ne(expr, 0)", () => { const result = compileToSymPy( diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts index 7d00e5a34ba..7cd6c783d3a 100644 --- a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts @@ -1,108 +1,31 @@ -import ts from "typescript"; - -import type { SDCPN, Transition } from "../../core/types/sdcpn"; - -/** - * Context for SymPy compilation, derived from the SDCPN model. - * Tells the compiler which identifiers are parameters vs. token fields. - */ -export type SymPyCompilationContext = { - parameterNames: Set; - /** Maps place name to its token field names */ - placeTokenFields: Map; - constructorFnName: string; -}; - -/** - * Builds a SymPyCompilationContext from an SDCPN model for a given transition. - */ -export function buildContextForTransition( - sdcpn: SDCPN, - transition: Transition, - constructorFnName: string, -): SymPyCompilationContext { - const parameterNames = new Set( - sdcpn.parameters.map((param) => param.variableName), - ); - const placeTokenFields = new Map(); - - const placeById = new Map(sdcpn.places.map((pl) => [pl.id, pl])); - const colorById = new Map(sdcpn.types.map((ct) => [ct.id, ct])); - - for (const arc of transition.inputArcs) { - const place = placeById.get(arc.placeId); - if (!place?.colorId) { - continue; - } - const color = colorById.get(place.colorId); - if (!color) { - continue; - } - placeTokenFields.set( - place.name, - color.elements.map((el) => el.name), - ); - } - - return { parameterNames, placeTokenFields, constructorFnName }; -} - -/** - * Builds a SymPyCompilationContext from an SDCPN model for a differential equation. - */ -export function buildContextForDifferentialEquation( - sdcpn: SDCPN, - colorId: string, -): SymPyCompilationContext { - const parameterNames = new Set( - sdcpn.parameters.map((param) => param.variableName), - ); - const placeTokenFields = new Map(); - - const color = sdcpn.types.find((ct) => ct.id === colorId); - if (color) { - // DE operates on tokens of its color type - placeTokenFields.set( - color.name, - color.elements.map((el) => el.name), - ); - } - - return { parameterNames, placeTokenFields, constructorFnName: "Dynamics" }; -} +import type { CompilationContext } from "./compile-to-ir"; +import { compileToIR } from "./compile-to-ir"; +import { irToSymPy } from "./ir-to-sympy"; + +// Re-export types and context builders for backward compatibility +export type { CompilationContext } from "./compile-to-ir"; +export type { IRResult } from "./compile-to-ir"; +export type { ExpressionIR } from "./expression-ir"; +export { + buildContextForDifferentialEquation, + buildContextForTransition, + compileToIR, +} from "./compile-to-ir"; +export { irToSymPy } from "./ir-to-sympy"; + +/** @deprecated Use {@link CompilationContext} instead. */ +export type SymPyCompilationContext = CompilationContext; export type SymPyResult = | { ok: true; sympyCode: string } | { ok: false; error: string; start: number; length: number }; -/** Shorthand for building an error result with position from a TS AST node. */ -function err( - error: string, - node: ts.Node, - sourceFile: ts.SourceFile, -): SymPyResult & { ok: false } { - return { - ok: false, - error, - start: node.getStart(sourceFile), - length: node.getWidth(sourceFile), - }; -} - -/** Error result for cases where no specific node is available. */ -function errNoPos(error: string): SymPyResult & { ok: false } { - return { ok: false, error, start: 0, length: 0 }; -} - /** * Compiles a Petrinaut TypeScript expression to SymPy Python code. * - * Expects code following the pattern: - * `export default ConstructorFn((params...) => expression)` - * - * Only a restricted subset of TypeScript is supported — pure expressions - * with arithmetic, Math functions, parameter/token access, and distributions. - * Anything outside this subset is rejected with a diagnostic. + * This is a convenience wrapper that composes two steps: + * 1. TypeScript → Expression IR ({@link compileToIR}) + * 2. Expression IR → SymPy ({@link irToSymPy}) * * @param code - The TypeScript expression code string * @param context - Compilation context with parameter names and token fields @@ -110,784 +33,9 @@ function errNoPos(error: string): SymPyResult & { ok: false } { */ export function compileToSymPy( code: string, - context: SymPyCompilationContext, -): SymPyResult { - const sourceFile = ts.createSourceFile( - "input.ts", - code, - ts.ScriptTarget.ES2015, - true, - ); - - // Find the default export - const exportAssignment = sourceFile.statements.find( - (stmt): stmt is ts.ExportAssignment => - ts.isExportAssignment(stmt) && !stmt.isExportEquals, - ); - - if (!exportAssignment) { - // Try export default as ExpressionStatement pattern - const exportDefault = sourceFile.statements.find((stmt) => { - if (ts.isExportAssignment(stmt)) { - return true; - } - // Handle "export default X(...)" which parses as ExportAssignment - return false; - }); - if (!exportDefault) { - return errNoPos("No default export found"); - } - } - - const exportExpr = exportAssignment!.expression; - - // Expect ConstructorFn(...) - if (!ts.isCallExpression(exportExpr)) { - return err( - `Expected ${context.constructorFnName}(...), got ${ts.SyntaxKind[exportExpr.kind]}`, - exportExpr, - sourceFile, - ); - } - - const callee = exportExpr.expression; - if (!ts.isIdentifier(callee) || callee.text !== context.constructorFnName) { - return err( - `Expected ${context.constructorFnName}(...), got ${callee.getText(sourceFile)}(...)`, - callee, - sourceFile, - ); - } - - if (exportExpr.arguments.length !== 1) { - return err( - `${context.constructorFnName} expects exactly one argument`, - exportExpr, - sourceFile, - ); - } - - const arg = exportExpr.arguments[0]!; - - // The argument should be an arrow function or function expression - if (!ts.isArrowFunction(arg) && !ts.isFunctionExpression(arg)) { - return err( - `Expected a function argument, got ${ts.SyntaxKind[arg.kind]}`, - arg, - sourceFile, - ); - } - - // Extract parameter names for the inner function - const localBindings = new Map(); - const innerParams = extractFunctionParams(arg, sourceFile); - - // Compile the body - const body = arg.body; - - if (ts.isBlock(body)) { - return compileBlock(body, context, localBindings, sourceFile); - } - - // Expression body — emit directly - const result = emitSymPy( - body, - context, - localBindings, - innerParams, - sourceFile, - ); - if (!result.ok) return result; - return { ok: true, sympyCode: result.sympyCode }; -} - -function extractFunctionParams( - fn: ts.ArrowFunction | ts.FunctionExpression, - sourceFile: ts.SourceFile, -): string[] { - return fn.parameters.map((p) => p.name.getText(sourceFile)); -} - -function compileBlock( - block: ts.Block, - context: SymPyCompilationContext, - localBindings: Map, - sourceFile: ts.SourceFile, -): SymPyResult { - const lines: string[] = []; - - for (const stmt of block.statements) { - if (ts.isVariableStatement(stmt)) { - for (const decl of stmt.declarationList.declarations) { - if (!decl.initializer) { - return err( - "Variable declaration without initializer", - decl, - sourceFile, - ); - } - if (!(stmt.declarationList.flags & ts.NodeFlags.Const)) { - return err( - "'let' and 'var' declarations are not supported, use 'const'", - stmt, - sourceFile, - ); - } - const name = decl.name.getText(sourceFile); - const valueResult = emitSymPy( - decl.initializer, - context, - localBindings, - [], - sourceFile, - ); - if (!valueResult.ok) return valueResult; - localBindings.set(name, valueResult.sympyCode); - lines.push(`${name} = ${valueResult.sympyCode}`); - } - } else if (ts.isReturnStatement(stmt)) { - if (!stmt.expression) { - return err("Empty return statement", stmt, sourceFile); - } - const result = emitSymPy( - stmt.expression, - context, - localBindings, - [], - sourceFile, - ); - if (!result.ok) return result; - lines.push(result.sympyCode); - } else if (ts.isExpressionStatement(stmt)) { - return err( - "Standalone expression has no effect — assign to a const or return it", - stmt, - sourceFile, - ); - } else { - return err( - `Unsupported statement: ${ts.SyntaxKind[stmt.kind]}`, - stmt, - sourceFile, - ); - } - } - - if (lines.length === 0) { - return err("Empty function body", block, sourceFile); - } - - return { ok: true, sympyCode: lines[lines.length - 1]! }; -} - -/** - * Compiles `collection.map(callback)` to a Python list comprehension. - * - * Handles two callback parameter styles: - * - Destructured: `({ x, y }) => ...` → binds each field as `_iter_x`, `_iter_y` - * - Simple identifier: `(token) => ...` → binds as-is - * - * Emits: `[ for _iter in ]` - */ -function compileMapCall( - collection: ts.Expression, - callback: ts.ArrowFunction | ts.FunctionExpression, - context: SymPyCompilationContext, - outerBindings: Map, - innerParams: string[], - sourceFile: ts.SourceFile, + context: CompilationContext, ): SymPyResult { - const iterVar = "_iter"; - const mapBindings = new Map(outerBindings); - - const param = callback.parameters[0]; - if (param) { - const paramName = param.name; - if (ts.isObjectBindingPattern(paramName)) { - // Destructured: ({ x, y, ... }) => ... - // Each field becomes a symbol like _iter_x, _iter_y - for (const element of paramName.elements) { - const fieldName = element.name.getText(sourceFile); - mapBindings.set(fieldName, `${iterVar}_${fieldName}`); - } - } else { - // Simple identifier: (token) => ... - mapBindings.set(paramName.getText(sourceFile), iterVar); - } - } - - // Compile the body - const body = callback.body; - let bodyResult: SymPyResult; - if (ts.isBlock(body)) { - bodyResult = compileBlock(body, context, mapBindings, sourceFile); - } else { - bodyResult = emitSymPy(body, context, mapBindings, innerParams, sourceFile); - } - if (!bodyResult.ok) return bodyResult; - - // Compile the collection expression - const collectionResult = emitSymPy( - collection, - context, - outerBindings, - innerParams, - sourceFile, - ); - if (!collectionResult.ok) return collectionResult; - - return { - ok: true, - sympyCode: `[${bodyResult.sympyCode} for ${iterVar} in ${collectionResult.sympyCode}]`, - }; -} - -const MATH_FUNCTION_MAP: Record = { - cos: "sp.cos", - sin: "sp.sin", - tan: "sp.tan", - acos: "sp.acos", - asin: "sp.asin", - atan: "sp.atan", - atan2: "sp.atan2", - sqrt: "sp.sqrt", - log: "sp.log", - exp: "sp.exp", - abs: "sp.Abs", - floor: "sp.floor", - ceil: "sp.ceiling", - pow: "sp.Pow", - min: "sp.Min", - max: "sp.Max", -}; - -const MATH_CONSTANT_MAP: Record = { - PI: "sp.pi", - E: "sp.E", - Infinity: "sp.oo", -}; - -function emitSymPy( - node: ts.Node, - context: SymPyCompilationContext, - localBindings: Map, - innerParams: string[], - sourceFile: ts.SourceFile, -): SymPyResult { - // Numeric literal - if (ts.isNumericLiteral(node)) { - return { ok: true, sympyCode: node.text }; - } - - // String literal — not supported in symbolic math - if (ts.isStringLiteral(node)) { - return err( - "String literals are not supported in symbolic expressions", - node, - sourceFile, - ); - } - - // Boolean literals - if (node.kind === ts.SyntaxKind.TrueKeyword) { - return { ok: true, sympyCode: "True" }; - } - if (node.kind === ts.SyntaxKind.FalseKeyword) { - return { ok: true, sympyCode: "False" }; - } - - // Identifier - if (ts.isIdentifier(node)) { - const name = node.text; - if (name === "Infinity") return { ok: true, sympyCode: "sp.oo" }; - if (localBindings.has(name)) { - return { ok: true, sympyCode: localBindings.get(name)! }; - } - if (context.parameterNames.has(name)) { - return { ok: true, sympyCode: name }; - } - // Could be a destructured token field or function param - return { ok: true, sympyCode: name }; - } - - // Parenthesized expression - if (ts.isParenthesizedExpression(node)) { - const inner = emitSymPy( - node.expression, - context, - localBindings, - innerParams, - sourceFile, - ); - if (!inner.ok) return inner; - return { ok: true, sympyCode: `(${inner.sympyCode})` }; - } - - // Prefix unary expression (-x, !x) - if (ts.isPrefixUnaryExpression(node)) { - const operand = emitSymPy( - node.operand, - context, - localBindings, - innerParams, - sourceFile, - ); - if (!operand.ok) return operand; - - switch (node.operator) { - case ts.SyntaxKind.MinusToken: - return { ok: true, sympyCode: `-(${operand.sympyCode})` }; - case ts.SyntaxKind.ExclamationToken: - return { ok: true, sympyCode: `sp.Not(${operand.sympyCode})` }; - case ts.SyntaxKind.PlusToken: - return operand; - default: - return err( - `Unsupported prefix operator: ${ts.SyntaxKind[node.operator]}`, - node, - sourceFile, - ); - } - } - - // Binary expression - if (ts.isBinaryExpression(node)) { - const left = emitSymPy( - node.left, - context, - localBindings, - innerParams, - sourceFile, - ); - if (!left.ok) return left; - const right = emitSymPy( - node.right, - context, - localBindings, - innerParams, - sourceFile, - ); - if (!right.ok) return right; - - switch (node.operatorToken.kind) { - case ts.SyntaxKind.PlusToken: - return { - ok: true, - sympyCode: `${left.sympyCode} + ${right.sympyCode}`, - }; - case ts.SyntaxKind.MinusToken: - return { - ok: true, - sympyCode: `${left.sympyCode} - ${right.sympyCode}`, - }; - case ts.SyntaxKind.AsteriskToken: - return { - ok: true, - sympyCode: `${left.sympyCode} * ${right.sympyCode}`, - }; - case ts.SyntaxKind.SlashToken: - return { - ok: true, - sympyCode: `${left.sympyCode} / ${right.sympyCode}`, - }; - case ts.SyntaxKind.AsteriskAsteriskToken: - return { - ok: true, - sympyCode: `${left.sympyCode}**${right.sympyCode}`, - }; - case ts.SyntaxKind.PercentToken: - return { - ok: true, - sympyCode: `sp.Mod(${left.sympyCode}, ${right.sympyCode})`, - }; - case ts.SyntaxKind.LessThanToken: - return { - ok: true, - sympyCode: `${left.sympyCode} < ${right.sympyCode}`, - }; - case ts.SyntaxKind.LessThanEqualsToken: - return { - ok: true, - sympyCode: `${left.sympyCode} <= ${right.sympyCode}`, - }; - case ts.SyntaxKind.GreaterThanToken: - return { - ok: true, - sympyCode: `${left.sympyCode} > ${right.sympyCode}`, - }; - case ts.SyntaxKind.GreaterThanEqualsToken: - return { - ok: true, - sympyCode: `${left.sympyCode} >= ${right.sympyCode}`, - }; - case ts.SyntaxKind.EqualsEqualsToken: - case ts.SyntaxKind.EqualsEqualsEqualsToken: - return { - ok: true, - sympyCode: `sp.Eq(${left.sympyCode}, ${right.sympyCode})`, - }; - case ts.SyntaxKind.ExclamationEqualsToken: - case ts.SyntaxKind.ExclamationEqualsEqualsToken: - return { - ok: true, - sympyCode: `sp.Ne(${left.sympyCode}, ${right.sympyCode})`, - }; - case ts.SyntaxKind.AmpersandAmpersandToken: - return { - ok: true, - sympyCode: `sp.And(${left.sympyCode}, ${right.sympyCode})`, - }; - case ts.SyntaxKind.BarBarToken: - return { - ok: true, - sympyCode: `sp.Or(${left.sympyCode}, ${right.sympyCode})`, - }; - default: - return err( - `Unsupported binary operator: ${node.operatorToken.getText(sourceFile)}`, - node.operatorToken, - sourceFile, - ); - } - } - - // Conditional (ternary) expression - if (ts.isConditionalExpression(node)) { - const condition = emitSymPy( - node.condition, - context, - localBindings, - innerParams, - sourceFile, - ); - if (!condition.ok) return condition; - const whenTrue = emitSymPy( - node.whenTrue, - context, - localBindings, - innerParams, - sourceFile, - ); - if (!whenTrue.ok) return whenTrue; - const whenFalse = emitSymPy( - node.whenFalse, - context, - localBindings, - innerParams, - sourceFile, - ); - if (!whenFalse.ok) return whenFalse; - return { - ok: true, - sympyCode: `sp.Piecewise((${whenTrue.sympyCode}, ${condition.sympyCode}), (${whenFalse.sympyCode}, True))`, - }; - } - - // Property access: parameters.x, tokens.Place[0].field, Math.PI - if (ts.isPropertyAccessExpression(node)) { - const propName = node.name.text; - - // Math constants: Math.PI, Math.E - if (ts.isIdentifier(node.expression) && node.expression.text === "Math") { - const constant = MATH_CONSTANT_MAP[propName]; - if (constant) return { ok: true, sympyCode: constant }; - // Math.method will be handled as part of a CallExpression - // Return a placeholder that the call expression handler will use - return { ok: true, sympyCode: `Math.${propName}` }; - } - - // parameters.x - if ( - ts.isIdentifier(node.expression) && - node.expression.text === "parameters" - ) { - return { ok: true, sympyCode: propName }; - } - - // tokens.Place[0].field — handle the chain - // First check: something.field where something is an element access - if (ts.isElementAccessExpression(node.expression)) { - // e.g., tokens.Space[0].x - const elemAccess = node.expression; - if (ts.isPropertyAccessExpression(elemAccess.expression)) { - const placePropAccess = elemAccess.expression; - if ( - ts.isIdentifier(placePropAccess.expression) && - placePropAccess.expression.text === "tokens" - ) { - const placeName = placePropAccess.name.text; - const indexExpr = elemAccess.argumentExpression; - const indexText = indexExpr.getText(sourceFile); - return { - ok: true, - sympyCode: `${placeName}_${indexText}_${propName}`, - }; - } - } - } - - // Generic property access — emit as dot access - const obj = emitSymPy( - node.expression, - context, - localBindings, - innerParams, - sourceFile, - ); - if (!obj.ok) return obj; - return { ok: true, sympyCode: `${obj.sympyCode}_${propName}` }; - } - - // Element access: tokens.Place[0], arr[i] - if (ts.isElementAccessExpression(node)) { - const obj = emitSymPy( - node.expression, - context, - localBindings, - innerParams, - sourceFile, - ); - if (!obj.ok) return obj; - const index = emitSymPy( - node.argumentExpression, - context, - localBindings, - innerParams, - sourceFile, - ); - if (!index.ok) return index; - return { ok: true, sympyCode: `${obj.sympyCode}_${index.sympyCode}` }; - } - - // Call expression: Math.cos(x), Math.hypot(a, b), Distribution.Gaussian(m, s) - if (ts.isCallExpression(node)) { - const callee = node.expression; - - // Math.fn(...) - if ( - ts.isPropertyAccessExpression(callee) && - ts.isIdentifier(callee.expression) && - callee.expression.text === "Math" - ) { - const fnName = callee.name.text; - - // Special case: Math.hypot(a, b) -> sp.sqrt(a**2 + b**2) - if (fnName === "hypot") { - const args: string[] = []; - for (const a of node.arguments) { - const r = emitSymPy( - a, - context, - localBindings, - innerParams, - sourceFile, - ); - if (!r.ok) return r; - args.push(r.sympyCode); - } - const sumOfSquares = args.map((a) => `(${a})**2`).join(" + "); - return { ok: true, sympyCode: `sp.sqrt(${sumOfSquares})` }; - } - - // Special case: Math.pow(a, b) -> a**b - if (fnName === "pow" && node.arguments.length === 2) { - const base = emitSymPy( - node.arguments[0]!, - context, - localBindings, - innerParams, - sourceFile, - ); - if (!base.ok) return base; - const exp = emitSymPy( - node.arguments[1]!, - context, - localBindings, - innerParams, - sourceFile, - ); - if (!exp.ok) return exp; - return { - ok: true, - sympyCode: `(${base.sympyCode})**(${exp.sympyCode})`, - }; - } - - const sympyFn = MATH_FUNCTION_MAP[fnName]; - if (!sympyFn) { - return err( - `Unsupported Math function: Math.${fnName}`, - callee, - sourceFile, - ); - } - - const args: string[] = []; - for (const a of node.arguments) { - const r = emitSymPy(a, context, localBindings, innerParams, sourceFile); - if (!r.ok) return r; - args.push(r.sympyCode); - } - return { ok: true, sympyCode: `${sympyFn}(${args.join(", ")})` }; - } - - // Distribution.Gaussian(m, s), Distribution.Uniform(a, b), Distribution.Lognormal(mu, sigma) - if ( - ts.isPropertyAccessExpression(callee) && - ts.isIdentifier(callee.expression) && - callee.expression.text === "Distribution" - ) { - const distName = callee.name.text; - const args: string[] = []; - for (const a of node.arguments) { - const r = emitSymPy(a, context, localBindings, innerParams, sourceFile); - if (!r.ok) return r; - args.push(r.sympyCode); - } - - switch (distName) { - case "Gaussian": - return { - ok: true, - sympyCode: `sp.stats.Normal('X', ${args.join(", ")})`, - }; - case "Uniform": - return { - ok: true, - sympyCode: `sp.stats.Uniform('X', ${args.join(", ")})`, - }; - case "Lognormal": - return { - ok: true, - sympyCode: `sp.stats.LogNormal('X', ${args.join(", ")})`, - }; - default: - return err( - `Unsupported distribution: Distribution.${distName}`, - callee, - sourceFile, - ); - } - } - - // Global built-in functions: Boolean(expr), Number(expr) - if (ts.isIdentifier(callee)) { - if (callee.text === "Boolean" && node.arguments.length === 1) { - const arg = emitSymPy( - node.arguments[0]!, - context, - localBindings, - innerParams, - sourceFile, - ); - if (!arg.ok) return arg; - return { ok: true, sympyCode: `sp.Ne(${arg.sympyCode}, 0)` }; - } - - if (callee.text === "Number" && node.arguments.length === 1) { - return emitSymPy( - node.arguments[0]!, - context, - localBindings, - innerParams, - sourceFile, - ); - } - } - - // .map(callback) on arrays/tokens — emit as Python list comprehension - if ( - ts.isPropertyAccessExpression(callee) && - callee.name.text === "map" && - node.arguments.length === 1 - ) { - const callback = node.arguments[0]!; - if (ts.isArrowFunction(callback) || ts.isFunctionExpression(callback)) { - return compileMapCall( - callee.expression, - callback, - context, - localBindings, - innerParams, - sourceFile, - ); - } - } - - return err( - `Unsupported function call: ${callee.getText(sourceFile)}`, - node, - sourceFile, - ); - } - - // Array literal expression [a, b, c] - if (ts.isArrayLiteralExpression(node)) { - const elements: string[] = []; - for (const elem of node.elements) { - const result = emitSymPy( - elem, - context, - localBindings, - innerParams, - sourceFile, - ); - if (!result.ok) return result; - elements.push(result.sympyCode); - } - return { ok: true, sympyCode: `[${elements.join(", ")}]` }; - } - - // Object literal expression { field: expr, ... } - if (ts.isObjectLiteralExpression(node)) { - const entries: string[] = []; - for (const prop of node.properties) { - if (!ts.isPropertyAssignment(prop)) { - return err( - `Unsupported object property kind: ${ts.SyntaxKind[prop.kind]}`, - prop, - sourceFile, - ); - } - const key = prop.name.getText(sourceFile); - const val = emitSymPy( - prop.initializer, - context, - localBindings, - innerParams, - sourceFile, - ); - if (!val.ok) return val; - entries.push(`'${key}': ${val.sympyCode}`); - } - return { ok: true, sympyCode: `{${entries.join(", ")}}` }; - } - - // Non-null assertion (x!) — just unwrap - if (ts.isNonNullExpression(node)) { - return emitSymPy( - node.expression, - context, - localBindings, - innerParams, - sourceFile, - ); - } - - // Type assertion (x as T) — just unwrap - if (ts.isAsExpression(node)) { - return emitSymPy( - node.expression, - context, - localBindings, - innerParams, - sourceFile, - ); - } - - return err( - `Unsupported syntax: ${ts.SyntaxKind[node.kind]}`, - node, - sourceFile, - ); + const irResult = compileToIR(code, context); + if (!irResult.ok) return irResult; + return { ok: true, sympyCode: irToSymPy(irResult.ir) }; } diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/expression-ir.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/expression-ir.ts new file mode 100644 index 00000000000..e5cc5d14cd5 --- /dev/null +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/expression-ir.ts @@ -0,0 +1,169 @@ +/** + * JSON-based intermediate representation for mathematical expressions. + * + * This IR captures the semantic structure of expressions parsed from + * TypeScript source code — bindings, expressions, and probability + * distributions — providing a JSON-serializable format that can be + * translated to various backends (e.g., SymPy). + */ + +export type ExpressionIR = + | NumberNode + | BooleanNode + | InfinityNode + | SymbolNode + | ParameterNode + | TokenAccessNode + | BinaryNode + | UnaryNode + | CallNode + | DistributionNode + | DerivedDistributionNode + | PiecewiseNode + | ArrayNode + | ObjectNode + | ListComprehensionNode + | LetNode + | PropertyAccessNode + | ElementAccessNode; + +export type NumberNode = { + type: "number"; + /** Exact string representation from source, e.g. "3.14" */ + value: string; +}; + +export type BooleanNode = { + type: "boolean"; + value: boolean; +}; + +export type InfinityNode = { + type: "infinity"; +}; + +/** A generic symbol (local binding, iterator variable, etc.) */ +export type SymbolNode = { + type: "symbol"; + name: string; +}; + +/** A model parameter reference (from `parameters.`) */ +export type ParameterNode = { + type: "parameter"; + name: string; +}; + +/** Token field access: `tokens.[].` */ +export type TokenAccessNode = { + type: "tokenAccess"; + place: string; + index: ExpressionIR; + field: string; +}; + +export type BinaryOp = + | "+" + | "-" + | "*" + | "/" + | "**" + | "%" + | "<" + | "<=" + | ">" + | ">=" + | "==" + | "!=" + | "&&" + | "||"; + +export type BinaryNode = { + type: "binary"; + op: BinaryOp; + left: ExpressionIR; + right: ExpressionIR; +}; + +export type UnaryOp = "-" | "!" | "+"; + +export type UnaryNode = { + type: "unary"; + op: UnaryOp; + operand: ExpressionIR; +}; + +/** A math function call (e.g. cos, sin, sqrt, hypot, pow, min, max) */ +export type CallNode = { + type: "call"; + fn: string; + args: ExpressionIR[]; +}; + +/** A probability distribution (e.g. Gaussian, Uniform, Lognormal) */ +export type DistributionNode = { + type: "distribution"; + distribution: string; + args: ExpressionIR[]; +}; + +/** + * A distribution transformed by a function: `dist.map(fn)`. + * + * Example: `Distribution.Gaussian(0, 10).map(Math.cos)` produces a + * derived distribution where samples are drawn from the base and then + * transformed through the body expression. + */ +export type DerivedDistributionNode = { + type: "derivedDistribution"; + distribution: ExpressionIR; + variable: string; + body: ExpressionIR; +}; + +/** Conditional expression (ternary) */ +export type PiecewiseNode = { + type: "piecewise"; + condition: ExpressionIR; + whenTrue: ExpressionIR; + whenFalse: ExpressionIR; +}; + +export type ArrayNode = { + type: "array"; + elements: ExpressionIR[]; +}; + +export type ObjectNode = { + type: "object"; + entries: { key: string; value: ExpressionIR }[]; +}; + +/** List comprehension from `.map()` calls */ +export type ListComprehensionNode = { + type: "listComprehension"; + body: ExpressionIR; + variable: string; + collection: ExpressionIR; +}; + +/** Scoped const bindings wrapping a body expression */ +export type LetNode = { + type: "let"; + bindings: { name: string; value: ExpressionIR }[]; + body: ExpressionIR; +}; + +/** Fallback property access: `.` */ +export type PropertyAccessNode = { + type: "propertyAccess"; + object: ExpressionIR; + property: string; +}; + +/** Fallback element access: `[]` */ +export type ElementAccessNode = { + type: "elementAccess"; + object: ExpressionIR; + index: ExpressionIR; +}; diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/ir-to-sympy.test.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/ir-to-sympy.test.ts new file mode 100644 index 00000000000..b7579ca5123 --- /dev/null +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/ir-to-sympy.test.ts @@ -0,0 +1,333 @@ +import { describe, expect, it } from "vitest"; + +import type { ExpressionIR } from "./expression-ir"; +import { irToSymPy } from "./ir-to-sympy"; + +describe("irToSymPy", () => { + describe("literals", () => { + it("should emit number", () => { + expect(irToSymPy({ type: "number", value: "42" })).toBe("42"); + }); + + it("should emit boolean true as True", () => { + expect(irToSymPy({ type: "boolean", value: true })).toBe("True"); + }); + + it("should emit boolean false as False", () => { + expect(irToSymPy({ type: "boolean", value: false })).toBe("False"); + }); + + it("should emit infinity as sp.oo", () => { + expect(irToSymPy({ type: "infinity" })).toBe("sp.oo"); + }); + }); + + describe("symbols and parameters", () => { + it("should emit symbol name", () => { + expect(irToSymPy({ type: "symbol", name: "x" })).toBe("x"); + }); + + it("should emit parameter name", () => { + expect(irToSymPy({ type: "parameter", name: "infection_rate" })).toBe( + "infection_rate", + ); + }); + + it("should emit Math constants", () => { + expect(irToSymPy({ type: "symbol", name: "PI" })).toBe("sp.pi"); + expect(irToSymPy({ type: "symbol", name: "E" })).toBe("sp.E"); + }); + }); + + describe("token access", () => { + it("should emit Place_index_field format", () => { + expect( + irToSymPy({ + type: "tokenAccess", + place: "Space", + index: { type: "number", value: "0" }, + field: "x", + }), + ).toBe("Space_0_x"); + }); + }); + + describe("binary operations", () => { + const left: ExpressionIR = { type: "number", value: "1" }; + const right: ExpressionIR = { type: "number", value: "2" }; + + it("should emit arithmetic operators", () => { + expect(irToSymPy({ type: "binary", op: "+", left, right })).toBe("1 + 2"); + expect(irToSymPy({ type: "binary", op: "*", left, right })).toBe("1 * 2"); + expect(irToSymPy({ type: "binary", op: "**", left, right })).toBe("1**2"); + }); + + it("should emit modulo as sp.Mod", () => { + expect(irToSymPy({ type: "binary", op: "%", left, right })).toBe( + "sp.Mod(1, 2)", + ); + }); + + it("should emit equality as sp.Eq", () => { + expect(irToSymPy({ type: "binary", op: "==", left, right })).toBe( + "sp.Eq(1, 2)", + ); + }); + + it("should emit inequality as sp.Ne", () => { + expect(irToSymPy({ type: "binary", op: "!=", left, right })).toBe( + "sp.Ne(1, 2)", + ); + }); + + it("should emit logical operators", () => { + expect(irToSymPy({ type: "binary", op: "&&", left, right })).toBe( + "sp.And(1, 2)", + ); + expect(irToSymPy({ type: "binary", op: "||", left, right })).toBe( + "sp.Or(1, 2)", + ); + }); + }); + + describe("unary operations", () => { + it("should emit negation", () => { + expect( + irToSymPy({ + type: "unary", + op: "-", + operand: { type: "symbol", name: "x" }, + }), + ).toBe("-(x)"); + }); + + it("should emit logical not", () => { + expect( + irToSymPy({ + type: "unary", + op: "!", + operand: { type: "boolean", value: true }, + }), + ).toBe("sp.Not(True)"); + }); + }); + + describe("function calls", () => { + it("should emit math functions with sp. prefix", () => { + expect( + irToSymPy({ + type: "call", + fn: "cos", + args: [{ type: "symbol", name: "x" }], + }), + ).toBe("sp.cos(x)"); + }); + + it("should emit hypot as sqrt of sum of squares", () => { + expect( + irToSymPy({ + type: "call", + fn: "hypot", + args: [ + { type: "symbol", name: "a" }, + { type: "symbol", name: "b" }, + ], + }), + ).toBe("sp.sqrt((a)**2 + (b)**2)"); + }); + + it("should emit pow as exponentiation", () => { + expect( + irToSymPy({ + type: "call", + fn: "pow", + args: [ + { type: "symbol", name: "a" }, + { type: "number", value: "2" }, + ], + }), + ).toBe("(a)**(2)"); + }); + }); + + describe("distributions", () => { + it("should emit Gaussian as sp.stats.Normal", () => { + expect( + irToSymPy({ + type: "distribution", + distribution: "Gaussian", + args: [ + { type: "number", value: "0" }, + { type: "number", value: "1" }, + ], + }), + ).toBe("sp.stats.Normal('X', 0, 1)"); + }); + + it("should emit Lognormal as sp.stats.LogNormal", () => { + expect( + irToSymPy({ + type: "distribution", + distribution: "Lognormal", + args: [ + { type: "number", value: "0" }, + { type: "number", value: "1" }, + ], + }), + ).toBe("sp.stats.LogNormal('X', 0, 1)"); + }); + }); + + describe("derived distributions", () => { + it("should emit DerivedDistribution with lambda", () => { + expect( + irToSymPy({ + type: "derivedDistribution", + distribution: { + type: "distribution", + distribution: "Gaussian", + args: [ + { type: "number", value: "0" }, + { type: "number", value: "10" }, + ], + }, + variable: "_x", + body: { + type: "call", + fn: "cos", + args: [{ type: "symbol", name: "_x" }], + }, + }), + ).toBe( + "DerivedDistribution(sp.stats.Normal('X', 0, 10), lambda _x: sp.cos(_x))", + ); + }); + }); + + describe("piecewise", () => { + it("should emit sp.Piecewise", () => { + expect( + irToSymPy({ + type: "piecewise", + condition: { + type: "binary", + op: ">", + left: { type: "symbol", name: "x" }, + right: { type: "number", value: "0" }, + }, + whenTrue: { type: "symbol", name: "x" }, + whenFalse: { type: "number", value: "0" }, + }), + ).toBe("sp.Piecewise((x, x > 0), (0, True))"); + }); + }); + + describe("collections", () => { + it("should emit array as Python list", () => { + expect( + irToSymPy({ + type: "array", + elements: [ + { type: "number", value: "1" }, + { type: "number", value: "2" }, + ], + }), + ).toBe("[1, 2]"); + }); + + it("should emit object as Python dict", () => { + expect( + irToSymPy({ + type: "object", + entries: [ + { key: "x", value: { type: "number", value: "1" } }, + { key: "y", value: { type: "number", value: "2" } }, + ], + }), + ).toBe("{'x': 1, 'y': 2}"); + }); + }); + + describe("list comprehension", () => { + it("should emit Python list comprehension", () => { + expect( + irToSymPy({ + type: "listComprehension", + variable: "_iter", + collection: { type: "symbol", name: "tokens" }, + body: { + type: "binary", + op: "+", + left: { type: "symbol", name: "_iter" }, + right: { type: "number", value: "1" }, + }, + }), + ).toBe("[_iter + 1 for _iter in tokens]"); + }); + }); + + describe("let bindings", () => { + it("should inline single binding", () => { + expect( + irToSymPy({ + type: "let", + bindings: [ + { + name: "mu", + value: { type: "parameter", name: "gravitational_constant" }, + }, + ], + body: { + type: "binary", + op: "*", + left: { type: "symbol", name: "mu" }, + right: { type: "number", value: "2" }, + }, + }), + ).toBe("gravitational_constant * 2"); + }); + + it("should inline chained bindings", () => { + expect( + irToSymPy({ + type: "let", + bindings: [ + { name: "a", value: { type: "parameter", name: "infection_rate" } }, + { + name: "b", + value: { type: "parameter", name: "recovery_rate" }, + }, + ], + body: { + type: "binary", + op: "+", + left: { type: "symbol", name: "a" }, + right: { type: "symbol", name: "b" }, + }, + }), + ).toBe("infection_rate + recovery_rate"); + }); + }); + + describe("property and element access fallbacks", () => { + it("should emit property access with underscore", () => { + expect( + irToSymPy({ + type: "propertyAccess", + object: { type: "symbol", name: "obj" }, + property: "field", + }), + ).toBe("obj_field"); + }); + + it("should emit element access with underscore", () => { + expect( + irToSymPy({ + type: "elementAccess", + object: { type: "symbol", name: "arr" }, + index: { type: "number", value: "0" }, + }), + ).toBe("arr_0"); + }); + }); +}); diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/ir-to-sympy.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/ir-to-sympy.ts new file mode 100644 index 00000000000..f9b896351ee --- /dev/null +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/ir-to-sympy.ts @@ -0,0 +1,207 @@ +import type { ExpressionIR } from "./expression-ir"; + +const MATH_FN_MAP: Record = { + cos: "sp.cos", + sin: "sp.sin", + tan: "sp.tan", + acos: "sp.acos", + asin: "sp.asin", + atan: "sp.atan", + atan2: "sp.atan2", + sqrt: "sp.sqrt", + log: "sp.log", + exp: "sp.exp", + abs: "sp.Abs", + floor: "sp.floor", + ceil: "sp.ceiling", + min: "sp.Min", + max: "sp.Max", +}; + +const MATH_CONSTANT_MAP: Record = { + PI: "sp.pi", + E: "sp.E", +}; + +const DISTRIBUTION_MAP: Record = { + Gaussian: "sp.stats.Normal", + Uniform: "sp.stats.Uniform", + Lognormal: "sp.stats.LogNormal", +}; + +/** + * Converts an expression IR node to SymPy Python code. + * + * Let-bindings are inlined: the binding's value replaces all references + * to the binding name in the body expression. + */ +export function irToSymPy( + node: ExpressionIR, + env: Map = new Map(), +): string { + switch (node.type) { + case "number": + return node.value; + + case "boolean": + return node.value ? "True" : "False"; + + case "infinity": + return "sp.oo"; + + case "symbol": { + const constant = MATH_CONSTANT_MAP[node.name]; + if (constant) return constant; + return env.get(node.name) ?? node.name; + } + + case "parameter": + return node.name; + + case "tokenAccess": { + const index = irToSymPy(node.index, env); + return `${node.place}_${index}_${node.field}`; + } + + case "binary": { + const left = irToSymPy(node.left, env); + const right = irToSymPy(node.right, env); + return emitBinaryOp(node.op, left, right); + } + + case "unary": { + const operand = irToSymPy(node.operand, env); + switch (node.op) { + case "-": + return `-(${operand})`; + case "!": + return `sp.Not(${operand})`; + case "+": + return operand; + } + break; + } + + case "call": + return emitCall(node.fn, node.args, env); + + case "distribution": { + const distFn = DISTRIBUTION_MAP[node.distribution]; + const args = node.args.map((a) => irToSymPy(a, env)); + return `${distFn}('X', ${args.join(", ")})`; + } + + case "derivedDistribution": { + const dist = irToSymPy(node.distribution, env); + const localEnv = new Map(env); + localEnv.set(node.variable, node.variable); + const body = irToSymPy(node.body, localEnv); + return `DerivedDistribution(${dist}, lambda ${node.variable}: ${body})`; + } + + case "piecewise": { + const condition = irToSymPy(node.condition, env); + const whenTrue = irToSymPy(node.whenTrue, env); + const whenFalse = irToSymPy(node.whenFalse, env); + return `sp.Piecewise((${whenTrue}, ${condition}), (${whenFalse}, True))`; + } + + case "array": { + const elements = node.elements.map((e) => irToSymPy(e, env)); + return `[${elements.join(", ")}]`; + } + + case "object": { + const entries = node.entries.map( + (e) => `'${e.key}': ${irToSymPy(e.value, env)}`, + ); + return `{${entries.join(", ")}}`; + } + + case "listComprehension": { + const body = irToSymPy(node.body, env); + const collection = irToSymPy(node.collection, env); + return `[${body} for ${node.variable} in ${collection}]`; + } + + case "let": { + const localEnv = new Map(env); + for (const binding of node.bindings) { + localEnv.set(binding.name, irToSymPy(binding.value, localEnv)); + } + return irToSymPy(node.body, localEnv); + } + + case "propertyAccess": { + const obj = irToSymPy(node.object, env); + return `${obj}_${node.property}`; + } + + case "elementAccess": { + const obj = irToSymPy(node.object, env); + const index = irToSymPy(node.index, env); + return `${obj}_${index}`; + } + } +} + +function emitBinaryOp(op: string, left: string, right: string): string { + switch (op) { + case "+": + return `${left} + ${right}`; + case "-": + return `${left} - ${right}`; + case "*": + return `${left} * ${right}`; + case "/": + return `${left} / ${right}`; + case "**": + return `${left}**${right}`; + case "%": + return `sp.Mod(${left}, ${right})`; + case "<": + return `${left} < ${right}`; + case "<=": + return `${left} <= ${right}`; + case ">": + return `${left} > ${right}`; + case ">=": + return `${left} >= ${right}`; + case "==": + return `sp.Eq(${left}, ${right})`; + case "!=": + return `sp.Ne(${left}, ${right})`; + case "&&": + return `sp.And(${left}, ${right})`; + case "||": + return `sp.Or(${left}, ${right})`; + default: + return `${left} ${op} ${right}`; + } +} + +function emitCall( + fn: string, + args: ExpressionIR[], + env: Map, +): string { + const compiledArgs = args.map((a) => irToSymPy(a, env)); + + // Math.hypot(a, b) → sp.sqrt(a**2 + b**2) + if (fn === "hypot") { + const sumOfSquares = compiledArgs.map((a) => `(${a})**2`).join(" + "); + return `sp.sqrt(${sumOfSquares})`; + } + + // Math.pow(a, b) → a**b + if (fn === "pow" && compiledArgs.length === 2) { + return `(${compiledArgs[0]!})**(${compiledArgs[1]!})`; + } + + const sympyFn = MATH_FN_MAP[fn]; + if (sympyFn) { + return `${sympyFn}(${compiledArgs.join(", ")})`; + } + + return `${fn}(${compiledArgs.join(", ")})`; +} From d16b505677826aa65aca924935a482751655d49e Mon Sep 17 00:00:00 2001 From: Chris Feijoo Date: Wed, 18 Mar 2026 14:09:43 +0100 Subject: [PATCH 07/14] FE-514: Move expression IR and SymPy compiler to src/expression Reorganize into src/expression/{ts-to-ir,ir-to-sympy} and remove the compile-to-sympy convenience wrapper. The LSP checker now uses compileToIR directly and reports "Invalid expression" instead of "SymPy" in diagnostics. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../simulator => expression}/expression-ir.ts | 0 .../ir-to-sympy}/ir-to-sympy.test.ts | 2 +- .../ir-to-sympy}/ir-to-sympy.ts | 2 +- .../ts-to-ir}/compile-to-ir.stories.tsx | 0 .../ts-to-ir}/compile-to-ir.test.ts | 2 +- .../ts-to-ir}/compile-to-ir.ts | 2 +- .../petrinaut/src/lsp/lib/checker.ts | 56 +- .../src/lsp/worker/language-server.worker.ts | 2 +- .../simulator/compile-to-sympy.test.ts | 821 ------------------ .../simulation/simulator/compile-to-sympy.ts | 41 - .../src/views/Editor/lib/export-sympy.ts | 14 +- 11 files changed, 45 insertions(+), 897 deletions(-) rename libs/@hashintel/petrinaut/src/{simulation/simulator => expression}/expression-ir.ts (100%) rename libs/@hashintel/petrinaut/src/{simulation/simulator => expression/ir-to-sympy}/ir-to-sympy.test.ts (99%) rename libs/@hashintel/petrinaut/src/{simulation/simulator => expression/ir-to-sympy}/ir-to-sympy.ts (98%) rename libs/@hashintel/petrinaut/src/{simulation/simulator => expression/ts-to-ir}/compile-to-ir.stories.tsx (100%) rename libs/@hashintel/petrinaut/src/{simulation/simulator => expression/ts-to-ir}/compile-to-ir.test.ts (99%) rename libs/@hashintel/petrinaut/src/{simulation/simulator => expression/ts-to-ir}/compile-to-ir.ts (99%) delete mode 100644 libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts delete mode 100644 libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/expression-ir.ts b/libs/@hashintel/petrinaut/src/expression/expression-ir.ts similarity index 100% rename from libs/@hashintel/petrinaut/src/simulation/simulator/expression-ir.ts rename to libs/@hashintel/petrinaut/src/expression/expression-ir.ts diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/ir-to-sympy.test.ts b/libs/@hashintel/petrinaut/src/expression/ir-to-sympy/ir-to-sympy.test.ts similarity index 99% rename from libs/@hashintel/petrinaut/src/simulation/simulator/ir-to-sympy.test.ts rename to libs/@hashintel/petrinaut/src/expression/ir-to-sympy/ir-to-sympy.test.ts index b7579ca5123..f5fb265267f 100644 --- a/libs/@hashintel/petrinaut/src/simulation/simulator/ir-to-sympy.test.ts +++ b/libs/@hashintel/petrinaut/src/expression/ir-to-sympy/ir-to-sympy.test.ts @@ -1,6 +1,6 @@ import { describe, expect, it } from "vitest"; -import type { ExpressionIR } from "./expression-ir"; +import type { ExpressionIR } from "../expression-ir"; import { irToSymPy } from "./ir-to-sympy"; describe("irToSymPy", () => { diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/ir-to-sympy.ts b/libs/@hashintel/petrinaut/src/expression/ir-to-sympy/ir-to-sympy.ts similarity index 98% rename from libs/@hashintel/petrinaut/src/simulation/simulator/ir-to-sympy.ts rename to libs/@hashintel/petrinaut/src/expression/ir-to-sympy/ir-to-sympy.ts index f9b896351ee..fbb54ecf1ad 100644 --- a/libs/@hashintel/petrinaut/src/simulation/simulator/ir-to-sympy.ts +++ b/libs/@hashintel/petrinaut/src/expression/ir-to-sympy/ir-to-sympy.ts @@ -1,4 +1,4 @@ -import type { ExpressionIR } from "./expression-ir"; +import type { ExpressionIR } from "../expression-ir"; const MATH_FN_MAP: Record = { cos: "sp.cos", diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-ir.stories.tsx b/libs/@hashintel/petrinaut/src/expression/ts-to-ir/compile-to-ir.stories.tsx similarity index 100% rename from libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-ir.stories.tsx rename to libs/@hashintel/petrinaut/src/expression/ts-to-ir/compile-to-ir.stories.tsx diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-ir.test.ts b/libs/@hashintel/petrinaut/src/expression/ts-to-ir/compile-to-ir.test.ts similarity index 99% rename from libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-ir.test.ts rename to libs/@hashintel/petrinaut/src/expression/ts-to-ir/compile-to-ir.test.ts index 6a2956dab66..26a36f8b3f4 100644 --- a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-ir.test.ts +++ b/libs/@hashintel/petrinaut/src/expression/ts-to-ir/compile-to-ir.test.ts @@ -1,8 +1,8 @@ import { describe, expect, it } from "vitest"; +import type { ExpressionIR } from "../expression-ir"; import type { CompilationContext } from "./compile-to-ir"; import { compileToIR } from "./compile-to-ir"; -import type { ExpressionIR } from "./expression-ir"; const defaultContext: CompilationContext = { parameterNames: new Set([ diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-ir.ts b/libs/@hashintel/petrinaut/src/expression/ts-to-ir/compile-to-ir.ts similarity index 99% rename from libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-ir.ts rename to libs/@hashintel/petrinaut/src/expression/ts-to-ir/compile-to-ir.ts index bc59ba1dc17..97f93192ecb 100644 --- a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-ir.ts +++ b/libs/@hashintel/petrinaut/src/expression/ts-to-ir/compile-to-ir.ts @@ -1,7 +1,7 @@ import ts from "typescript"; import type { SDCPN, Transition } from "../../core/types/sdcpn"; -import type { BinaryOp, ExpressionIR, UnaryOp } from "./expression-ir"; +import type { BinaryOp, ExpressionIR, UnaryOp } from "../expression-ir"; /** * Context for compilation, derived from the SDCPN model. diff --git a/libs/@hashintel/petrinaut/src/lsp/lib/checker.ts b/libs/@hashintel/petrinaut/src/lsp/lib/checker.ts index c456d8e8576..54ae0a3847e 100644 --- a/libs/@hashintel/petrinaut/src/lsp/lib/checker.ts +++ b/libs/@hashintel/petrinaut/src/lsp/lib/checker.ts @@ -4,9 +4,9 @@ import type { SDCPN } from "../../core/types/sdcpn"; import { buildContextForDifferentialEquation, buildContextForTransition, - compileToSymPy, - type SymPyResult, -} from "../../simulation/simulator/compile-to-sympy"; + compileToIR, + type IRResult, +} from "../../expression/ts-to-ir/compile-to-ir"; import type { SDCPNLanguageServer } from "./create-sdcpn-language-service"; import { getItemFilePath } from "./file-paths"; @@ -31,22 +31,22 @@ export type SDCPNCheckResult = { isValid: boolean; /** TypeScript error diagnostics grouped by item */ itemDiagnostics: SDCPNDiagnostic[]; - /** SymPy compilation warning diagnostics (informational, do not affect validity) */ - sympyDiagnostics: SDCPNDiagnostic[]; + /** Expression IR diagnostics (unsupported math expressions, do not affect validity) */ + expressionDiagnostics: SDCPNDiagnostic[]; }; /** - * Creates a synthetic ts.Diagnostic from a SymPy compilation error result. - * Uses category 0 (Warning) since SymPy compilation failures are informational - * — the TypeScript code may still be valid, just not convertible to SymPy. + * Creates a synthetic ts.Diagnostic from an expression IR compilation error. + * Uses category 0 (Warning) since the TypeScript code may still be valid, + * just not representable as a pure mathematical expression. */ -function makeSymPyDiagnostic( - result: SymPyResult & { ok: false }, +function makeExpressionDiagnostic( + result: IRResult & { ok: false }, ): ts.Diagnostic { return { category: 0, // Warning - code: 99000, // Custom code for SymPy diagnostics - messageText: `SymPy: ${result.error}`, + code: 99000, + messageText: `Invalid expression: ${result.error}`, file: undefined, start: result.start, length: result.length, @@ -54,17 +54,17 @@ function makeSymPyDiagnostic( } /** - * Appends a SymPy diagnostic to the item diagnostics list, merging with + * Appends an expression diagnostic to the item diagnostics list, merging with * any existing entry for the same item. */ -function appendSymPyDiagnostic( +function appendExpressionDiagnostic( itemDiagnostics: SDCPNDiagnostic[], itemId: string, itemType: ItemType, filePath: string, - result: SymPyResult & { ok: false }, + result: IRResult & { ok: false }, ): void { - const diag = makeSymPyDiagnostic(result); + const diag = makeExpressionDiagnostic(result); const existing = itemDiagnostics.find( (di) => di.itemId === itemId && di.itemType === itemType, ); @@ -76,20 +76,20 @@ function appendSymPyDiagnostic( } /** - * Runs SymPy compilation on all SDCPN code expressions and appends - * any errors as warning diagnostics. + * Validates all SDCPN code expressions as mathematical expressions by + * compiling them to the expression IR, appending any errors as warnings. */ -function checkSymPyCompilation(sdcpn: SDCPN): SDCPNDiagnostic[] { +function checkExpressions(sdcpn: SDCPN): SDCPNDiagnostic[] { const itemDiagnostics: SDCPNDiagnostic[] = []; // Check differential equations for (const de of sdcpn.differentialEquations) { const ctx = buildContextForDifferentialEquation(sdcpn, de.colorId); - const result = compileToSymPy(de.code, ctx); + const result = compileToIR(de.code, ctx); if (!result.ok) { const filePath = getItemFilePath("differential-equation-code", { id: de.id, }); - appendSymPyDiagnostic( + appendExpressionDiagnostic( itemDiagnostics, de.id, "differential-equation", @@ -102,12 +102,12 @@ function checkSymPyCompilation(sdcpn: SDCPN): SDCPNDiagnostic[] { // Check transition lambdas and kernels for (const transition of sdcpn.transitions) { const lambdaCtx = buildContextForTransition(sdcpn, transition, "Lambda"); - const lambdaResult = compileToSymPy(transition.lambdaCode, lambdaCtx); + const lambdaResult = compileToIR(transition.lambdaCode, lambdaCtx); if (!lambdaResult.ok) { const filePath = getItemFilePath("transition-lambda-code", { transitionId: transition.id, }); - appendSymPyDiagnostic( + appendExpressionDiagnostic( itemDiagnostics, transition.id, "transition-lambda", @@ -129,7 +129,7 @@ function checkSymPyCompilation(sdcpn: SDCPN): SDCPNDiagnostic[] { transition, "TransitionKernel", ); - const kernelResult = compileToSymPy( + const kernelResult = compileToIR( transition.transitionKernelCode, kernelCtx, ); @@ -137,7 +137,7 @@ function checkSymPyCompilation(sdcpn: SDCPN): SDCPNDiagnostic[] { const filePath = getItemFilePath("transition-kernel-code", { transitionId: transition.id, }); - appendSymPyDiagnostic( + appendExpressionDiagnostic( itemDiagnostics, transition.id, "transition-kernel", @@ -235,12 +235,12 @@ export function checkSDCPN( } } - // Run SymPy compilation checks on all code expressions - const sympyDiagnostics = checkSymPyCompilation(sdcpn); + // Validate expressions as mathematical IR + const expressionDiagnostics = checkExpressions(sdcpn); return { isValid: itemDiagnostics.length === 0, itemDiagnostics, - sympyDiagnostics, + expressionDiagnostics, }; } diff --git a/libs/@hashintel/petrinaut/src/lsp/worker/language-server.worker.ts b/libs/@hashintel/petrinaut/src/lsp/worker/language-server.worker.ts index 1636a286ed0..31cf8dd7a84 100644 --- a/libs/@hashintel/petrinaut/src/lsp/worker/language-server.worker.ts +++ b/libs/@hashintel/petrinaut/src/lsp/worker/language-server.worker.ts @@ -62,7 +62,7 @@ function publishAllDiagnostics(sdcpn: SDCPN): void { const result = checkSDCPN(sdcpn, server); const allDiagnostics = [ ...result.itemDiagnostics, - ...result.sympyDiagnostics, + ...result.expressionDiagnostics, ]; const params: PublishDiagnosticsParams[] = allDiagnostics.map((item) => { const uri = filePathToUri(item.filePath); diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts deleted file mode 100644 index 719bc1f1a0a..00000000000 --- a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts +++ /dev/null @@ -1,821 +0,0 @@ -import { describe, expect, it } from "vitest"; - -import { - compileToSymPy, - type SymPyCompilationContext, -} from "./compile-to-sympy"; - -const defaultContext: SymPyCompilationContext = { - parameterNames: new Set([ - "infection_rate", - "recovery_rate", - "gravitational_constant", - "earth_radius", - "satellite_radius", - "crash_threshold", - ]), - placeTokenFields: new Map([ - ["Space", ["x", "y", "direction", "velocity"]], - ["Susceptible", []], - ["Infected", []], - ]), - constructorFnName: "Lambda", -}; - -function dynamicsContext(): SymPyCompilationContext { - return { ...defaultContext, constructorFnName: "Dynamics" }; -} - -function kernelContext(): SymPyCompilationContext { - return { ...defaultContext, constructorFnName: "TransitionKernel" }; -} - -describe("compileToSymPy", () => { - describe("basic expressions", () => { - it("should compile a numeric literal", () => { - const result = compileToSymPy( - "export default Lambda(() => 1)", - defaultContext, - ); - expect(result).toEqual({ ok: true, sympyCode: "1" }); - }); - - it("should compile a decimal literal", () => { - const result = compileToSymPy( - "export default Lambda(() => 3.14)", - defaultContext, - ); - expect(result).toEqual({ ok: true, sympyCode: "3.14" }); - }); - - it("should compile boolean true", () => { - const result = compileToSymPy( - "export default Lambda(() => true)", - defaultContext, - ); - expect(result).toEqual({ ok: true, sympyCode: "True" }); - }); - - it("should compile boolean false", () => { - const result = compileToSymPy( - "export default Lambda(() => false)", - defaultContext, - ); - expect(result).toEqual({ ok: true, sympyCode: "False" }); - }); - - it("should compile Infinity", () => { - const result = compileToSymPy( - "export default Lambda(() => Infinity)", - defaultContext, - ); - expect(result).toEqual({ ok: true, sympyCode: "sp.oo" }); - }); - }); - - describe("parameter access", () => { - it("should compile parameters.x to symbol x", () => { - const result = compileToSymPy( - "export default Lambda((tokens, parameters) => parameters.infection_rate)", - defaultContext, - ); - expect(result).toEqual({ ok: true, sympyCode: "infection_rate" }); - }); - - it("should compile parameters in arithmetic", () => { - const result = compileToSymPy( - "export default Lambda((tokens, parameters) => parameters.infection_rate * 2)", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "infection_rate * 2", - }); - }); - }); - - describe("binary arithmetic", () => { - it("should compile addition", () => { - const result = compileToSymPy( - "export default Lambda(() => 1 + 2)", - defaultContext, - ); - expect(result).toEqual({ ok: true, sympyCode: "1 + 2" }); - }); - - it("should compile subtraction", () => { - const result = compileToSymPy( - "export default Lambda(() => 5 - 3)", - defaultContext, - ); - expect(result).toEqual({ ok: true, sympyCode: "5 - 3" }); - }); - - it("should compile multiplication", () => { - const result = compileToSymPy( - "export default Lambda(() => 2 * 3)", - defaultContext, - ); - expect(result).toEqual({ ok: true, sympyCode: "2 * 3" }); - }); - - it("should compile division", () => { - const result = compileToSymPy( - "export default Lambda(() => 1 / 3)", - defaultContext, - ); - expect(result).toEqual({ ok: true, sympyCode: "1 / 3" }); - }); - - it("should compile power operator", () => { - const result = compileToSymPy( - "export default Lambda((tokens, parameters) => parameters.satellite_radius ** 2)", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "satellite_radius**2", - }); - }); - - it("should compile modulo", () => { - const result = compileToSymPy( - "export default Lambda(() => 10 % 3)", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "sp.Mod(10, 3)", - }); - }); - }); - - describe("comparison operators", () => { - it("should compile less than", () => { - const result = compileToSymPy( - "export default Lambda((tokens, parameters) => parameters.infection_rate < 5)", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "infection_rate < 5", - }); - }); - - it("should compile greater than or equal", () => { - const result = compileToSymPy( - "export default Lambda((tokens, parameters) => parameters.infection_rate >= 1)", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "infection_rate >= 1", - }); - }); - - it("should compile strict equality to Eq", () => { - const result = compileToSymPy( - "export default Lambda((tokens, parameters) => parameters.infection_rate === 3)", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "sp.Eq(infection_rate, 3)", - }); - }); - - it("should compile inequality to Ne", () => { - const result = compileToSymPy( - "export default Lambda((tokens, parameters) => parameters.infection_rate !== 0)", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "sp.Ne(infection_rate, 0)", - }); - }); - }); - - describe("logical operators", () => { - it("should compile && to sp.And", () => { - const result = compileToSymPy( - "export default Lambda((tokens, parameters) => parameters.infection_rate > 0 && parameters.recovery_rate > 0)", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "sp.And(infection_rate > 0, recovery_rate > 0)", - }); - }); - - it("should compile || to sp.Or", () => { - const result = compileToSymPy( - "export default Lambda((tokens, parameters) => parameters.infection_rate === 0 || parameters.recovery_rate === 0)", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "sp.Or(sp.Eq(infection_rate, 0), sp.Eq(recovery_rate, 0))", - }); - }); - }); - - describe("prefix unary operators", () => { - it("should compile negation", () => { - const result = compileToSymPy( - "export default Lambda((tokens, parameters) => -parameters.infection_rate)", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "-(infection_rate)", - }); - }); - - it("should compile logical not", () => { - const result = compileToSymPy( - "export default Lambda(() => !true)", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "sp.Not(True)", - }); - }); - }); - - describe("Math functions", () => { - it("should compile Math.cos", () => { - const result = compileToSymPy( - "export default Lambda((tokens, parameters) => Math.cos(parameters.infection_rate))", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "sp.cos(infection_rate)", - }); - }); - - it("should compile Math.sin", () => { - const result = compileToSymPy( - "export default Lambda((tokens, parameters) => Math.sin(parameters.infection_rate))", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "sp.sin(infection_rate)", - }); - }); - - it("should compile Math.sqrt", () => { - const result = compileToSymPy( - "export default Lambda((tokens, parameters) => Math.sqrt(parameters.infection_rate))", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "sp.sqrt(infection_rate)", - }); - }); - - it("should compile Math.log", () => { - const result = compileToSymPy( - "export default Lambda((tokens, parameters) => Math.log(parameters.infection_rate))", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "sp.log(infection_rate)", - }); - }); - - it("should compile Math.exp", () => { - const result = compileToSymPy( - "export default Lambda((tokens, parameters) => Math.exp(parameters.infection_rate))", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "sp.exp(infection_rate)", - }); - }); - - it("should compile Math.abs", () => { - const result = compileToSymPy( - "export default Lambda((tokens, parameters) => Math.abs(parameters.infection_rate))", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "sp.Abs(infection_rate)", - }); - }); - - it("should compile Math.pow to exponentiation", () => { - const result = compileToSymPy( - "export default Lambda((tokens, parameters) => Math.pow(parameters.infection_rate, 2))", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "(infection_rate)**(2)", - }); - }); - - it("should compile Math.hypot to sqrt of sum of squares", () => { - const result = compileToSymPy( - "export default Lambda((tokens, parameters) => Math.hypot(parameters.infection_rate, parameters.recovery_rate))", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "sp.sqrt((infection_rate)**2 + (recovery_rate)**2)", - }); - }); - - it("should compile Math.PI", () => { - const result = compileToSymPy( - "export default Lambda(() => Math.PI)", - defaultContext, - ); - expect(result).toEqual({ ok: true, sympyCode: "sp.pi" }); - }); - - it("should compile Math.E", () => { - const result = compileToSymPy( - "export default Lambda(() => Math.E)", - defaultContext, - ); - expect(result).toEqual({ ok: true, sympyCode: "sp.E" }); - }); - }); - - describe("token access", () => { - it("should compile tokens.Place[0].field to symbol", () => { - const result = compileToSymPy( - "export default Lambda((tokens) => tokens.Space[0].x)", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "Space_0_x", - }); - }); - - it("should compile token field comparison", () => { - const result = compileToSymPy( - "export default Lambda((tokens, parameters) => tokens.Space[0].velocity < parameters.crash_threshold)", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "Space_0_velocity < crash_threshold", - }); - }); - }); - - describe("conditional (ternary) expression", () => { - it("should compile to Piecewise", () => { - const result = compileToSymPy( - "export default Lambda((tokens, parameters) => parameters.infection_rate > 1 ? parameters.infection_rate : 0)", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: - "sp.Piecewise((infection_rate, infection_rate > 1), (0, True))", - }); - }); - }); - - describe("Distribution calls", () => { - it("should compile Distribution.Gaussian", () => { - const result = compileToSymPy( - "export default Lambda(() => Distribution.Gaussian(0, 1))", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "sp.stats.Normal('X', 0, 1)", - }); - }); - - it("should compile Distribution.Uniform", () => { - const result = compileToSymPy( - "export default Lambda(() => Distribution.Uniform(0, 1))", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "sp.stats.Uniform('X', 0, 1)", - }); - }); - - it("should compile Distribution.Lognormal", () => { - const result = compileToSymPy( - "export default Lambda(() => Distribution.Lognormal(0, 1))", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "sp.stats.LogNormal('X', 0, 1)", - }); - }); - }); - - describe("derived distributions", () => { - it("should compile distribution.map(Math.cos) end-to-end", () => { - const result = compileToSymPy( - `export default Lambda(() => { - const angle = Distribution.Gaussian(0, 10); - return angle.map(Math.cos); - })`, - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: - "DerivedDistribution(sp.stats.Normal('X', 0, 10), lambda _x: sp.cos(_x))", - }); - }); - - it("should compile inline distribution.map(arrow)", () => { - const result = compileToSymPy( - "export default Lambda((tokens, parameters) => Distribution.Uniform(0, 1).map((x) => x * parameters.infection_rate))", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: - "DerivedDistribution(sp.stats.Uniform('X', 0, 1), lambda _x: _x * infection_rate)", - }); - }); - }); - - describe("global built-in functions", () => { - it("should compile Boolean(expr) to sp.Ne(expr, 0)", () => { - const result = compileToSymPy( - "export default Lambda((tokens, parameters) => Boolean(parameters.infection_rate))", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "sp.Ne(infection_rate, 0)", - }); - }); - - it("should compile Boolean with arithmetic expression", () => { - const result = compileToSymPy( - "export default Lambda(() => Boolean(1 + 2))", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "sp.Ne(1 + 2, 0)", - }); - }); - - it("should compile Number(expr) as identity", () => { - const result = compileToSymPy( - "export default Lambda(() => Number(true))", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "True", - }); - }); - - it("should compile Boolean in block body with return", () => { - const result = compileToSymPy( - `export default Lambda((tokens, parameters) => { - const sum = parameters.infection_rate + parameters.recovery_rate; - return Boolean(sum); - })`, - defaultContext, - ); - expect(result.ok).toBe(true); - if (result.ok) { - expect(result.sympyCode).toContain("sp.Ne"); - } - }); - }); - - describe("block body with const and return", () => { - it("should compile block body with const bindings", () => { - const result = compileToSymPy( - `export default Dynamics((tokens, parameters) => { - const mu = parameters.gravitational_constant; - return mu * 2; - })`, - dynamicsContext(), - ); - expect(result).toEqual({ - ok: true, - sympyCode: "gravitational_constant * 2", - }); - }); - - it("should compile block body with multiple const bindings", () => { - const result = compileToSymPy( - `export default Dynamics((tokens, parameters) => { - const a = parameters.infection_rate; - const b = parameters.recovery_rate; - return a + b; - })`, - dynamicsContext(), - ); - expect(result).toEqual({ - ok: true, - sympyCode: "infection_rate + recovery_rate", - }); - }); - }); - - describe("real-world expressions", () => { - it("should compile SIR infection rate lambda", () => { - const result = compileToSymPy( - "export default Lambda((tokens, parameters) => parameters.infection_rate)", - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "infection_rate", - }); - }); - - it("should compile satellite crash predicate lambda", () => { - const result = compileToSymPy( - `export default Lambda((tokens, parameters) => { - const distance = Math.hypot(tokens.Space[0].x, tokens.Space[0].y); - return distance < parameters.earth_radius + parameters.crash_threshold + parameters.satellite_radius; - })`, - defaultContext, - ); - expect(result.ok).toBe(true); - if (result.ok) { - expect(result.sympyCode).toContain("sp.sqrt"); - expect(result.sympyCode).toContain("<"); - expect(result.sympyCode).toContain("earth_radius"); - } - }); - - it("should compile orbital dynamics expression", () => { - const result = compileToSymPy( - `export default Dynamics((tokens, parameters) => { - const mu = parameters.gravitational_constant; - const r = Math.hypot(tokens.Space[0].x, tokens.Space[0].y); - const ax = (-mu * tokens.Space[0].x) / (r * r * r); - return ax; - })`, - dynamicsContext(), - ); - expect(result.ok).toBe(true); - if (result.ok) { - expect(result.sympyCode).toContain("gravitational_constant"); - expect(result.sympyCode).toContain("Space_0_x"); - } - }); - - it("should compile transition kernel with object literal", () => { - const result = compileToSymPy( - `export default TransitionKernel((tokens) => { - return { - x: tokens.Space[0].x, - y: tokens.Space[0].y, - velocity: 0, - direction: 0 - }; - })`, - kernelContext(), - ); - expect(result.ok).toBe(true); - if (result.ok) { - expect(result.sympyCode).toContain("'x': Space_0_x"); - expect(result.sympyCode).toContain("'y': Space_0_y"); - expect(result.sympyCode).toContain("'velocity': 0"); - } - }); - - it("should compile transition kernel with array of objects", () => { - const result = compileToSymPy( - `export default TransitionKernel((tokens) => { - return { - Debris: [ - { - x: tokens.Space[0].x, - y: tokens.Space[0].y, - velocity: 0, - direction: 0 - }, - { - x: tokens.Space[1].x, - y: tokens.Space[1].y, - velocity: 0, - direction: 0 - }, - ] - }; - })`, - kernelContext(), - ); - expect(result.ok).toBe(true); - if (result.ok) { - expect(result.sympyCode).toContain("'Debris': ["); - expect(result.sympyCode).toContain("'x': Space_0_x"); - expect(result.sympyCode).toContain("'x': Space_1_x"); - } - }); - - it("should compile simple array literal", () => { - const result = compileToSymPy( - "export default Lambda(() => [1, 2, 3])", - defaultContext, - ); - expect(result).toEqual({ ok: true, sympyCode: "[1, 2, 3]" }); - }); - }); - - describe(".map() list comprehension", () => { - it("should compile tokens.map with destructured params", () => { - const result = compileToSymPy( - `export default Dynamics((tokens, parameters) => { - const mu = parameters.gravitational_constant; - return tokens.map(({ x, y, direction, velocity }) => { - const r = Math.hypot(x, y); - const ax = (-mu * x) / (r * r * r); - const ay = (-mu * y) / (r * r * r); - return { - x: velocity * Math.cos(direction), - y: velocity * Math.sin(direction), - direction: (-ax * Math.sin(direction) + ay * Math.cos(direction)) / velocity, - velocity: ax * Math.cos(direction) + ay * Math.sin(direction), - }; - }); - })`, - dynamicsContext(), - ); - expect(result.ok).toBe(true); - if (result.ok) { - expect(result.sympyCode).toContain("for _iter in tokens"); - expect(result.sympyCode).toContain("_iter_x"); - expect(result.sympyCode).toContain("_iter_velocity"); - expect(result.sympyCode).toContain("sp.cos(_iter_direction)"); - } - }); - - it("should compile simple .map with identifier param", () => { - const result = compileToSymPy( - `export default Lambda((tokens) => tokens.map((token) => token + 1))`, - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "[_iter + 1 for _iter in tokens]", - }); - }); - - it("should compile .map with expression body", () => { - const result = compileToSymPy( - `export default Lambda((tokens, parameters) => tokens.map(({ x }) => x * parameters.infection_rate))`, - defaultContext, - ); - expect(result).toEqual({ - ok: true, - sympyCode: "[_iter_x * infection_rate for _iter in tokens]", - }); - }); - }); - - describe("error handling", () => { - it("should reject code without default export", () => { - const result = compileToSymPy( - "const x = Lambda(() => 1);", - defaultContext, - ); - expect(result.ok).toBe(false); - if (!result.ok) { - expect(result.error).toContain("No default export"); - expect(result.start).toBe(0); - expect(result.length).toBe(0); - } - }); - - it("should reject wrong constructor function name with position", () => { - const code = "export default WrongName(() => 1)"; - const result = compileToSymPy(code, defaultContext); - expect(result.ok).toBe(false); - if (!result.ok) { - expect(result.error).toContain("Expected Lambda(...)"); - // "WrongName" starts at position 15 - expect(result.start).toBe(code.indexOf("WrongName")); - expect(result.length).toBe("WrongName".length); - } - }); - - it("should reject unsupported Math function with position", () => { - const code = "export default Lambda(() => Math.random())"; - const result = compileToSymPy(code, defaultContext); - expect(result.ok).toBe(false); - if (!result.ok) { - expect(result.error).toContain("Unsupported Math function"); - // Points to the "Math.random" callee - expect(result.start).toBe(code.indexOf("Math.random")); - expect(result.length).toBe("Math.random".length); - } - }); - - it("should reject if statements in block body with position", () => { - const code = `export default Lambda((tokens, parameters) => { - if (parameters.infection_rate > 1) { - return parameters.infection_rate; - } - return 0; - })`; - const result = compileToSymPy(code, defaultContext); - expect(result.ok).toBe(false); - if (!result.ok) { - expect(result.error).toContain("Unsupported statement"); - expect(result.start).toBe(code.indexOf("if")); - expect(result.length).toBeGreaterThan(0); - } - }); - - it("should reject let declarations with position", () => { - const code = `export default Lambda(() => { - let x = 1; - return x; - })`; - const result = compileToSymPy(code, defaultContext); - expect(result.ok).toBe(false); - if (!result.ok) { - expect(result.error).toContain("let"); - expect(result.start).toBe(code.indexOf("let x = 1;")); - expect(result.length).toBe("let x = 1;".length); - } - }); - - it("should reject var declarations", () => { - const code = `export default Lambda(() => { - var x = 1; - return x; - })`; - const result = compileToSymPy(code, defaultContext); - expect(result.ok).toBe(false); - if (!result.ok) { - expect(result.error).toContain("var"); - expect(result.error).toContain("use 'const'"); - } - }); - - it("should reject string literals with position", () => { - const code = `export default Lambda(() => "hello")`; - const result = compileToSymPy(code, defaultContext); - expect(result.ok).toBe(false); - if (!result.ok) { - expect(result.error).toContain("String literals"); - expect(result.start).toBe(code.indexOf('"hello"')); - expect(result.length).toBe('"hello"'.length); - } - }); - - it("should reject unsupported function calls with position", () => { - const code = `export default Lambda(() => console.log(1))`; - const result = compileToSymPy(code, defaultContext); - expect(result.ok).toBe(false); - if (!result.ok) { - expect(result.error).toContain("Unsupported function call"); - expect(result.start).toBe(code.indexOf("console.log(1)")); - expect(result.length).toBe("console.log(1)".length); - } - }); - - it("should reject standalone expression statements", () => { - const code = `export default Lambda((tokensByPlace, parameters) => { - const a = Boolean(1 + 2); - Boolean(1 + 2); - return Boolean(1 + 2); - })`; - const result = compileToSymPy(code, defaultContext); - expect(result.ok).toBe(false); - if (!result.ok) { - expect(result.error).toContain("Standalone expression has no effect"); - // The standalone expression is the second line in the block - const standalonePos = code.indexOf("\n Boolean(1 + 2);") + 11; - expect(result.start).toBe(standalonePos); - expect(result.length).toBe("Boolean(1 + 2);".length); - } - }); - - it("should reject unsupported binary operator with position", () => { - const code = `export default Lambda(() => 1 << 2)`; - const result = compileToSymPy(code, defaultContext); - expect(result.ok).toBe(false); - if (!result.ok) { - expect(result.error).toContain("Unsupported binary operator"); - expect(result.start).toBe(code.indexOf("<<")); - expect(result.length).toBe("<<".length); - } - }); - }); -}); diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts deleted file mode 100644 index 7cd6c783d3a..00000000000 --- a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts +++ /dev/null @@ -1,41 +0,0 @@ -import type { CompilationContext } from "./compile-to-ir"; -import { compileToIR } from "./compile-to-ir"; -import { irToSymPy } from "./ir-to-sympy"; - -// Re-export types and context builders for backward compatibility -export type { CompilationContext } from "./compile-to-ir"; -export type { IRResult } from "./compile-to-ir"; -export type { ExpressionIR } from "./expression-ir"; -export { - buildContextForDifferentialEquation, - buildContextForTransition, - compileToIR, -} from "./compile-to-ir"; -export { irToSymPy } from "./ir-to-sympy"; - -/** @deprecated Use {@link CompilationContext} instead. */ -export type SymPyCompilationContext = CompilationContext; - -export type SymPyResult = - | { ok: true; sympyCode: string } - | { ok: false; error: string; start: number; length: number }; - -/** - * Compiles a Petrinaut TypeScript expression to SymPy Python code. - * - * This is a convenience wrapper that composes two steps: - * 1. TypeScript → Expression IR ({@link compileToIR}) - * 2. Expression IR → SymPy ({@link irToSymPy}) - * - * @param code - The TypeScript expression code string - * @param context - Compilation context with parameter names and token fields - * @returns Either `{ ok: true, sympyCode }` or `{ ok: false, error }` - */ -export function compileToSymPy( - code: string, - context: CompilationContext, -): SymPyResult { - const irResult = compileToIR(code, context); - if (!irResult.ok) return irResult; - return { ok: true, sympyCode: irToSymPy(irResult.ir) }; -} diff --git a/libs/@hashintel/petrinaut/src/views/Editor/lib/export-sympy.ts b/libs/@hashintel/petrinaut/src/views/Editor/lib/export-sympy.ts index 281549292d3..f5a95468e5c 100644 --- a/libs/@hashintel/petrinaut/src/views/Editor/lib/export-sympy.ts +++ b/libs/@hashintel/petrinaut/src/views/Editor/lib/export-sympy.ts @@ -1,9 +1,10 @@ import type { SDCPN } from "../../../core/types/sdcpn"; +import { irToSymPy } from "../../../expression/ir-to-sympy/ir-to-sympy"; import { buildContextForDifferentialEquation, buildContextForTransition, - compileToSymPy, -} from "../../../simulation/simulator/compile-to-sympy"; + compileToIR, +} from "../../../expression/ts-to-ir/compile-to-ir"; type SymPyExpression = { name: string; @@ -12,6 +13,15 @@ type SymPyExpression = { error: string | null; }; +function compileToSymPy( + code: string, + ctx: ReturnType, +): { ok: true; sympyCode: string } | { ok: false; error: string } { + const irResult = compileToIR(code, ctx); + if (!irResult.ok) return irResult; + return { ok: true, sympyCode: irToSymPy(irResult.ir) }; +} + /** * Converts all expressions in an SDCPN model to SymPy and produces a JSON * export containing both the original model and the SymPy representations. From a0b3f0a7767569f9c9d7d927edb67c1b2f65aac1 Mon Sep 17 00:00:00 2001 From: Chris Feijoo Date: Wed, 18 Mar 2026 18:53:55 +0100 Subject: [PATCH 08/14] FE-514: Add expression output split view with IR/SymPy format selector Add a `showExpressionOutput` user setting (default: false) with a toggle in the viewport settings dialog. When enabled, a resizable split panel appears below the Lambda and Transition Results code editors showing the compiled output. A floating format selector allows switching between expression IR (JSON) and SymPy (Python). Co-Authored-By: Claude Opus 4.6 (1M context) --- .../expression/expression-output-panel.tsx | 68 ++++++++++++++++ .../expression/use-expression-ir-output.ts | 66 ++++++++++++++++ .../src/state/user-settings-context.ts | 4 + .../src/state/user-settings-provider.tsx | 2 + .../transition-firing-time/subview.tsx | 77 +++++++++++++++---- .../subviews/transition-results/subview.tsx | 77 +++++++++++++++---- .../components/viewport-settings-dialog.tsx | 11 +++ 7 files changed, 279 insertions(+), 26 deletions(-) create mode 100644 libs/@hashintel/petrinaut/src/expression/expression-output-panel.tsx create mode 100644 libs/@hashintel/petrinaut/src/expression/use-expression-ir-output.ts diff --git a/libs/@hashintel/petrinaut/src/expression/expression-output-panel.tsx b/libs/@hashintel/petrinaut/src/expression/expression-output-panel.tsx new file mode 100644 index 00000000000..c664694ebb3 --- /dev/null +++ b/libs/@hashintel/petrinaut/src/expression/expression-output-panel.tsx @@ -0,0 +1,68 @@ +import { css } from "@hashintel/ds-helpers/css"; +import { useState } from "react"; + +import { Select } from "../components/select"; +import { CodeEditor } from "../monaco/code-editor"; +import type { + ExpressionOutput, + ExpressionOutputFormat, +} from "./use-expression-ir-output"; + +const containerStyle = css({ + position: "relative", + height: "full", +}); + +const selectContainerStyle = css({ + display: "flex", + alignItems: "center", + gap: "1", + position: "absolute", + bottom: "0", + right: "0", + zIndex: "[10]", + backdropFilter: "[blur(20px)]", + p: "1", + pl: "2", + borderTopLeftRadius: "sm", +}); + +const selectLabelStyle = css({ + fontSize: "xs", + fontWeight: "medium", + color: "neutral.s80", +}); + +const selectStyle = css({ + width: "[90px]", +}); + +export const ExpressionOutputPanel: React.FC<{ + output: ExpressionOutput; +}> = ({ output }) => { + const [format, setFormat] = useState("ir"); + + return ( +
+
+ Target + Date: Wed, 18 Mar 2026 20:02:57 +0100 Subject: [PATCH 09/14] FE-514: Add OCaml and Lean output targets, improve SymPy output quality MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rewrite SymPy emitter to produce clean Python: explicit `from X import Y` imports (no `sp.` prefix), hoisted named distributions from let-bindings, inline anonymous distributions, `return` statement, and auto-indented dicts/lists when >80 chars. DerivedDistribution replaced with direct symbolic substitution. - Add OCaml backend (ir-to-ocaml) with Float operators, let-in bindings, Distribution functions, auto-indented records/lists, and direct distribution substitution (no Distribution.map wrapper). - Add Lean 4 backend (ir-to-lean) with Mathlib imports, Real.cos/sin/sqrt, Unicode operators (≤, ≥, ≠, ∧, ∨, ¬), gaussianReal/uniformOn distributions, and ℝ type annotations on let-bindings. - Wire all four targets (IR, SymPy, OCaml, Lean) into the expression output panel selector. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../expression/expression-output-panel.tsx | 2 + .../src/expression/ir-to-lean/ir-to-lean.ts | 322 +++++++++++++ .../src/expression/ir-to-ocaml/ir-to-ocaml.ts | 228 +++++++++ .../ir-to-sympy/ir-to-sympy.test.ts | 143 ++++-- .../src/expression/ir-to-sympy/ir-to-sympy.ts | 442 +++++++++++------- .../expression/use-expression-ir-output.ts | 10 +- 6 files changed, 939 insertions(+), 208 deletions(-) create mode 100644 libs/@hashintel/petrinaut/src/expression/ir-to-lean/ir-to-lean.ts create mode 100644 libs/@hashintel/petrinaut/src/expression/ir-to-ocaml/ir-to-ocaml.ts diff --git a/libs/@hashintel/petrinaut/src/expression/expression-output-panel.tsx b/libs/@hashintel/petrinaut/src/expression/expression-output-panel.tsx index c664694ebb3..f293fb793a4 100644 --- a/libs/@hashintel/petrinaut/src/expression/expression-output-panel.tsx +++ b/libs/@hashintel/petrinaut/src/expression/expression-output-panel.tsx @@ -53,6 +53,8 @@ export const ExpressionOutputPanel: React.FC<{ options={[ { value: "ir", label: "IR" }, { value: "sympy", label: "SymPy" }, + { value: "ocaml", label: "OCaml" }, + { value: "lean", label: "Lean" }, ]} size="xs" portal={false} diff --git a/libs/@hashintel/petrinaut/src/expression/ir-to-lean/ir-to-lean.ts b/libs/@hashintel/petrinaut/src/expression/ir-to-lean/ir-to-lean.ts new file mode 100644 index 00000000000..627ef2aed14 --- /dev/null +++ b/libs/@hashintel/petrinaut/src/expression/ir-to-lean/ir-to-lean.ts @@ -0,0 +1,322 @@ +import type { ExpressionIR } from "../expression-ir"; + +const INDENT = " "; +function pad(level: number): string { + return INDENT.repeat(level); +} + +const MATH_FN_MAP: Record = { + cos: "Real.cos", + sin: "Real.sin", + tan: "Real.tan", + acos: "Real.arccos", + asin: "Real.arcsin", + atan: "Real.arctan", + atan2: "Real.arctan2", + sqrt: "Real.sqrt", + log: "Real.log", + exp: "Real.exp", + abs: "abs", + floor: "⌊·⌋", + ceil: "⌈·⌉", + min: "min", + max: "max", +}; + +const MATH_CONSTANT_MAP: Record = { + PI: "Real.pi", + E: "Real.exp 1", +}; + +const DISTRIBUTION_MAP: Record = { + Gaussian: { + fn: "gaussianReal", + import: "Mathlib.Probability.Distributions.Gaussian", + }, + Uniform: { + fn: "uniformOn (Set.Icc", + import: "Mathlib.Probability.Distributions.Uniform", + }, + Lognormal: { + fn: "lognormalReal", + import: "Mathlib.Probability.Distributions.Gaussian", + }, +}; + +/** + * Accumulates imports and let-bindings for clean Lean 4 output. + */ +class LeanEmitter { + readonly imports = new Set(); + readonly statements: string[] = []; + private readonly env: Map; + + constructor(env?: Map) { + this.env = new Map(env); + } + + private addImport(module: string): void { + this.imports.add(module); + } + + setEnv(name: string, value: string): void { + this.env.set(name, value); + } + + emit(node: ExpressionIR, indent = 0): string { + switch (node.type) { + case "number": + return node.value; + + case "boolean": + return node.value ? "True" : "False"; + + case "infinity": { + this.addImport("Mathlib.Order.BoundedOrder"); + return "⊤"; + } + + case "symbol": { + const constant = MATH_CONSTANT_MAP[node.name]; + if (constant) { + this.addImport( + "Mathlib.Analysis.SpecialFunctions.Trigonometric.Basic", + ); + return constant; + } + return this.env.get(node.name) ?? node.name; + } + + case "parameter": + return node.name; + + case "tokenAccess": { + const index = this.emit(node.index); + return `${node.place}_${index}_${node.field}`; + } + + case "binary": { + const left = this.emit(node.left, indent); + const right = this.emit(node.right, indent); + return this.emitBinaryOp(node.op, left, right); + } + + case "unary": { + const operand = this.emit(node.operand, indent); + switch (node.op) { + case "-": + return `-(${operand})`; + case "!": + return `¬(${operand})`; + case "+": + return operand; + } + break; + } + + case "call": + return this.emitCall(node.fn, node.args, indent); + + case "distribution": { + const dist = DISTRIBUTION_MAP[node.distribution]; + if (!dist) return `${node.distribution}(?)`; + this.addImport(dist.import); + const args = node.args.map((a) => this.emit(a, indent)); + if (node.distribution === "Uniform") { + return `uniformOn (Set.Icc ${args[0]!} ${args[1]!})`; + } + return `${dist.fn} ${args.join(" ")}`; + } + + case "derivedDistribution": { + const dist = this.emit(node.distribution, indent); + this.setEnv(node.variable, dist); + return this.emit(node.body, indent); + } + + case "piecewise": { + const condition = this.emit(node.condition, indent); + const whenTrue = this.emit(node.whenTrue, indent); + const whenFalse = this.emit(node.whenFalse, indent); + return `if ${condition} then ${whenTrue} else ${whenFalse}`; + } + + case "array": { + const elements = node.elements.map((e) => this.emit(e, indent)); + const flat = `[${elements.join(", ")}]`; + if (flat.length <= 80 || elements.length <= 1) return flat; + const inner = pad(indent + 1); + const outer = pad(indent); + return `[\n${inner}${elements.join(`,\n${inner}`)}\n${outer}]`; + } + + case "object": { + const entries = node.entries.map( + (e) => `${e.key} := ${this.emit(e.value, indent + 1)}`, + ); + const flat = `{ ${entries.join(", ")} }`; + if (flat.length <= 80 || entries.length <= 1) return flat; + const inner = pad(indent + 1); + const outer = pad(indent); + return `{\n${inner}${entries.join(`,\n${inner}`)}\n${outer}}`; + } + + case "listComprehension": { + const body = this.emit(node.body, indent); + const collection = this.emit(node.collection, indent); + return `${collection}.map (fun ${node.variable} => ${body})`; + } + + case "let": { + for (const binding of node.bindings) { + if (binding.value.type === "distribution") { + this.emitNamedDistribution(binding.value, binding.name); + } else { + const value = this.emit(binding.value, indent); + this.statements.push(`let ${binding.name} : ℝ := ${value}`); + } + this.env.set(binding.name, binding.name); + } + return this.emit(node.body, indent); + } + + case "propertyAccess": { + const obj = this.emit(node.object, indent); + return `${obj}.${node.property}`; + } + + case "elementAccess": { + const obj = this.emit(node.object, indent); + const index = this.emit(node.index, indent); + return `${obj}[${index}]!`; + } + } + } + + emitNamedDistribution( + node: ExpressionIR & { type: "distribution" }, + name: string, + ): void { + const dist = DISTRIBUTION_MAP[node.distribution]; + if (!dist) return; + this.addImport(dist.import); + const args = node.args.map((a) => this.emit(a)); + if (node.distribution === "Uniform") { + this.statements.push( + `let ${name} := uniformOn (Set.Icc ${args[0]!} ${args[1]!})`, + ); + } else { + this.statements.push(`let ${name} := ${dist.fn} ${args.join(" ")}`); + } + } + + private emitCall(fn: string, args: ExpressionIR[], indent: number): string { + const compiledArgs = args.map((a) => this.emit(a, indent)); + + if (fn === "hypot") { + this.addImport("Mathlib.Analysis.SpecialFunctions.Pow.Real"); + const sumOfSquares = compiledArgs.map((a) => `(${a}) ^ 2`).join(" + "); + return `Real.sqrt (${sumOfSquares})`; + } + + if (fn === "pow" && compiledArgs.length === 2) { + return `(${compiledArgs[0]!}) ^ (${compiledArgs[1]!})`; + } + + if (fn === "floor") { + return `⌊${compiledArgs[0]!}⌋`; + } + + if (fn === "ceil") { + return `⌈${compiledArgs[0]!}⌉`; + } + + const leanFn = MATH_FN_MAP[fn]; + if (leanFn) { + if (fn === "cos" || fn === "sin" || fn === "tan") { + this.addImport("Mathlib.Analysis.SpecialFunctions.Trigonometric.Basic"); + } else if (fn === "acos" || fn === "asin" || fn === "atan") { + this.addImport( + "Mathlib.Analysis.SpecialFunctions.Trigonometric.Inverse", + ); + } else if (fn === "sqrt" || fn === "log" || fn === "exp") { + this.addImport("Mathlib.Analysis.SpecialFunctions.Pow.Real"); + } + return `${leanFn} ${compiledArgs.map((a) => `(${a})`).join(" ")}`; + } + + return `${fn} ${compiledArgs.map((a) => `(${a})`).join(" ")}`; + } + + private emitBinaryOp(op: string, left: string, right: string): string { + switch (op) { + case "+": + return `${left} + ${right}`; + case "-": + return `${left} - ${right}`; + case "*": + return `${left} * ${right}`; + case "/": + return `${left} / ${right}`; + case "**": + return `${left} ^ ${right}`; + case "%": + return `${left} % ${right}`; + case "<": + return `${left} < ${right}`; + case "<=": + return `${left} ≤ ${right}`; + case ">": + return `${left} > ${right}`; + case ">=": + return `${left} ≥ ${right}`; + case "==": + return `${left} = ${right}`; + case "!=": + return `${left} ≠ ${right}`; + case "&&": + return `${left} ∧ ${right}`; + case "||": + return `${left} ∨ ${right}`; + default: + return `${left} ${op} ${right}`; + } + } + + renderImports(): string[] { + const sorted = [...this.imports].sort(); + return sorted.map((m) => `import ${m}`); + } +} + +/** + * Converts an expression IR node to Lean 4 code with Mathlib. + * + * Produces clean Lean with explicit imports, hoisted distributions, + * and let-bindings as variable assignments. + */ +export function irToLean( + node: ExpressionIR, + env: Map = new Map(), +): string { + const emitter = new LeanEmitter(env); + const expr = emitter.emit(node); + + const imports = emitter.renderImports(); + const parts: string[] = []; + + if (imports.length > 0) { + parts.push(imports.join("\n")); + } + + // Always open Real and ProbabilityTheory for cleaner code + parts.push("open Real ProbabilityTheory"); + + if (emitter.statements.length > 0) { + parts.push(emitter.statements.join("\n")); + } + + parts.push(`return ${expr}`); + + return parts.join("\n\n"); +} diff --git a/libs/@hashintel/petrinaut/src/expression/ir-to-ocaml/ir-to-ocaml.ts b/libs/@hashintel/petrinaut/src/expression/ir-to-ocaml/ir-to-ocaml.ts new file mode 100644 index 00000000000..dd50db15540 --- /dev/null +++ b/libs/@hashintel/petrinaut/src/expression/ir-to-ocaml/ir-to-ocaml.ts @@ -0,0 +1,228 @@ +import type { ExpressionIR } from "../expression-ir"; + +const INDENT = " "; +function pad(level: number): string { + return INDENT.repeat(level); +} + +const MATH_FN_MAP: Record = { + cos: "Float.cos", + sin: "Float.sin", + tan: "Float.tan", + acos: "Float.acos", + asin: "Float.asin", + atan: "Float.atan", + atan2: "Float.atan2", + sqrt: "Float.sqrt", + log: "Float.log", + exp: "Float.exp", + abs: "Float.abs", + floor: "Float.round ~dir:`Down", + ceil: "Float.round ~dir:`Up", + min: "Float.min", + max: "Float.max", +}; + +const MATH_CONSTANT_MAP: Record = { + PI: "Float.pi", + E: "Float.(exp 1.0)", +}; + +const DISTRIBUTION_MAP: Record = { + Gaussian: "Distribution.gaussian", + Uniform: "Distribution.uniform", + Lognormal: "Distribution.lognormal", +}; + +/** + * Converts an expression IR node to OCaml code. + * + * Let-bindings are emitted as OCaml `let ... in` expressions. + */ +export function irToOCaml( + node: ExpressionIR, + env: Map = new Map(), + indent = 0, +): string { + switch (node.type) { + case "number": { + const val = node.value; + // OCaml float literals require a dot + return val.includes(".") ? val : `${val}.`; + } + + case "boolean": + return node.value ? "true" : "false"; + + case "infinity": + return "Float.infinity"; + + case "symbol": { + const constant = MATH_CONSTANT_MAP[node.name]; + if (constant) return constant; + return env.get(node.name) ?? node.name; + } + + case "parameter": + return node.name; + + case "tokenAccess": { + const index = irToOCaml(node.index, env); + return `${node.place}_${index}_${node.field}`; + } + + case "binary": { + const left = irToOCaml(node.left, env); + const right = irToOCaml(node.right, env); + return emitBinaryOp(node.op, left, right); + } + + case "unary": { + const operand = irToOCaml(node.operand, env); + switch (node.op) { + case "-": + return `Float.neg (${operand})`; + case "!": + return `not (${operand})`; + case "+": + return operand; + } + break; + } + + case "call": + return emitCall(node.fn, node.args, env); + + case "distribution": { + const distFn = DISTRIBUTION_MAP[node.distribution]; + const args = node.args.map((a) => irToOCaml(a, env)); + return `${distFn} ${args.map((a) => `(${a})`).join(" ")}`; + } + + case "derivedDistribution": { + const dist = irToOCaml(node.distribution, env, indent); + const localEnv = new Map(env); + localEnv.set(node.variable, dist); + return irToOCaml(node.body, localEnv, indent); + } + + case "piecewise": { + const condition = irToOCaml(node.condition, env); + const whenTrue = irToOCaml(node.whenTrue, env); + const whenFalse = irToOCaml(node.whenFalse, env); + return `if ${condition} then ${whenTrue} else ${whenFalse}`; + } + + case "array": { + const elements = node.elements.map((e) => irToOCaml(e, env, indent)); + const flat = `[${elements.join("; ")}]`; + if (flat.length <= 80 || elements.length <= 1) return flat; + const inner = pad(indent + 1); + const outer = pad(indent); + return `[\n${inner}${elements.join(`;\n${inner}`)}\n${outer}]`; + } + + case "object": { + const entries = node.entries.map( + (e) => `${e.key} = ${irToOCaml(e.value, env, indent + 1)}`, + ); + const flat = `{ ${entries.join("; ")} }`; + if (flat.length <= 80 || entries.length <= 1) return flat; + const inner = pad(indent + 1); + const outer = pad(indent); + return `{\n${inner}${entries.join(`;\n${inner}`)}\n${outer}}`; + } + + case "listComprehension": { + const body = irToOCaml(node.body, env); + const collection = irToOCaml(node.collection, env); + return `List.map (fun ${node.variable} -> ${body}) ${collection}`; + } + + case "let": { + let result = ""; + const localEnv = new Map(env); + for (const binding of node.bindings) { + const value = irToOCaml(binding.value, localEnv); + result += `let ${binding.name} = ${value} in\n`; + localEnv.set(binding.name, binding.name); + } + result += irToOCaml(node.body, localEnv); + return result; + } + + case "propertyAccess": { + const obj = irToOCaml(node.object, env); + return `(${obj}).${node.property}`; + } + + case "elementAccess": { + const obj = irToOCaml(node.object, env); + const index = irToOCaml(node.index, env); + return `List.nth (${obj}) (${index})`; + } + } +} + +function emitBinaryOp(op: string, left: string, right: string): string { + switch (op) { + case "+": + return `(${left}) +. (${right})`; + case "-": + return `(${left}) -. (${right})`; + case "*": + return `(${left}) *. (${right})`; + case "/": + return `(${left}) /. (${right})`; + case "**": + return `(${left}) ** (${right})`; + case "%": + return `Float.mod_float (${left}) (${right})`; + case "<": + return `Float.( < ) (${left}) (${right})`; + case "<=": + return `Float.( <= ) (${left}) (${right})`; + case ">": + return `Float.( > ) (${left}) (${right})`; + case ">=": + return `Float.( >= ) (${left}) (${right})`; + case "==": + return `Float.( = ) (${left}) (${right})`; + case "!=": + return `Float.( <> ) (${left}) (${right})`; + case "&&": + return `(${left}) && (${right})`; + case "||": + return `(${left}) || (${right})`; + default: + return `(${left}) ${op} (${right})`; + } +} + +function emitCall( + fn: string, + args: ExpressionIR[], + env: Map, +): string { + const compiledArgs = args.map((a) => irToOCaml(a, env)); + + // Math.hypot(a, b) → Float.sqrt (a *. a +. b *. b) + if (fn === "hypot") { + const sumOfSquares = compiledArgs + .map((a) => `(${a}) *. (${a})`) + .join(" +. "); + return `Float.sqrt (${sumOfSquares})`; + } + + // Math.pow(a, b) → a ** b + if (fn === "pow" && compiledArgs.length === 2) { + return `(${compiledArgs[0]!}) ** (${compiledArgs[1]!})`; + } + + const ocamlFn = MATH_FN_MAP[fn]; + if (ocamlFn) { + return `${ocamlFn} (${compiledArgs.join(") (")})`; + } + + return `${fn} (${compiledArgs.join(") (")})`; +} diff --git a/libs/@hashintel/petrinaut/src/expression/ir-to-sympy/ir-to-sympy.test.ts b/libs/@hashintel/petrinaut/src/expression/ir-to-sympy/ir-to-sympy.test.ts index f5fb265267f..ef89d8ad325 100644 --- a/libs/@hashintel/petrinaut/src/expression/ir-to-sympy/ir-to-sympy.test.ts +++ b/libs/@hashintel/petrinaut/src/expression/ir-to-sympy/ir-to-sympy.test.ts @@ -6,36 +6,42 @@ import { irToSymPy } from "./ir-to-sympy"; describe("irToSymPy", () => { describe("literals", () => { it("should emit number", () => { - expect(irToSymPy({ type: "number", value: "42" })).toBe("42"); + expect(irToSymPy({ type: "number", value: "42" })).toBe("return 42"); }); it("should emit boolean true as True", () => { - expect(irToSymPy({ type: "boolean", value: true })).toBe("True"); + expect(irToSymPy({ type: "boolean", value: true })).toBe("return True"); }); it("should emit boolean false as False", () => { - expect(irToSymPy({ type: "boolean", value: false })).toBe("False"); + expect(irToSymPy({ type: "boolean", value: false })).toBe("return False"); }); - it("should emit infinity as sp.oo", () => { - expect(irToSymPy({ type: "infinity" })).toBe("sp.oo"); + it("should emit infinity with import", () => { + expect(irToSymPy({ type: "infinity" })).toBe( + "from sympy import oo\n\nreturn oo", + ); }); }); describe("symbols and parameters", () => { it("should emit symbol name", () => { - expect(irToSymPy({ type: "symbol", name: "x" })).toBe("x"); + expect(irToSymPy({ type: "symbol", name: "x" })).toBe("return x"); }); it("should emit parameter name", () => { expect(irToSymPy({ type: "parameter", name: "infection_rate" })).toBe( - "infection_rate", + "return infection_rate", ); }); - it("should emit Math constants", () => { - expect(irToSymPy({ type: "symbol", name: "PI" })).toBe("sp.pi"); - expect(irToSymPy({ type: "symbol", name: "E" })).toBe("sp.E"); + it("should emit Math constants with import", () => { + expect(irToSymPy({ type: "symbol", name: "PI" })).toBe( + "from sympy import pi\n\nreturn pi", + ); + expect(irToSymPy({ type: "symbol", name: "E" })).toBe( + "from sympy import E\n\nreturn E", + ); }); }); @@ -48,7 +54,7 @@ describe("irToSymPy", () => { index: { type: "number", value: "0" }, field: "x", }), - ).toBe("Space_0_x"); + ).toBe("return Space_0_x"); }); }); @@ -57,35 +63,41 @@ describe("irToSymPy", () => { const right: ExpressionIR = { type: "number", value: "2" }; it("should emit arithmetic operators", () => { - expect(irToSymPy({ type: "binary", op: "+", left, right })).toBe("1 + 2"); - expect(irToSymPy({ type: "binary", op: "*", left, right })).toBe("1 * 2"); - expect(irToSymPy({ type: "binary", op: "**", left, right })).toBe("1**2"); + expect(irToSymPy({ type: "binary", op: "+", left, right })).toBe( + "return 1 + 2", + ); + expect(irToSymPy({ type: "binary", op: "*", left, right })).toBe( + "return 1 * 2", + ); + expect(irToSymPy({ type: "binary", op: "**", left, right })).toBe( + "return 1**2", + ); }); - it("should emit modulo as sp.Mod", () => { + it("should emit modulo as Mod with import", () => { expect(irToSymPy({ type: "binary", op: "%", left, right })).toBe( - "sp.Mod(1, 2)", + "from sympy import Mod\n\nreturn Mod(1, 2)", ); }); - it("should emit equality as sp.Eq", () => { + it("should emit equality as Eq with import", () => { expect(irToSymPy({ type: "binary", op: "==", left, right })).toBe( - "sp.Eq(1, 2)", + "from sympy import Eq\n\nreturn Eq(1, 2)", ); }); - it("should emit inequality as sp.Ne", () => { + it("should emit inequality as Ne with import", () => { expect(irToSymPy({ type: "binary", op: "!=", left, right })).toBe( - "sp.Ne(1, 2)", + "from sympy import Ne\n\nreturn Ne(1, 2)", ); }); - it("should emit logical operators", () => { + it("should emit logical operators with imports", () => { expect(irToSymPy({ type: "binary", op: "&&", left, right })).toBe( - "sp.And(1, 2)", + "from sympy import And\n\nreturn And(1, 2)", ); expect(irToSymPy({ type: "binary", op: "||", left, right })).toBe( - "sp.Or(1, 2)", + "from sympy import Or\n\nreturn Or(1, 2)", ); }); }); @@ -98,29 +110,29 @@ describe("irToSymPy", () => { op: "-", operand: { type: "symbol", name: "x" }, }), - ).toBe("-(x)"); + ).toBe("return -(x)"); }); - it("should emit logical not", () => { + it("should emit logical not with import", () => { expect( irToSymPy({ type: "unary", op: "!", operand: { type: "boolean", value: true }, }), - ).toBe("sp.Not(True)"); + ).toBe("from sympy import Not\n\nreturn Not(True)"); }); }); describe("function calls", () => { - it("should emit math functions with sp. prefix", () => { + it("should emit math functions with import", () => { expect( irToSymPy({ type: "call", fn: "cos", args: [{ type: "symbol", name: "x" }], }), - ).toBe("sp.cos(x)"); + ).toBe("from sympy import cos\n\nreturn cos(x)"); }); it("should emit hypot as sqrt of sum of squares", () => { @@ -133,7 +145,7 @@ describe("irToSymPy", () => { { type: "symbol", name: "b" }, ], }), - ).toBe("sp.sqrt((a)**2 + (b)**2)"); + ).toBe("from sympy import sqrt\n\nreturn sqrt((a)**2 + (b)**2)"); }); it("should emit pow as exponentiation", () => { @@ -146,12 +158,12 @@ describe("irToSymPy", () => { { type: "number", value: "2" }, ], }), - ).toBe("(a)**(2)"); + ).toBe("return (a)**(2)"); }); }); describe("distributions", () => { - it("should emit Gaussian as sp.stats.Normal", () => { + it("should emit inline Gaussian as Normal", () => { expect( irToSymPy({ type: "distribution", @@ -161,10 +173,10 @@ describe("irToSymPy", () => { { type: "number", value: "1" }, ], }), - ).toBe("sp.stats.Normal('X', 0, 1)"); + ).toBe("from sympy.stats import Normal\n\nreturn Normal(0, 1)"); }); - it("should emit Lognormal as sp.stats.LogNormal", () => { + it("should emit inline Lognormal as LogNormal", () => { expect( irToSymPy({ type: "distribution", @@ -174,12 +186,42 @@ describe("irToSymPy", () => { { type: "number", value: "1" }, ], }), - ).toBe("sp.stats.LogNormal('X', 0, 1)"); + ).toBe("from sympy.stats import LogNormal\n\nreturn LogNormal(0, 1)"); + }); + + it("should emit named distribution from let-binding with symbol name", () => { + expect( + irToSymPy({ + type: "let", + bindings: [ + { + name: "angle", + value: { + type: "distribution", + distribution: "Uniform", + args: [ + { type: "number", value: "0" }, + { type: "number", value: "1" }, + ], + }, + }, + ], + body: { type: "symbol", name: "angle" }, + }), + ).toBe( + [ + "from sympy.stats import Uniform", + "", + "angle = Uniform('angle', 0, 1)", + "", + "return angle", + ].join("\n"), + ); }); }); describe("derived distributions", () => { - it("should emit DerivedDistribution with lambda", () => { + it("should substitute inline distribution into body", () => { expect( irToSymPy({ type: "derivedDistribution", @@ -199,13 +241,18 @@ describe("irToSymPy", () => { }, }), ).toBe( - "DerivedDistribution(sp.stats.Normal('X', 0, 10), lambda _x: sp.cos(_x))", + [ + "from sympy import cos", + "from sympy.stats import Normal", + "", + "return cos(Normal(0, 10))", + ].join("\n"), ); }); }); describe("piecewise", () => { - it("should emit sp.Piecewise", () => { + it("should emit Piecewise with import", () => { expect( irToSymPy({ type: "piecewise", @@ -218,7 +265,9 @@ describe("irToSymPy", () => { whenTrue: { type: "symbol", name: "x" }, whenFalse: { type: "number", value: "0" }, }), - ).toBe("sp.Piecewise((x, x > 0), (0, True))"); + ).toBe( + "from sympy import Piecewise\n\nreturn Piecewise((x, x > 0), (0, True))", + ); }); }); @@ -232,7 +281,7 @@ describe("irToSymPy", () => { { type: "number", value: "2" }, ], }), - ).toBe("[1, 2]"); + ).toBe("return [1, 2]"); }); it("should emit object as Python dict", () => { @@ -244,7 +293,7 @@ describe("irToSymPy", () => { { key: "y", value: { type: "number", value: "2" } }, ], }), - ).toBe("{'x': 1, 'y': 2}"); + ).toBe("return {'x': 1, 'y': 2}"); }); }); @@ -262,12 +311,12 @@ describe("irToSymPy", () => { right: { type: "number", value: "1" }, }, }), - ).toBe("[_iter + 1 for _iter in tokens]"); + ).toBe("return [_iter + 1 for _iter in tokens]"); }); }); describe("let bindings", () => { - it("should inline single binding", () => { + it("should emit single binding as assignment", () => { expect( irToSymPy({ type: "let", @@ -284,10 +333,10 @@ describe("irToSymPy", () => { right: { type: "number", value: "2" }, }, }), - ).toBe("gravitational_constant * 2"); + ).toBe("mu = gravitational_constant\n\nreturn mu * 2"); }); - it("should inline chained bindings", () => { + it("should emit chained bindings as assignments", () => { expect( irToSymPy({ type: "let", @@ -305,7 +354,7 @@ describe("irToSymPy", () => { right: { type: "symbol", name: "b" }, }, }), - ).toBe("infection_rate + recovery_rate"); + ).toBe("a = infection_rate\nb = recovery_rate\n\nreturn a + b"); }); }); @@ -317,7 +366,7 @@ describe("irToSymPy", () => { object: { type: "symbol", name: "obj" }, property: "field", }), - ).toBe("obj_field"); + ).toBe("return obj_field"); }); it("should emit element access with underscore", () => { @@ -327,7 +376,7 @@ describe("irToSymPy", () => { object: { type: "symbol", name: "arr" }, index: { type: "number", value: "0" }, }), - ).toBe("arr_0"); + ).toBe("return arr_0"); }); }); }); diff --git a/libs/@hashintel/petrinaut/src/expression/ir-to-sympy/ir-to-sympy.ts b/libs/@hashintel/petrinaut/src/expression/ir-to-sympy/ir-to-sympy.ts index fbb54ecf1ad..16aab28199b 100644 --- a/libs/@hashintel/petrinaut/src/expression/ir-to-sympy/ir-to-sympy.ts +++ b/libs/@hashintel/petrinaut/src/expression/ir-to-sympy/ir-to-sympy.ts @@ -1,207 +1,329 @@ import type { ExpressionIR } from "../expression-ir"; -const MATH_FN_MAP: Record = { - cos: "sp.cos", - sin: "sp.sin", - tan: "sp.tan", - acos: "sp.acos", - asin: "sp.asin", - atan: "sp.atan", - atan2: "sp.atan2", - sqrt: "sp.sqrt", - log: "sp.log", - exp: "sp.exp", - abs: "sp.Abs", - floor: "sp.floor", - ceil: "sp.ceiling", - min: "sp.Min", - max: "sp.Max", +/** Maps IR function names to SymPy names and their `sympy` module path. */ +const MATH_FN_MAP: Record = { + cos: { name: "cos", module: "sympy" }, + sin: { name: "sin", module: "sympy" }, + tan: { name: "tan", module: "sympy" }, + acos: { name: "acos", module: "sympy" }, + asin: { name: "asin", module: "sympy" }, + atan: { name: "atan", module: "sympy" }, + atan2: { name: "atan2", module: "sympy" }, + sqrt: { name: "sqrt", module: "sympy" }, + log: { name: "log", module: "sympy" }, + exp: { name: "exp", module: "sympy" }, + abs: { name: "Abs", module: "sympy" }, + floor: { name: "floor", module: "sympy" }, + ceil: { name: "ceiling", module: "sympy" }, + min: { name: "Min", module: "sympy" }, + max: { name: "Max", module: "sympy" }, }; -const MATH_CONSTANT_MAP: Record = { - PI: "sp.pi", - E: "sp.E", +const MATH_CONSTANT_MAP: Record = { + PI: { name: "pi", module: "sympy" }, + E: { name: "E", module: "sympy" }, }; -const DISTRIBUTION_MAP: Record = { - Gaussian: "sp.stats.Normal", - Uniform: "sp.stats.Uniform", - Lognormal: "sp.stats.LogNormal", +const DISTRIBUTION_MAP: Record< + string, + { name: string; module: "sympy.stats" } +> = { + Gaussian: { name: "Normal", module: "sympy.stats" }, + Uniform: { name: "Uniform", module: "sympy.stats" }, + Lognormal: { name: "LogNormal", module: "sympy.stats" }, }; +/** Symbols that need `from sympy import ...` */ +const SYMPY_SYMBOLS: Record = { + Piecewise: { name: "Piecewise", module: "sympy" }, + Mod: { name: "Mod", module: "sympy" }, + Eq: { name: "Eq", module: "sympy" }, + Ne: { name: "Ne", module: "sympy" }, + And: { name: "And", module: "sympy" }, + Or: { name: "Or", module: "sympy" }, + Not: { name: "Not", module: "sympy" }, + oo: { name: "oo", module: "sympy" }, +}; + +const INDENT = " "; +function pad(level: number): string { + return INDENT.repeat(level); +} + /** - * Converts an expression IR node to SymPy Python code. - * - * Let-bindings are inlined: the binding's value replaces all references - * to the binding name in the body expression. + * Accumulates top-level variable assignments and import tracking so that + * distributions and let-bindings are hoisted out of the final expression. */ -export function irToSymPy( - node: ExpressionIR, - env: Map = new Map(), -): string { - switch (node.type) { - case "number": - return node.value; +class SymPyEmitter { + /** Ordered list of `name = value` assignment lines */ + readonly statements: string[] = []; + /** Tracks imports: module → set of names */ + readonly imports = new Map>(); + private readonly env: Map; - case "boolean": - return node.value ? "True" : "False"; + constructor(env?: Map) { + this.env = new Map(env); + } - case "infinity": - return "sp.oo"; + /** Create a child scope that shares the same statements list and imports. */ + child(): SymPyEmitter { + const child = new SymPyEmitter(this.env); + (child as { statements: string[] }).statements = this.statements; + (child as { imports: Map> }).imports = this.imports; + return child; + } - case "symbol": { - const constant = MATH_CONSTANT_MAP[node.name]; - if (constant) return constant; - return env.get(node.name) ?? node.name; + private addImport(module: string, name: string): void { + let names = this.imports.get(module); + if (!names) { + names = new Set(); + this.imports.set(module, names); } + names.add(name); + } - case "parameter": - return node.name; + /** Use a sympy function/constant, registering its import. */ + private use(entry: { name: string; module: string }): string { + this.addImport(entry.module, entry.name); + return entry.name; + } - case "tokenAccess": { - const index = irToSymPy(node.index, env); - return `${node.place}_${index}_${node.field}`; - } + setEnv(name: string, value: string): void { + this.env.set(name, value); + } - case "binary": { - const left = irToSymPy(node.left, env); - const right = irToSymPy(node.right, env); - return emitBinaryOp(node.op, left, right); - } + emit(node: ExpressionIR, indent = 0): string { + switch (node.type) { + case "number": + return node.value; + + case "boolean": + return node.value ? "True" : "False"; - case "unary": { - const operand = irToSymPy(node.operand, env); - switch (node.op) { - case "-": - return `-(${operand})`; - case "!": - return `sp.Not(${operand})`; - case "+": - return operand; + case "infinity": + return this.use(SYMPY_SYMBOLS["oo"]!); + + case "symbol": { + const constant = MATH_CONSTANT_MAP[node.name]; + if (constant) return this.use(constant); + return this.env.get(node.name) ?? node.name; } - break; - } - case "call": - return emitCall(node.fn, node.args, env); + case "parameter": + return node.name; - case "distribution": { - const distFn = DISTRIBUTION_MAP[node.distribution]; - const args = node.args.map((a) => irToSymPy(a, env)); - return `${distFn}('X', ${args.join(", ")})`; - } + case "tokenAccess": { + const index = this.emit(node.index); + return `${node.place}_${index}_${node.field}`; + } - case "derivedDistribution": { - const dist = irToSymPy(node.distribution, env); - const localEnv = new Map(env); - localEnv.set(node.variable, node.variable); - const body = irToSymPy(node.body, localEnv); - return `DerivedDistribution(${dist}, lambda ${node.variable}: ${body})`; - } + case "binary": { + const left = this.emit(node.left); + const right = this.emit(node.right); + return this.emitBinaryOp(node.op, left, right); + } - case "piecewise": { - const condition = irToSymPy(node.condition, env); - const whenTrue = irToSymPy(node.whenTrue, env); - const whenFalse = irToSymPy(node.whenFalse, env); - return `sp.Piecewise((${whenTrue}, ${condition}), (${whenFalse}, True))`; - } + case "unary": { + const operand = this.emit(node.operand); + switch (node.op) { + case "-": + return `-(${operand})`; + case "!": + return `${this.use(SYMPY_SYMBOLS["Not"]!)}(${operand})`; + case "+": + return operand; + } + break; + } - case "array": { - const elements = node.elements.map((e) => irToSymPy(e, env)); - return `[${elements.join(", ")}]`; - } + case "call": + return this.emitCall(node.fn, node.args); + + case "distribution": { + // Inline: just emit the call directly (no hoisting, no symbol name) + const dist = DISTRIBUTION_MAP[node.distribution]; + const distFn = dist ? this.use(dist) : node.distribution; + const args = node.args.map((a) => this.emit(a)); + return `${distFn}(${args.join(", ")})`; + } + + case "derivedDistribution": { + const dist = this.emit(node.distribution); + this.setEnv(node.variable, dist); + return this.emit(node.body); + } - case "object": { - const entries = node.entries.map( - (e) => `'${e.key}': ${irToSymPy(e.value, env)}`, - ); - return `{${entries.join(", ")}}`; + case "piecewise": { + const pw = this.use(SYMPY_SYMBOLS["Piecewise"]!); + const condition = this.emit(node.condition); + const whenTrue = this.emit(node.whenTrue); + const whenFalse = this.emit(node.whenFalse); + return `${pw}((${whenTrue}, ${condition}), (${whenFalse}, True))`; + } + + case "array": { + const elements = node.elements.map((e) => this.emit(e, indent)); + const flat = `[${elements.join(", ")}]`; + if (flat.length <= 80 || elements.length <= 1) return flat; + const inner = pad(indent + 1); + const outer = pad(indent); + return `[\n${inner}${elements.join(`,\n${inner}`)},\n${outer}]`; + } + + case "object": { + const entries = node.entries.map( + (e) => `'${e.key}': ${this.emit(e.value, indent + 1)}`, + ); + const flat = `{${entries.join(", ")}}`; + if (flat.length <= 80 || entries.length <= 1) return flat; + const inner = pad(indent + 1); + const outer = pad(indent); + return `{\n${inner}${entries.join(`,\n${inner}`)},\n${outer}}`; + } + + case "listComprehension": { + const collection = this.emit(node.collection); + const childEmitter = this.child(); + childEmitter.setEnv(node.variable, node.variable); + const body = childEmitter.emit(node.body); + return `[${body} for ${node.variable} in ${collection}]`; + } + + case "let": { + for (const binding of node.bindings) { + if (binding.value.type === "distribution") { + this.emitNamedDistribution(binding.value, binding.name); + } else { + const value = this.emit(binding.value); + this.statements.push(`${binding.name} = ${value}`); + } + this.env.set(binding.name, binding.name); + } + return this.emit(node.body); + } + + case "propertyAccess": { + const obj = this.emit(node.object); + return `${obj}_${node.property}`; + } + + case "elementAccess": { + const obj = this.emit(node.object); + const index = this.emit(node.index); + return `${obj}_${index}`; + } } + } + + /** + * Emits a named distribution as a hoisted assignment. + * Called from `let` bindings where the user gave a name. + */ + emitNamedDistribution( + node: ExpressionIR & { type: "distribution" }, + name: string, + ): void { + const dist = DISTRIBUTION_MAP[node.distribution]; + const distFn = dist ? this.use(dist) : node.distribution; + const args = node.args.map((a) => this.emit(a)); + this.statements.push(`${name} = ${distFn}('${name}', ${args.join(", ")})`); + } - case "listComprehension": { - const body = irToSymPy(node.body, env); - const collection = irToSymPy(node.collection, env); - return `[${body} for ${node.variable} in ${collection}]`; + private emitCall(fn: string, args: ExpressionIR[]): string { + const compiledArgs = args.map((a) => this.emit(a)); + + if (fn === "hypot") { + const sqrtFn = this.use(MATH_FN_MAP["sqrt"]!); + const sumOfSquares = compiledArgs.map((a) => `(${a})**2`).join(" + "); + return `${sqrtFn}(${sumOfSquares})`; } - case "let": { - const localEnv = new Map(env); - for (const binding of node.bindings) { - localEnv.set(binding.name, irToSymPy(binding.value, localEnv)); - } - return irToSymPy(node.body, localEnv); + if (fn === "pow" && compiledArgs.length === 2) { + return `(${compiledArgs[0]!})**(${compiledArgs[1]!})`; } - case "propertyAccess": { - const obj = irToSymPy(node.object, env); - return `${obj}_${node.property}`; + const entry = MATH_FN_MAP[fn]; + if (entry) { + const sympyFn = this.use(entry); + return `${sympyFn}(${compiledArgs.join(", ")})`; } - case "elementAccess": { - const obj = irToSymPy(node.object, env); - const index = irToSymPy(node.index, env); - return `${obj}_${index}`; + return `${fn}(${compiledArgs.join(", ")})`; + } + + private emitBinaryOp(op: string, left: string, right: string): string { + switch (op) { + case "+": + return `${left} + ${right}`; + case "-": + return `${left} - ${right}`; + case "*": + return `${left} * ${right}`; + case "/": + return `${left} / ${right}`; + case "**": + return `${left}**${right}`; + case "%": + return `${this.use(SYMPY_SYMBOLS["Mod"]!)}(${left}, ${right})`; + case "<": + return `${left} < ${right}`; + case "<=": + return `${left} <= ${right}`; + case ">": + return `${left} > ${right}`; + case ">=": + return `${left} >= ${right}`; + case "==": + return `${this.use(SYMPY_SYMBOLS["Eq"]!)}(${left}, ${right})`; + case "!=": + return `${this.use(SYMPY_SYMBOLS["Ne"]!)}(${left}, ${right})`; + case "&&": + return `${this.use(SYMPY_SYMBOLS["And"]!)}(${left}, ${right})`; + case "||": + return `${this.use(SYMPY_SYMBOLS["Or"]!)}(${left}, ${right})`; + default: + return `${left} ${op} ${right}`; } } -} -function emitBinaryOp(op: string, left: string, right: string): string { - switch (op) { - case "+": - return `${left} + ${right}`; - case "-": - return `${left} - ${right}`; - case "*": - return `${left} * ${right}`; - case "/": - return `${left} / ${right}`; - case "**": - return `${left}**${right}`; - case "%": - return `sp.Mod(${left}, ${right})`; - case "<": - return `${left} < ${right}`; - case "<=": - return `${left} <= ${right}`; - case ">": - return `${left} > ${right}`; - case ">=": - return `${left} >= ${right}`; - case "==": - return `sp.Eq(${left}, ${right})`; - case "!=": - return `sp.Ne(${left}, ${right})`; - case "&&": - return `sp.And(${left}, ${right})`; - case "||": - return `sp.Or(${left}, ${right})`; - default: - return `${left} ${op} ${right}`; + /** Renders collected imports as `from import ...` lines. */ + renderImports(): string[] { + const lines: string[] = []; + // Sort modules for deterministic output + const modules = [...this.imports.keys()].sort(); + for (const module of modules) { + const names = [...this.imports.get(module)!].sort(); + lines.push(`from ${module} import ${names.join(", ")}`); + } + return lines; } } -function emitCall( - fn: string, - args: ExpressionIR[], - env: Map, +/** + * Converts an expression IR node to SymPy Python code. + * + * Produces clean Python with explicit imports, hoisted distributions, + * and let-bindings as variable assignments. + */ +export function irToSymPy( + node: ExpressionIR, + env: Map = new Map(), ): string { - const compiledArgs = args.map((a) => irToSymPy(a, env)); + const emitter = new SymPyEmitter(env); + const expr = emitter.emit(node); - // Math.hypot(a, b) → sp.sqrt(a**2 + b**2) - if (fn === "hypot") { - const sumOfSquares = compiledArgs.map((a) => `(${a})**2`).join(" + "); - return `sp.sqrt(${sumOfSquares})`; - } + const imports = emitter.renderImports(); + const parts: string[] = []; - // Math.pow(a, b) → a**b - if (fn === "pow" && compiledArgs.length === 2) { - return `(${compiledArgs[0]!})**(${compiledArgs[1]!})`; + if (imports.length > 0) { + parts.push(imports.join("\n")); } - const sympyFn = MATH_FN_MAP[fn]; - if (sympyFn) { - return `${sympyFn}(${compiledArgs.join(", ")})`; + if (emitter.statements.length > 0) { + parts.push(emitter.statements.join("\n")); } - return `${fn}(${compiledArgs.join(", ")})`; + parts.push(`return ${expr}`); + + return parts.join("\n\n"); } diff --git a/libs/@hashintel/petrinaut/src/expression/use-expression-ir-output.ts b/libs/@hashintel/petrinaut/src/expression/use-expression-ir-output.ts index 18a6df38099..9f1cde1feca 100644 --- a/libs/@hashintel/petrinaut/src/expression/use-expression-ir-output.ts +++ b/libs/@hashintel/petrinaut/src/expression/use-expression-ir-output.ts @@ -3,17 +3,21 @@ import { use, useMemo } from "react"; import type { Transition } from "../core/types/sdcpn"; import { SDCPNContext } from "../state/sdcpn-context"; import { UserSettingsContext } from "../state/user-settings-context"; +import { irToLean } from "./ir-to-lean/ir-to-lean"; +import { irToOCaml } from "./ir-to-ocaml/ir-to-ocaml"; import { irToSymPy } from "./ir-to-sympy/ir-to-sympy"; import { buildContextForTransition, compileToIR, } from "./ts-to-ir/compile-to-ir"; -export type ExpressionOutputFormat = "ir" | "sympy"; +export type ExpressionOutputFormat = "ir" | "sympy" | "ocaml" | "lean"; export type ExpressionOutput = { ir: string; sympy: string; + ocaml: string; + lean: string; }; /** @@ -45,6 +49,8 @@ export function useExpressionOutput( return { ir: JSON.stringify(result.ir, null, 2), sympy: irToSymPy(result.ir), + ocaml: irToOCaml(result.ir), + lean: irToLean(result.ir), }; } const errorJson = JSON.stringify( @@ -55,6 +61,8 @@ export function useExpressionOutput( return { ir: errorJson, sympy: `# Error: ${result.error}`, + ocaml: `(* Error: ${result.error} *)`, + lean: `-- Error: ${result.error}`, }; }, [ showExpressionOutput, From 526e34e8d2cbdf683522e2c23235dc8745c91061 Mon Sep 17 00:00:00 2001 From: Chris Feijoo Date: Wed, 18 Mar 2026 20:17:32 +0100 Subject: [PATCH 10/14] FE-514: Add syntax highlighting to expression output panel Import Python and F# language contributions in the Monaco provider. Map output formats to Monaco languages: JSON for IR, Python for SymPy, F# for OCaml and Lean. Use key-based remount to ensure language switches correctly when the user changes format. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../petrinaut/src/expression/expression-output-panel.tsx | 9 +++++++++ libs/@hashintel/petrinaut/src/monaco/provider.tsx | 6 ++++++ 2 files changed, 15 insertions(+) diff --git a/libs/@hashintel/petrinaut/src/expression/expression-output-panel.tsx b/libs/@hashintel/petrinaut/src/expression/expression-output-panel.tsx index f293fb793a4..15dac71dca5 100644 --- a/libs/@hashintel/petrinaut/src/expression/expression-output-panel.tsx +++ b/libs/@hashintel/petrinaut/src/expression/expression-output-panel.tsx @@ -8,6 +8,13 @@ import type { ExpressionOutputFormat, } from "./use-expression-ir-output"; +const FORMAT_LANGUAGE: Record = { + ir: "json", + sympy: "python", + ocaml: "fsharp", + lean: "fsharp", +}; + const containerStyle = css({ position: "relative", height: "full", @@ -61,7 +68,9 @@ export const ExpressionOutputPanel: React.FC<{ />
diff --git a/libs/@hashintel/petrinaut/src/monaco/provider.tsx b/libs/@hashintel/petrinaut/src/monaco/provider.tsx index 0eff5ff7995..a76a7be44b4 100644 --- a/libs/@hashintel/petrinaut/src/monaco/provider.tsx +++ b/libs/@hashintel/petrinaut/src/monaco/provider.tsx @@ -19,6 +19,12 @@ async function initMonaco(): Promise { import( "monaco-editor/esm/vs/basic-languages/typescript/typescript.contribution.js" ), + import( + "monaco-editor/esm/vs/basic-languages/python/python.contribution.js" + ), + import( + "monaco-editor/esm/vs/basic-languages/fsharp/fsharp.contribution.js" + ), ]); window.MonacoEnvironment = { From 9ad7e48ee75ff96265256c5dcea463aba2effd69 Mon Sep 17 00:00:00 2001 From: Chris Feijoo Date: Wed, 18 Mar 2026 20:21:06 +0100 Subject: [PATCH 11/14] FE-514: Add expression output panel to differential equation properties Extract shared `formatIRResult` and `useCompileToOutput` helpers from the expression output hook. Add `useDiffEqExpressionOutput` for differential equations using `buildContextForDifferentialEquation`. Show the same resizable split panel with format selector in the differential equation Code section when the setting is enabled. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../expression/use-expression-ir-output.ts | 109 ++++++++++++------ .../subviews/main.tsx | 93 ++++++++++++--- 2 files changed, 148 insertions(+), 54 deletions(-) diff --git a/libs/@hashintel/petrinaut/src/expression/use-expression-ir-output.ts b/libs/@hashintel/petrinaut/src/expression/use-expression-ir-output.ts index 9f1cde1feca..fdbc5f70938 100644 --- a/libs/@hashintel/petrinaut/src/expression/use-expression-ir-output.ts +++ b/libs/@hashintel/petrinaut/src/expression/use-expression-ir-output.ts @@ -1,12 +1,14 @@ import { use, useMemo } from "react"; -import type { Transition } from "../core/types/sdcpn"; +import type { DifferentialEquation, Transition } from "../core/types/sdcpn"; import { SDCPNContext } from "../state/sdcpn-context"; import { UserSettingsContext } from "../state/user-settings-context"; import { irToLean } from "./ir-to-lean/ir-to-lean"; import { irToOCaml } from "./ir-to-ocaml/ir-to-ocaml"; import { irToSymPy } from "./ir-to-sympy/ir-to-sympy"; +import type { CompilationContext, IRResult } from "./ts-to-ir/compile-to-ir"; import { + buildContextForDifferentialEquation, buildContextForTransition, compileToIR, } from "./ts-to-ir/compile-to-ir"; @@ -20,9 +22,42 @@ export type ExpressionOutput = { lean: string; }; +function formatIRResult(result: IRResult): ExpressionOutput { + if (result.ok) { + return { + ir: JSON.stringify(result.ir, null, 2), + sympy: irToSymPy(result.ir), + ocaml: irToOCaml(result.ir), + lean: irToLean(result.ir), + }; + } + const errorJson = JSON.stringify( + { error: result.error, start: result.start, length: result.length }, + null, + 2, + ); + return { + ir: errorJson, + sympy: `# Error: ${result.error}`, + ocaml: `(* Error: ${result.error} *)`, + lean: `-- Error: ${result.error}`, + }; +} + +function useCompileToOutput( + code: string, + ctx: CompilationContext, + enabled: boolean, +): ExpressionOutput | null { + return useMemo(() => { + if (!enabled) return null; + return formatIRResult(compileToIR(code, ctx)); + }, [enabled, code, ctx]); +} + /** - * Compiles a transition's code to expression IR and SymPy, returning - * both formatted strings, or `null` when the setting is disabled. + * Compiles a transition's code to all output formats, + * or `null` when the setting is disabled. */ export function useExpressionOutput( transition: Transition, @@ -36,39 +71,41 @@ export function useExpressionOutput( ? transition.lambdaCode : transition.transitionKernelCode; - return useMemo(() => { - if (!showExpressionOutput) return null; + const ctx = useMemo( + () => + buildContextForTransition( + petriNetDefinition, + transition, + constructorFnName, + ), + [petriNetDefinition, transition, constructorFnName], + ); - const ctx = buildContextForTransition( - petriNetDefinition, - transition, - constructorFnName, - ); - const result = compileToIR(code, ctx); - if (result.ok) { - return { - ir: JSON.stringify(result.ir, null, 2), - sympy: irToSymPy(result.ir), - ocaml: irToOCaml(result.ir), - lean: irToLean(result.ir), - }; - } - const errorJson = JSON.stringify( - { error: result.error, start: result.start, length: result.length }, - null, - 2, - ); - return { - ir: errorJson, - sympy: `# Error: ${result.error}`, - ocaml: `(* Error: ${result.error} *)`, - lean: `-- Error: ${result.error}`, - }; - }, [ + return useCompileToOutput(code, ctx, showExpressionOutput); +} + +/** + * Compiles a differential equation's code to all output formats, + * or `null` when the setting is disabled. + */ +export function useDiffEqExpressionOutput( + differentialEquation: DifferentialEquation, +): ExpressionOutput | null { + const { showExpressionOutput } = use(UserSettingsContext); + const { petriNetDefinition } = use(SDCPNContext); + + const ctx = useMemo( + () => + buildContextForDifferentialEquation( + petriNetDefinition, + differentialEquation.colorId, + ), + [petriNetDefinition, differentialEquation.colorId], + ); + + return useCompileToOutput( + differentialEquation.code, + ctx, showExpressionOutput, - petriNetDefinition, - transition, - constructorFnName, - code, - ]); + ); } diff --git a/libs/@hashintel/petrinaut/src/views/Editor/panels/PropertiesPanel/differential-equation-properties/subviews/main.tsx b/libs/@hashintel/petrinaut/src/views/Editor/panels/PropertiesPanel/differential-equation-properties/subviews/main.tsx index 7be7e63cecb..0b7ef40a23e 100644 --- a/libs/@hashintel/petrinaut/src/views/Editor/panels/PropertiesPanel/differential-equation-properties/subviews/main.tsx +++ b/libs/@hashintel/petrinaut/src/views/Editor/panels/PropertiesPanel/differential-equation-properties/subviews/main.tsx @@ -1,5 +1,6 @@ import { css } from "@hashintel/ds-helpers/css"; import { useState } from "react"; +import { Group, Panel, Separator } from "react-resizable-panels"; import { TbDotsVertical, TbSparkles } from "react-icons/tb"; import { Button } from "../../../../../../components/button"; @@ -16,6 +17,8 @@ import { DEFAULT_DIFFERENTIAL_EQUATION_CODE, generateDefaultDifferentialEquationCode, } from "../../../../../../core/default-codes"; +import { ExpressionOutputPanel } from "../../../../../../expression/expression-output-panel"; +import { useDiffEqExpressionOutput } from "../../../../../../expression/use-expression-ir-output"; import { CodeEditor } from "../../../../../../monaco/code-editor"; import { getDocumentUri } from "../../../../../../monaco/editor-paths"; import { useIsReadOnly } from "../../../../../../state/use-is-read-only"; @@ -90,12 +93,36 @@ const aiIconStyle = css({ fontSize: "base", }); +const panelGroupStyle = css({ + flex: "[1]", + minHeight: "[0]", +}); + +const panelStyle = css({ + height: "full", +}); + +const resizeHandleStyle = css({ + borderTopWidth: "thin", + borderTopColor: "neutral.a20", + cursor: "ns-resize", + backgroundColor: "[transparent]", + transition: "[background-color 0.15s ease]", + "&[data-separator=hover]": { + backgroundColor: "neutral.a40", + }, + "&[data-separator=active]": { + backgroundColor: "blue.s60", + }, +}); + const DiffEqMainContent: React.FC = () => { const { differentialEquation, types, places, updateDifferentialEquation } = useDiffEqPropertiesContext(); const [showConfirmDialog, setShowConfirmDialog] = useState(false); const [pendingTypeId, setPendingTypeId] = useState(null); const isReadOnly = useIsReadOnly(); + const expressionOutput = useDiffEqExpressionOutput(differentialEquation); const placesUsingEquation = places.filter((place) => { if (!place.differentialEquationId) { @@ -258,25 +285,55 @@ const DiffEqMainContent: React.FC = () => { )}
- { - updateDifferentialEquation( + {expressionOutput !== null ? ( + + + { + updateDifferentialEquation( + differentialEquation.id, + (existingEquation) => { + existingEquation.code = newCode ?? ""; + }, + ); + }} + options={{ readOnly: isReadOnly }} + tooltip={isReadOnly ? UI_MESSAGES.READ_ONLY_MODE : undefined} + /> + + + + + + + ) : ( + { - existingEquation.code = newCode ?? ""; - }, - ); - }} - options={{ readOnly: isReadOnly }} - tooltip={isReadOnly ? UI_MESSAGES.READ_ONLY_MODE : undefined} - /> + )} + language="typescript" + value={differentialEquation.code} + height="100%" + onChange={(newCode) => { + updateDifferentialEquation( + differentialEquation.id, + (existingEquation) => { + existingEquation.code = newCode ?? ""; + }, + ); + }} + options={{ readOnly: isReadOnly }} + tooltip={isReadOnly ? UI_MESSAGES.READ_ONLY_MODE : undefined} + /> + )}
); From b8e3990bd93fd5321e83dde7a62697ea87fdfd68 Mon Sep 17 00:00:00 2001 From: Chris Feijoo Date: Wed, 18 Mar 2026 20:27:40 +0100 Subject: [PATCH 12/14] FE-514: Fix .map() compilation to use proper variable names MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Destructured .map(({ x, y }) => ...) now emits propertyAccess nodes (token.x, token.y) instead of prefixed symbols (_iter_x, _iter_y). The iteration variable is derived from the collection name by singularizing it (tokens → token). Simple .map((token) => ...) uses the user's parameter name directly. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../expression/ts-to-ir/compile-to-ir.test.ts | 14 ++++-- .../src/expression/ts-to-ir/compile-to-ir.ts | 47 +++++++++++++++++-- 2 files changed, 52 insertions(+), 9 deletions(-) diff --git a/libs/@hashintel/petrinaut/src/expression/ts-to-ir/compile-to-ir.test.ts b/libs/@hashintel/petrinaut/src/expression/ts-to-ir/compile-to-ir.test.ts index 26a36f8b3f4..745fb7a1757 100644 --- a/libs/@hashintel/petrinaut/src/expression/ts-to-ir/compile-to-ir.test.ts +++ b/libs/@hashintel/petrinaut/src/expression/ts-to-ir/compile-to-ir.test.ts @@ -361,7 +361,7 @@ describe("compileToIR", () => { }); describe(".map() list comprehension", () => { - it("should compile .map with destructured params", () => { + it("should compile .map with destructured params as propertyAccess", () => { const result = compileToIR( `export default Lambda((tokens, parameters) => tokens.map(({ x }) => x * parameters.infection_rate))`, defaultContext, @@ -370,12 +370,16 @@ describe("compileToIR", () => { if (result.ok) { expect(result.ir).toEqual({ type: "listComprehension", - variable: "_iter", + variable: "token", collection: { type: "symbol", name: "tokens" }, body: { type: "binary", op: "*", - left: { type: "symbol", name: "_iter_x" }, + left: { + type: "propertyAccess", + object: { type: "symbol", name: "token" }, + property: "x", + }, right: { type: "parameter", name: "infection_rate" }, }, }); @@ -391,9 +395,9 @@ describe("compileToIR", () => { if (result.ok) { expect(result.ir).toEqual({ type: "listComprehension", - variable: "_iter", + variable: "token", collection: { type: "symbol", name: "tokens" }, - body: { type: "symbol", name: "_iter" }, + body: { type: "symbol", name: "token" }, }); } }); diff --git a/libs/@hashintel/petrinaut/src/expression/ts-to-ir/compile-to-ir.ts b/libs/@hashintel/petrinaut/src/expression/ts-to-ir/compile-to-ir.ts index 97f93192ecb..41d6a2da0d3 100644 --- a/libs/@hashintel/petrinaut/src/expression/ts-to-ir/compile-to-ir.ts +++ b/libs/@hashintel/petrinaut/src/expression/ts-to-ir/compile-to-ir.ts @@ -285,6 +285,21 @@ function compileBlockToIR( /** * Compiles `collection.map(callback)` to a list comprehension IR node. */ +/** + * Derives a singular iteration variable name from the collection expression. + * `tokens` → `token`, `items` → `item`, fallback → `item`. + */ +function deriveIterVarName( + collection: ts.Expression, + sourceFile: ts.SourceFile, +): string { + const text = collection.getText(sourceFile); + if (text.endsWith("s") && text.length > 1) { + return text.slice(0, -1); + } + return "item"; +} + function compileMapCallToIR( collection: ts.Expression, callback: ts.ArrowFunction | ts.FunctionExpression, @@ -292,7 +307,6 @@ function compileMapCallToIR( outerScope: Scope, sourceFile: ts.SourceFile, ): IRResult { - const iterVar = "_iter"; const mapScope: Scope = { localBindingNames: new Set(outerScope.localBindingNames), symbolOverrides: new Map(outerScope.symbolOverrides), @@ -300,16 +314,28 @@ function compileMapCallToIR( }; const param = callback.parameters[0]; + let iterVar: string; + if (param) { const paramName = param.name; if (ts.isObjectBindingPattern(paramName)) { + // Destructured: ({ x, y }) => ... + // Use a derived name and map fields to propertyAccess on it + iterVar = deriveIterVarName(collection, sourceFile); for (const element of paramName.elements) { const fieldName = element.name.getText(sourceFile); - mapScope.symbolOverrides.set(fieldName, `${iterVar}_${fieldName}`); + mapScope.symbolOverrides.set( + fieldName, + `\0propAccess:${iterVar}:${fieldName}`, + ); } } else { - mapScope.symbolOverrides.set(paramName.getText(sourceFile), iterVar); + // Simple: (token) => ... + iterVar = paramName.getText(sourceFile); + mapScope.symbolOverrides.set(iterVar, iterVar); } + } else { + iterVar = deriveIterVarName(collection, sourceFile); } // Compile the body @@ -529,9 +555,22 @@ function emitIR( return { ok: true, ir: { type: "infinity" } }; } if (scope.symbolOverrides.has(name)) { + const override = scope.symbolOverrides.get(name)!; + // Destructured .map() fields are encoded as propertyAccess sentinels + if (override.startsWith("\0propAccess:")) { + const [, objName, property] = override.split(":"); + return { + ok: true, + ir: { + type: "propertyAccess", + object: { type: "symbol", name: objName! }, + property: property!, + }, + }; + } return { ok: true, - ir: { type: "symbol", name: scope.symbolOverrides.get(name)! }, + ir: { type: "symbol", name: override }, }; } if (scope.localBindingNames.has(name)) { From a06df40213b5ed8f069751aae97a281eddde0beb Mon Sep 17 00:00:00 2001 From: Chris Feijoo Date: Wed, 18 Mar 2026 20:34:49 +0100 Subject: [PATCH 13/14] Update bottom picker --- .../petrinaut/src/expression/expression-output-panel.tsx | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/libs/@hashintel/petrinaut/src/expression/expression-output-panel.tsx b/libs/@hashintel/petrinaut/src/expression/expression-output-panel.tsx index 15dac71dca5..4e68645a84d 100644 --- a/libs/@hashintel/petrinaut/src/expression/expression-output-panel.tsx +++ b/libs/@hashintel/petrinaut/src/expression/expression-output-panel.tsx @@ -25,13 +25,14 @@ const selectContainerStyle = css({ alignItems: "center", gap: "1", position: "absolute", - bottom: "0", - right: "0", + bottom: "[1px]", + right: "[1px]", zIndex: "[10]", backdropFilter: "[blur(20px)]", p: "1", pl: "2", borderTopLeftRadius: "sm", + borderBottomRightRadius: "sm", }); const selectLabelStyle = css({ From 5026f2fc58c7dabe53046361810d659e1aba1719 Mon Sep 17 00:00:00 2001 From: Chris Feijoo Date: Wed, 18 Mar 2026 20:38:13 +0100 Subject: [PATCH 14/14] FE-514: Mark expression output setting as experimental Co-Authored-By: Claude Opus 4.6 (1M context) --- .../views/SDCPN/components/viewport-settings-dialog.tsx | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/libs/@hashintel/petrinaut/src/views/SDCPN/components/viewport-settings-dialog.tsx b/libs/@hashintel/petrinaut/src/views/SDCPN/components/viewport-settings-dialog.tsx index d7108d7532a..e0dacaa6ded 100644 --- a/libs/@hashintel/petrinaut/src/views/SDCPN/components/viewport-settings-dialog.tsx +++ b/libs/@hashintel/petrinaut/src/views/SDCPN/components/viewport-settings-dialog.tsx @@ -170,7 +170,12 @@ export const ViewportSettingsDialog: React.FC = ({ /> + Expression output{" "} + Experimental + + } description="Show a read-only panel with the compiled expression IR next to code editors" >