diff --git a/libs/@hashintel/petrinaut/src/lsp/lib/checker.ts b/libs/@hashintel/petrinaut/src/lsp/lib/checker.ts index 3c5029c5fe3..375a488b7e2 100644 --- a/libs/@hashintel/petrinaut/src/lsp/lib/checker.ts +++ b/libs/@hashintel/petrinaut/src/lsp/lib/checker.ts @@ -1,6 +1,12 @@ import type ts from "typescript"; import type { SDCPN } from "../../core/types/sdcpn"; +import { + buildContextForDifferentialEquation, + buildContextForTransition, + compileToSymPy, + type SymPyResult, +} from "../../simulation/simulator/compile-to-sympy"; import type { SDCPNLanguageServer } from "./create-sdcpn-language-service"; import { getItemFilePath } from "./file-paths"; @@ -27,6 +33,113 @@ export type SDCPNCheckResult = { itemDiagnostics: SDCPNDiagnostic[]; }; +/** + * Creates a synthetic ts.Diagnostic from a SymPy compilation error result. + * Uses category 0 (Warning) since SymPy compilation failures are informational + * — the TypeScript code may still be valid, just not convertible to SymPy. + */ +function makeSymPyDiagnostic( + result: SymPyResult & { ok: false }, +): ts.Diagnostic { + return { + category: 0, // Warning + code: 99000, // Custom code for SymPy diagnostics + messageText: `SymPy: ${result.error}`, + file: undefined, + start: result.start, + length: result.length, + }; +} + +/** + * Appends a SymPy diagnostic to the item diagnostics list, merging with + * any existing entry for the same item. + */ +function appendSymPyDiagnostic( + itemDiagnostics: SDCPNDiagnostic[], + itemId: string, + itemType: ItemType, + filePath: string, + result: SymPyResult & { ok: false }, +): void { + const diag = makeSymPyDiagnostic(result); + const existing = itemDiagnostics.find( + (di) => di.itemId === itemId && di.itemType === itemType, + ); + if (existing) { + existing.diagnostics.push(diag); + } else { + itemDiagnostics.push({ itemId, itemType, filePath, diagnostics: [diag] }); + } +} + +/** + * Runs SymPy compilation on all SDCPN code expressions and appends + * any errors as warning diagnostics. + */ +function checkSymPyCompilation( + sdcpn: SDCPN, + itemDiagnostics: SDCPNDiagnostic[], +): void { + // Check differential equations + for (const de of sdcpn.differentialEquations) { + const ctx = buildContextForDifferentialEquation(sdcpn, de.colorId); + const result = compileToSymPy(de.code, ctx); + if (!result.ok) { + const filePath = getItemFilePath("differential-equation-code", { + id: de.id, + }); + appendSymPyDiagnostic( + itemDiagnostics, + de.id, + "differential-equation", + filePath, + result, + ); + } + } + + // Check transition lambdas and kernels + for (const transition of sdcpn.transitions) { + const lambdaCtx = buildContextForTransition(sdcpn, transition, "Lambda"); + const lambdaResult = compileToSymPy(transition.lambdaCode, lambdaCtx); + if (!lambdaResult.ok) { + const filePath = getItemFilePath("transition-lambda-code", { + transitionId: transition.id, + }); + appendSymPyDiagnostic( + itemDiagnostics, + transition.id, + "transition-lambda", + filePath, + lambdaResult, + ); + } + + const kernelCtx = buildContextForTransition( + sdcpn, + transition, + "TransitionKernel", + ); + const kernelResult = compileToSymPy( + transition.transitionKernelCode, + kernelCtx, + ); + if (!kernelResult.ok) { + const filePath = getItemFilePath("transition-kernel-code", { + transitionId: transition.id, + }); + appendSymPyDiagnostic( + itemDiagnostics, + transition.id, + "transition-kernel", + filePath, + kernelResult, + ); + } + } +} + /** * Checks the validity of an SDCPN by running TypeScript validation * on all user-provided code (transitions and differential equations). @@ -111,6 +224,9 @@ export function checkSDCPN( } } + // Run SymPy compilation checks on all code expressions + checkSymPyCompilation(sdcpn, itemDiagnostics); + return { isValid: itemDiagnostics.length === 0, itemDiagnostics, diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts new file mode 100644 index 00000000000..0ae9e4eef5d --- /dev/null +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.test.ts @@ -0,0 +1,792 @@ +import { describe, expect, it } from "vitest"; + +import { + compileToSymPy, + type SymPyCompilationContext, +} from "./compile-to-sympy"; + +const defaultContext: SymPyCompilationContext = { + parameterNames: new Set([ + "infection_rate", + "recovery_rate", + "gravitational_constant", + "earth_radius", + "satellite_radius", + "crash_threshold", + ]), + placeTokenFields: new Map([ + ["Space", ["x", "y", "direction", "velocity"]], + ["Susceptible", []], + ["Infected", []], + ]), + constructorFnName: "Lambda", +}; + +function dynamicsContext(): SymPyCompilationContext { + return { ...defaultContext, constructorFnName: "Dynamics" }; +} + +function kernelContext(): SymPyCompilationContext { + return { ...defaultContext, constructorFnName: "TransitionKernel" }; +} + +describe("compileToSymPy", () => { + describe("basic expressions", () => { + it("should compile a numeric literal", () => { + const result = compileToSymPy( + "export default Lambda(() => 1)", + defaultContext, + ); + expect(result).toEqual({ ok: true, sympyCode: "1" }); + }); + + it("should compile a decimal literal", () => { + const result = compileToSymPy( + "export default Lambda(() => 3.14)", + defaultContext, + ); + expect(result).toEqual({ ok: true, sympyCode: "3.14" }); + }); + + it("should compile boolean true", () => { + const result = compileToSymPy( + "export default Lambda(() => true)", + defaultContext, + ); + expect(result).toEqual({ ok: true, sympyCode: "True" }); + }); + + it("should compile boolean false", () => { + const result = compileToSymPy( + "export default Lambda(() => false)", + defaultContext, + ); + expect(result).toEqual({ ok: true, sympyCode: "False" }); + }); + + it("should compile Infinity", () => { + const result = compileToSymPy( + "export default Lambda(() => Infinity)", + defaultContext, + ); + expect(result).toEqual({ ok: true, sympyCode: "sp.oo" }); + }); + }); + + describe("parameter access", () => { + it("should compile parameters.x to symbol x", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => parameters.infection_rate)", + defaultContext, + ); + expect(result).toEqual({ ok: true, sympyCode: "infection_rate" }); + }); + + it("should compile parameters in arithmetic", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => parameters.infection_rate * 2)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "infection_rate * 2", + }); + }); + }); + + describe("binary arithmetic", () => { + it("should compile addition", () => { + const result = compileToSymPy( + "export default Lambda(() => 1 + 2)", + defaultContext, + ); + expect(result).toEqual({ ok: true, sympyCode: "1 + 2" }); + }); + + it("should compile subtraction", () => { + const result = compileToSymPy( + "export default Lambda(() => 5 - 3)", + defaultContext, + ); + expect(result).toEqual({ ok: true, sympyCode: "5 - 3" }); + }); + + it("should compile multiplication", () => { + const result = compileToSymPy( + "export default Lambda(() => 2 * 3)", + defaultContext, + ); + expect(result).toEqual({ ok: true, sympyCode: "2 * 3" }); + }); + + it("should compile division", () => { + const result = compileToSymPy( + "export default Lambda(() => 1 / 3)", + defaultContext, + ); + expect(result).toEqual({ ok: true, sympyCode: "1 / 3" }); + }); + + it("should compile power operator", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => parameters.satellite_radius ** 2)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "satellite_radius**2", + }); + }); + + it("should compile modulo", () => { + const result = compileToSymPy( + "export default Lambda(() => 10 % 3)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.Mod(10, 3)", + }); + }); + }); + + describe("comparison operators", () => { + it("should compile less than", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => parameters.infection_rate < 5)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "infection_rate < 5", + }); + }); + + it("should compile greater than or equal", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => parameters.infection_rate >= 1)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "infection_rate >= 1", + }); + }); + + it("should compile strict equality to Eq", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => parameters.infection_rate === 3)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.Eq(infection_rate, 3)", + }); + }); + + it("should compile inequality to Ne", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => parameters.infection_rate !== 0)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.Ne(infection_rate, 0)", + }); + }); + }); + + describe("logical operators", () => { + it("should compile && to sp.And", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => parameters.infection_rate > 0 && parameters.recovery_rate > 0)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.And(infection_rate > 0, recovery_rate > 0)", + }); + }); + + it("should compile || to sp.Or", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => parameters.infection_rate === 0 || parameters.recovery_rate === 0)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.Or(sp.Eq(infection_rate, 0), sp.Eq(recovery_rate, 0))", + }); + }); + }); + + describe("prefix unary operators", () => { + it("should compile negation", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => -parameters.infection_rate)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "-(infection_rate)", + }); + }); + + it("should compile logical not", () => { + const result = compileToSymPy( + "export default Lambda(() => !true)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.Not(True)", + }); + }); + }); + + describe("Math functions", () => { + it("should compile Math.cos", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => Math.cos(parameters.infection_rate))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.cos(infection_rate)", + }); + }); + + it("should compile Math.sin", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => Math.sin(parameters.infection_rate))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.sin(infection_rate)", + }); + }); + + it("should compile Math.sqrt", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => Math.sqrt(parameters.infection_rate))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.sqrt(infection_rate)", + }); + }); + + it("should compile Math.log", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => Math.log(parameters.infection_rate))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.log(infection_rate)", + }); + }); + + it("should compile Math.exp", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => Math.exp(parameters.infection_rate))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.exp(infection_rate)", + }); + }); + + it("should compile Math.abs", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => Math.abs(parameters.infection_rate))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.Abs(infection_rate)", + }); + }); + + it("should compile Math.pow to exponentiation", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => Math.pow(parameters.infection_rate, 2))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "(infection_rate)**(2)", + }); + }); + + it("should compile Math.hypot to sqrt of sum of squares", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => Math.hypot(parameters.infection_rate, parameters.recovery_rate))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.sqrt((infection_rate)**2 + (recovery_rate)**2)", + }); + }); + + it("should compile Math.PI", () => { + const result = compileToSymPy( + "export default Lambda(() => Math.PI)", + defaultContext, + ); + expect(result).toEqual({ ok: true, sympyCode: "sp.pi" }); + }); + + it("should compile Math.E", () => { + const result = compileToSymPy( + "export default Lambda(() => Math.E)", + defaultContext, + ); + expect(result).toEqual({ ok: true, sympyCode: "sp.E" }); + }); + }); + + describe("token access", () => { + it("should compile tokens.Place[0].field to symbol", () => { + const result = compileToSymPy( + "export default Lambda((tokens) => tokens.Space[0].x)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "Space_0_x", + }); + }); + + it("should compile token field comparison", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => tokens.Space[0].velocity < parameters.crash_threshold)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "Space_0_velocity < crash_threshold", + }); + }); + }); + + describe("conditional (ternary) expression", () => { + it("should compile to Piecewise", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => parameters.infection_rate > 1 ? parameters.infection_rate : 0)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: + "sp.Piecewise((infection_rate, infection_rate > 1), (0, True))", + }); + }); + }); + + describe("Distribution calls", () => { + it("should compile Distribution.Gaussian", () => { + const result = compileToSymPy( + "export default Lambda(() => Distribution.Gaussian(0, 1))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.stats.Normal('X', 0, 1)", + }); + }); + + it("should compile Distribution.Uniform", () => { + const result = compileToSymPy( + "export default Lambda(() => Distribution.Uniform(0, 1))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.stats.Uniform('X', 0, 1)", + }); + }); + + it("should compile Distribution.Lognormal", () => { + const result = compileToSymPy( + "export default Lambda(() => Distribution.Lognormal(0, 1))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.stats.LogNormal('X', 0, 1)", + }); + }); + }); + + describe("global built-in functions", () => { + it("should compile Boolean(expr) to sp.Ne(expr, 0)", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => Boolean(parameters.infection_rate))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.Ne(infection_rate, 0)", + }); + }); + + it("should compile Boolean with arithmetic expression", () => { + const result = compileToSymPy( + "export default Lambda(() => Boolean(1 + 2))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "sp.Ne(1 + 2, 0)", + }); + }); + + it("should compile Number(expr) as identity", () => { + const result = compileToSymPy( + "export default Lambda(() => Number(true))", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "True", + }); + }); + + it("should compile Boolean in block body with return", () => { + const result = compileToSymPy( + `export default Lambda((tokens, parameters) => { + const sum = parameters.infection_rate + parameters.recovery_rate; + return Boolean(sum); + })`, + defaultContext, + ); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.sympyCode).toContain("sp.Ne"); + } + }); + }); + + describe("block body with const and return", () => { + it("should compile block body with const bindings", () => { + const result = compileToSymPy( + `export default Dynamics((tokens, parameters) => { + const mu = parameters.gravitational_constant; + return mu * 2; + })`, + dynamicsContext(), + ); + expect(result).toEqual({ + ok: true, + sympyCode: "gravitational_constant * 2", + }); + }); + + it("should compile block body with multiple const bindings", () => { + const result = compileToSymPy( + `export default Dynamics((tokens, parameters) => { + const a = parameters.infection_rate; + const b = parameters.recovery_rate; + return a + b; + })`, + dynamicsContext(), + ); + expect(result).toEqual({ + ok: true, + sympyCode: "infection_rate + recovery_rate", + }); + }); + }); + + describe("real-world expressions", () => { + it("should compile SIR infection rate lambda", () => { + const result = compileToSymPy( + "export default Lambda((tokens, parameters) => parameters.infection_rate)", + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "infection_rate", + }); + }); + + it("should compile satellite crash predicate lambda", () => { + const result = compileToSymPy( + `export default Lambda((tokens, parameters) => { + const distance = Math.hypot(tokens.Space[0].x, tokens.Space[0].y); + return distance < parameters.earth_radius + parameters.crash_threshold + parameters.satellite_radius; + })`, + defaultContext, + ); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.sympyCode).toContain("sp.sqrt"); + expect(result.sympyCode).toContain("<"); + expect(result.sympyCode).toContain("earth_radius"); + } + }); + + it("should compile orbital dynamics expression", () => { + const result = compileToSymPy( + `export default Dynamics((tokens, parameters) => { + const mu = parameters.gravitational_constant; + const r = Math.hypot(tokens.Space[0].x, tokens.Space[0].y); + const ax = (-mu * tokens.Space[0].x) / (r * r * r); + return ax; + })`, + dynamicsContext(), + ); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.sympyCode).toContain("gravitational_constant"); + expect(result.sympyCode).toContain("Space_0_x"); + } + }); + + it("should compile transition kernel with object literal", () => { + const result = compileToSymPy( + `export default TransitionKernel((tokens) => { + return { + x: tokens.Space[0].x, + y: tokens.Space[0].y, + velocity: 0, + direction: 0 + }; + })`, + kernelContext(), + ); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.sympyCode).toContain("'x': Space_0_x"); + expect(result.sympyCode).toContain("'y': Space_0_y"); + expect(result.sympyCode).toContain("'velocity': 0"); + } + }); + + it("should compile transition kernel with array of objects", () => { + const result = compileToSymPy( + `export default TransitionKernel((tokens) => { + return { + Debris: [ + { + x: tokens.Space[0].x, + y: tokens.Space[0].y, + velocity: 0, + direction: 0 + }, + { + x: tokens.Space[1].x, + y: tokens.Space[1].y, + velocity: 0, + direction: 0 + }, + ] + }; + })`, + kernelContext(), + ); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.sympyCode).toContain("'Debris': ["); + expect(result.sympyCode).toContain("'x': Space_0_x"); + expect(result.sympyCode).toContain("'x': Space_1_x"); + } + }); + + it("should compile simple array literal", () => { + const result = compileToSymPy( + "export default Lambda(() => [1, 2, 3])", + defaultContext, + ); + expect(result).toEqual({ ok: true, sympyCode: "[1, 2, 3]" }); + }); + }); + + describe(".map() list comprehension", () => { + it("should compile tokens.map with destructured params", () => { + const result = compileToSymPy( + `export default Dynamics((tokens, parameters) => { + const mu = parameters.gravitational_constant; + return tokens.map(({ x, y, direction, velocity }) => { + const r = Math.hypot(x, y); + const ax = (-mu * x) / (r * r * r); + const ay = (-mu * y) / (r * r * r); + return { + x: velocity * Math.cos(direction), + y: velocity * Math.sin(direction), + direction: (-ax * Math.sin(direction) + ay * Math.cos(direction)) / velocity, + velocity: ax * Math.cos(direction) + ay * Math.sin(direction), + }; + }); + })`, + dynamicsContext(), + ); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.sympyCode).toContain("for _iter in tokens"); + expect(result.sympyCode).toContain("_iter_x"); + expect(result.sympyCode).toContain("_iter_velocity"); + expect(result.sympyCode).toContain("sp.cos(_iter_direction)"); + } + }); + + it("should compile simple .map with identifier param", () => { + const result = compileToSymPy( + `export default Lambda((tokens) => tokens.map((token) => token + 1))`, + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "[_iter + 1 for _iter in tokens]", + }); + }); + + it("should compile .map with expression body", () => { + const result = compileToSymPy( + `export default Lambda((tokens, parameters) => tokens.map(({ x }) => x * parameters.infection_rate))`, + defaultContext, + ); + expect(result).toEqual({ + ok: true, + sympyCode: "[_iter_x * infection_rate for _iter in tokens]", + }); + }); + }); + + describe("error handling", () => { + it("should reject code without default export", () => { + const result = compileToSymPy( + "const x = Lambda(() => 1);", + defaultContext, + ); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toContain("No default export"); + expect(result.start).toBe(0); + expect(result.length).toBe(0); + } + }); + + it("should reject wrong constructor function name with position", () => { + const code = "export default WrongName(() => 1)"; + const result = compileToSymPy(code, defaultContext); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toContain("Expected Lambda(...)"); + // "WrongName" starts at position 15 + expect(result.start).toBe(code.indexOf("WrongName")); + expect(result.length).toBe("WrongName".length); + } + }); + + it("should reject unsupported Math function with position", () => { + const code = "export default Lambda(() => Math.random())"; + const result = compileToSymPy(code, defaultContext); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toContain("Unsupported Math function"); + // Points to the "Math.random" callee + expect(result.start).toBe(code.indexOf("Math.random")); + expect(result.length).toBe("Math.random".length); + } + }); + + it("should reject if statements in block body with position", () => { + const code = `export default Lambda((tokens, parameters) => { + if (parameters.infection_rate > 1) { + return parameters.infection_rate; + } + return 0; + })`; + const result = compileToSymPy(code, defaultContext); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toContain("Unsupported statement"); + expect(result.start).toBe(code.indexOf("if")); + expect(result.length).toBeGreaterThan(0); + } + }); + + it("should reject let declarations with position", () => { + const code = `export default Lambda(() => { + let x = 1; + return x; + })`; + const result = compileToSymPy(code, defaultContext); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toContain("let"); + expect(result.start).toBe(code.indexOf("let x = 1;")); + expect(result.length).toBe("let x = 1;".length); + } + }); + + it("should reject var declarations", () => { + const code = `export default Lambda(() => { + var x = 1; + return x; + })`; + const result = compileToSymPy(code, defaultContext); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toContain("var"); + expect(result.error).toContain("use 'const'"); + } + }); + + it("should reject string literals with position", () => { + const code = `export default Lambda(() => "hello")`; + const result = compileToSymPy(code, defaultContext); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toContain("String literals"); + expect(result.start).toBe(code.indexOf('"hello"')); + expect(result.length).toBe('"hello"'.length); + } + }); + + it("should reject unsupported function calls with position", () => { + const code = `export default Lambda(() => console.log(1))`; + const result = compileToSymPy(code, defaultContext); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toContain("Unsupported function call"); + expect(result.start).toBe(code.indexOf("console.log(1)")); + expect(result.length).toBe("console.log(1)".length); + } + }); + + it("should reject standalone expression statements", () => { + const code = `export default Lambda((tokensByPlace, parameters) => { + const a = Boolean(1 + 2); + Boolean(1 + 2); + return Boolean(1 + 2); + })`; + const result = compileToSymPy(code, defaultContext); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toContain("Standalone expression has no effect"); + // The standalone expression is the second line in the block + const standalonePos = code.indexOf("\n Boolean(1 + 2);") + 11; + expect(result.start).toBe(standalonePos); + expect(result.length).toBe("Boolean(1 + 2);".length); + } + }); + + it("should reject unsupported binary operator with position", () => { + const code = `export default Lambda(() => 1 << 2)`; + const result = compileToSymPy(code, defaultContext); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toContain("Unsupported binary operator"); + expect(result.start).toBe(code.indexOf("<<")); + expect(result.length).toBe("<<".length); + } + }); + }); +}); diff --git a/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts new file mode 100644 index 00000000000..7d00e5a34ba --- /dev/null +++ b/libs/@hashintel/petrinaut/src/simulation/simulator/compile-to-sympy.ts @@ -0,0 +1,893 @@ +import ts from "typescript"; + +import type { SDCPN, Transition } from "../../core/types/sdcpn"; + +/** + * Context for SymPy compilation, derived from the SDCPN model. + * Tells the compiler which identifiers are parameters vs. token fields. + */ +export type SymPyCompilationContext = { + parameterNames: Set; + /** Maps place name to its token field names */ + placeTokenFields: Map; + constructorFnName: string; +}; + +/** + * Builds a SymPyCompilationContext from an SDCPN model for a given transition. + */ +export function buildContextForTransition( + sdcpn: SDCPN, + transition: Transition, + constructorFnName: string, +): SymPyCompilationContext { + const parameterNames = new Set( + sdcpn.parameters.map((param) => param.variableName), + ); + const placeTokenFields = new Map(); + + const placeById = new Map(sdcpn.places.map((pl) => [pl.id, pl])); + const colorById = new Map(sdcpn.types.map((ct) => [ct.id, ct])); + + for (const arc of transition.inputArcs) { + const place = placeById.get(arc.placeId); + if (!place?.colorId) { + continue; + } + const color = colorById.get(place.colorId); + if (!color) { + continue; + } + placeTokenFields.set( + place.name, + color.elements.map((el) => el.name), + ); + } + + return { parameterNames, placeTokenFields, constructorFnName }; +} + +/** + * Builds a SymPyCompilationContext from an SDCPN model for a differential equation. + */ +export function buildContextForDifferentialEquation( + sdcpn: SDCPN, + colorId: string, +): SymPyCompilationContext { + const parameterNames = new Set( + sdcpn.parameters.map((param) => param.variableName), + ); + const placeTokenFields = new Map(); + + const color = sdcpn.types.find((ct) => ct.id === colorId); + if (color) { + // DE operates on tokens of its color type + placeTokenFields.set( + color.name, + color.elements.map((el) => el.name), + ); + } + + return { parameterNames, placeTokenFields, constructorFnName: "Dynamics" }; +} + +export type SymPyResult = + | { ok: true; sympyCode: string } + | { ok: false; error: string; start: number; length: number }; + +/** Shorthand for building an error result with position from a TS AST node. */ +function err( + error: string, + node: ts.Node, + sourceFile: ts.SourceFile, +): SymPyResult & { ok: false } { + return { + ok: false, + error, + start: node.getStart(sourceFile), + length: node.getWidth(sourceFile), + }; +} + +/** Error result for cases where no specific node is available. */ +function errNoPos(error: string): SymPyResult & { ok: false } { + return { ok: false, error, start: 0, length: 0 }; +} + +/** + * Compiles a Petrinaut TypeScript expression to SymPy Python code. + * + * Expects code following the pattern: + * `export default ConstructorFn((params...) => expression)` + * + * Only a restricted subset of TypeScript is supported — pure expressions + * with arithmetic, Math functions, parameter/token access, and distributions. + * Anything outside this subset is rejected with a diagnostic. + * + * @param code - The TypeScript expression code string + * @param context - Compilation context with parameter names and token fields + * @returns Either `{ ok: true, sympyCode }` or `{ ok: false, error }` + */ +export function compileToSymPy( + code: string, + context: SymPyCompilationContext, +): SymPyResult { + const sourceFile = ts.createSourceFile( + "input.ts", + code, + ts.ScriptTarget.ES2015, + true, + ); + + // Find the default export + const exportAssignment = sourceFile.statements.find( + (stmt): stmt is ts.ExportAssignment => + ts.isExportAssignment(stmt) && !stmt.isExportEquals, + ); + + if (!exportAssignment) { + // Try export default as ExpressionStatement pattern + const exportDefault = sourceFile.statements.find((stmt) => { + if (ts.isExportAssignment(stmt)) { + return true; + } + // Handle "export default X(...)" which parses as ExportAssignment + return false; + }); + if (!exportDefault) { + return errNoPos("No default export found"); + } + } + + const exportExpr = exportAssignment!.expression; + + // Expect ConstructorFn(...) + if (!ts.isCallExpression(exportExpr)) { + return err( + `Expected ${context.constructorFnName}(...), got ${ts.SyntaxKind[exportExpr.kind]}`, + exportExpr, + sourceFile, + ); + } + + const callee = exportExpr.expression; + if (!ts.isIdentifier(callee) || callee.text !== context.constructorFnName) { + return err( + `Expected ${context.constructorFnName}(...), got ${callee.getText(sourceFile)}(...)`, + callee, + sourceFile, + ); + } + + if (exportExpr.arguments.length !== 1) { + return err( + `${context.constructorFnName} expects exactly one argument`, + exportExpr, + sourceFile, + ); + } + + const arg = exportExpr.arguments[0]!; + + // The argument should be an arrow function or function expression + if (!ts.isArrowFunction(arg) && !ts.isFunctionExpression(arg)) { + return err( + `Expected a function argument, got ${ts.SyntaxKind[arg.kind]}`, + arg, + sourceFile, + ); + } + + // Extract parameter names for the inner function + const localBindings = new Map(); + const innerParams = extractFunctionParams(arg, sourceFile); + + // Compile the body + const body = arg.body; + + if (ts.isBlock(body)) { + return compileBlock(body, context, localBindings, sourceFile); + } + + // Expression body — emit directly + const result = emitSymPy( + body, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!result.ok) return result; + return { ok: true, sympyCode: result.sympyCode }; +} + +function extractFunctionParams( + fn: ts.ArrowFunction | ts.FunctionExpression, + sourceFile: ts.SourceFile, +): string[] { + return fn.parameters.map((p) => p.name.getText(sourceFile)); +} + +function compileBlock( + block: ts.Block, + context: SymPyCompilationContext, + localBindings: Map, + sourceFile: ts.SourceFile, +): SymPyResult { + const lines: string[] = []; + + for (const stmt of block.statements) { + if (ts.isVariableStatement(stmt)) { + for (const decl of stmt.declarationList.declarations) { + if (!decl.initializer) { + return err( + "Variable declaration without initializer", + decl, + sourceFile, + ); + } + if (!(stmt.declarationList.flags & ts.NodeFlags.Const)) { + return err( + "'let' and 'var' declarations are not supported, use 'const'", + stmt, + sourceFile, + ); + } + const name = decl.name.getText(sourceFile); + const valueResult = emitSymPy( + decl.initializer, + context, + localBindings, + [], + sourceFile, + ); + if (!valueResult.ok) return valueResult; + localBindings.set(name, valueResult.sympyCode); + lines.push(`${name} = ${valueResult.sympyCode}`); + } + } else if (ts.isReturnStatement(stmt)) { + if (!stmt.expression) { + return err("Empty return statement", stmt, sourceFile); + } + const result = emitSymPy( + stmt.expression, + context, + localBindings, + [], + sourceFile, + ); + if (!result.ok) return result; + lines.push(result.sympyCode); + } else if (ts.isExpressionStatement(stmt)) { + return err( + "Standalone expression has no effect — assign to a const or return it", + stmt, + sourceFile, + ); + } else { + return err( + `Unsupported statement: ${ts.SyntaxKind[stmt.kind]}`, + stmt, + sourceFile, + ); + } + } + + if (lines.length === 0) { + return err("Empty function body", block, sourceFile); + } + + return { ok: true, sympyCode: lines[lines.length - 1]! }; +} + +/** + * Compiles `collection.map(callback)` to a Python list comprehension. + * + * Handles two callback parameter styles: + * - Destructured: `({ x, y }) => ...` → binds each field as `_iter_x`, `_iter_y` + * - Simple identifier: `(token) => ...` → binds as-is + * + * Emits: `[ for _iter in ]` + */ +function compileMapCall( + collection: ts.Expression, + callback: ts.ArrowFunction | ts.FunctionExpression, + context: SymPyCompilationContext, + outerBindings: Map, + innerParams: string[], + sourceFile: ts.SourceFile, +): SymPyResult { + const iterVar = "_iter"; + const mapBindings = new Map(outerBindings); + + const param = callback.parameters[0]; + if (param) { + const paramName = param.name; + if (ts.isObjectBindingPattern(paramName)) { + // Destructured: ({ x, y, ... }) => ... + // Each field becomes a symbol like _iter_x, _iter_y + for (const element of paramName.elements) { + const fieldName = element.name.getText(sourceFile); + mapBindings.set(fieldName, `${iterVar}_${fieldName}`); + } + } else { + // Simple identifier: (token) => ... + mapBindings.set(paramName.getText(sourceFile), iterVar); + } + } + + // Compile the body + const body = callback.body; + let bodyResult: SymPyResult; + if (ts.isBlock(body)) { + bodyResult = compileBlock(body, context, mapBindings, sourceFile); + } else { + bodyResult = emitSymPy(body, context, mapBindings, innerParams, sourceFile); + } + if (!bodyResult.ok) return bodyResult; + + // Compile the collection expression + const collectionResult = emitSymPy( + collection, + context, + outerBindings, + innerParams, + sourceFile, + ); + if (!collectionResult.ok) return collectionResult; + + return { + ok: true, + sympyCode: `[${bodyResult.sympyCode} for ${iterVar} in ${collectionResult.sympyCode}]`, + }; +} + +const MATH_FUNCTION_MAP: Record = { + cos: "sp.cos", + sin: "sp.sin", + tan: "sp.tan", + acos: "sp.acos", + asin: "sp.asin", + atan: "sp.atan", + atan2: "sp.atan2", + sqrt: "sp.sqrt", + log: "sp.log", + exp: "sp.exp", + abs: "sp.Abs", + floor: "sp.floor", + ceil: "sp.ceiling", + pow: "sp.Pow", + min: "sp.Min", + max: "sp.Max", +}; + +const MATH_CONSTANT_MAP: Record = { + PI: "sp.pi", + E: "sp.E", + Infinity: "sp.oo", +}; + +function emitSymPy( + node: ts.Node, + context: SymPyCompilationContext, + localBindings: Map, + innerParams: string[], + sourceFile: ts.SourceFile, +): SymPyResult { + // Numeric literal + if (ts.isNumericLiteral(node)) { + return { ok: true, sympyCode: node.text }; + } + + // String literal — not supported in symbolic math + if (ts.isStringLiteral(node)) { + return err( + "String literals are not supported in symbolic expressions", + node, + sourceFile, + ); + } + + // Boolean literals + if (node.kind === ts.SyntaxKind.TrueKeyword) { + return { ok: true, sympyCode: "True" }; + } + if (node.kind === ts.SyntaxKind.FalseKeyword) { + return { ok: true, sympyCode: "False" }; + } + + // Identifier + if (ts.isIdentifier(node)) { + const name = node.text; + if (name === "Infinity") return { ok: true, sympyCode: "sp.oo" }; + if (localBindings.has(name)) { + return { ok: true, sympyCode: localBindings.get(name)! }; + } + if (context.parameterNames.has(name)) { + return { ok: true, sympyCode: name }; + } + // Could be a destructured token field or function param + return { ok: true, sympyCode: name }; + } + + // Parenthesized expression + if (ts.isParenthesizedExpression(node)) { + const inner = emitSymPy( + node.expression, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!inner.ok) return inner; + return { ok: true, sympyCode: `(${inner.sympyCode})` }; + } + + // Prefix unary expression (-x, !x) + if (ts.isPrefixUnaryExpression(node)) { + const operand = emitSymPy( + node.operand, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!operand.ok) return operand; + + switch (node.operator) { + case ts.SyntaxKind.MinusToken: + return { ok: true, sympyCode: `-(${operand.sympyCode})` }; + case ts.SyntaxKind.ExclamationToken: + return { ok: true, sympyCode: `sp.Not(${operand.sympyCode})` }; + case ts.SyntaxKind.PlusToken: + return operand; + default: + return err( + `Unsupported prefix operator: ${ts.SyntaxKind[node.operator]}`, + node, + sourceFile, + ); + } + } + + // Binary expression + if (ts.isBinaryExpression(node)) { + const left = emitSymPy( + node.left, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!left.ok) return left; + const right = emitSymPy( + node.right, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!right.ok) return right; + + switch (node.operatorToken.kind) { + case ts.SyntaxKind.PlusToken: + return { + ok: true, + sympyCode: `${left.sympyCode} + ${right.sympyCode}`, + }; + case ts.SyntaxKind.MinusToken: + return { + ok: true, + sympyCode: `${left.sympyCode} - ${right.sympyCode}`, + }; + case ts.SyntaxKind.AsteriskToken: + return { + ok: true, + sympyCode: `${left.sympyCode} * ${right.sympyCode}`, + }; + case ts.SyntaxKind.SlashToken: + return { + ok: true, + sympyCode: `${left.sympyCode} / ${right.sympyCode}`, + }; + case ts.SyntaxKind.AsteriskAsteriskToken: + return { + ok: true, + sympyCode: `${left.sympyCode}**${right.sympyCode}`, + }; + case ts.SyntaxKind.PercentToken: + return { + ok: true, + sympyCode: `sp.Mod(${left.sympyCode}, ${right.sympyCode})`, + }; + case ts.SyntaxKind.LessThanToken: + return { + ok: true, + sympyCode: `${left.sympyCode} < ${right.sympyCode}`, + }; + case ts.SyntaxKind.LessThanEqualsToken: + return { + ok: true, + sympyCode: `${left.sympyCode} <= ${right.sympyCode}`, + }; + case ts.SyntaxKind.GreaterThanToken: + return { + ok: true, + sympyCode: `${left.sympyCode} > ${right.sympyCode}`, + }; + case ts.SyntaxKind.GreaterThanEqualsToken: + return { + ok: true, + sympyCode: `${left.sympyCode} >= ${right.sympyCode}`, + }; + case ts.SyntaxKind.EqualsEqualsToken: + case ts.SyntaxKind.EqualsEqualsEqualsToken: + return { + ok: true, + sympyCode: `sp.Eq(${left.sympyCode}, ${right.sympyCode})`, + }; + case ts.SyntaxKind.ExclamationEqualsToken: + case ts.SyntaxKind.ExclamationEqualsEqualsToken: + return { + ok: true, + sympyCode: `sp.Ne(${left.sympyCode}, ${right.sympyCode})`, + }; + case ts.SyntaxKind.AmpersandAmpersandToken: + return { + ok: true, + sympyCode: `sp.And(${left.sympyCode}, ${right.sympyCode})`, + }; + case ts.SyntaxKind.BarBarToken: + return { + ok: true, + sympyCode: `sp.Or(${left.sympyCode}, ${right.sympyCode})`, + }; + default: + return err( + `Unsupported binary operator: ${node.operatorToken.getText(sourceFile)}`, + node.operatorToken, + sourceFile, + ); + } + } + + // Conditional (ternary) expression + if (ts.isConditionalExpression(node)) { + const condition = emitSymPy( + node.condition, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!condition.ok) return condition; + const whenTrue = emitSymPy( + node.whenTrue, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!whenTrue.ok) return whenTrue; + const whenFalse = emitSymPy( + node.whenFalse, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!whenFalse.ok) return whenFalse; + return { + ok: true, + sympyCode: `sp.Piecewise((${whenTrue.sympyCode}, ${condition.sympyCode}), (${whenFalse.sympyCode}, True))`, + }; + } + + // Property access: parameters.x, tokens.Place[0].field, Math.PI + if (ts.isPropertyAccessExpression(node)) { + const propName = node.name.text; + + // Math constants: Math.PI, Math.E + if (ts.isIdentifier(node.expression) && node.expression.text === "Math") { + const constant = MATH_CONSTANT_MAP[propName]; + if (constant) return { ok: true, sympyCode: constant }; + // Math.method will be handled as part of a CallExpression + // Return a placeholder that the call expression handler will use + return { ok: true, sympyCode: `Math.${propName}` }; + } + + // parameters.x + if ( + ts.isIdentifier(node.expression) && + node.expression.text === "parameters" + ) { + return { ok: true, sympyCode: propName }; + } + + // tokens.Place[0].field — handle the chain + // First check: something.field where something is an element access + if (ts.isElementAccessExpression(node.expression)) { + // e.g., tokens.Space[0].x + const elemAccess = node.expression; + if (ts.isPropertyAccessExpression(elemAccess.expression)) { + const placePropAccess = elemAccess.expression; + if ( + ts.isIdentifier(placePropAccess.expression) && + placePropAccess.expression.text === "tokens" + ) { + const placeName = placePropAccess.name.text; + const indexExpr = elemAccess.argumentExpression; + const indexText = indexExpr.getText(sourceFile); + return { + ok: true, + sympyCode: `${placeName}_${indexText}_${propName}`, + }; + } + } + } + + // Generic property access — emit as dot access + const obj = emitSymPy( + node.expression, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!obj.ok) return obj; + return { ok: true, sympyCode: `${obj.sympyCode}_${propName}` }; + } + + // Element access: tokens.Place[0], arr[i] + if (ts.isElementAccessExpression(node)) { + const obj = emitSymPy( + node.expression, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!obj.ok) return obj; + const index = emitSymPy( + node.argumentExpression, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!index.ok) return index; + return { ok: true, sympyCode: `${obj.sympyCode}_${index.sympyCode}` }; + } + + // Call expression: Math.cos(x), Math.hypot(a, b), Distribution.Gaussian(m, s) + if (ts.isCallExpression(node)) { + const callee = node.expression; + + // Math.fn(...) + if ( + ts.isPropertyAccessExpression(callee) && + ts.isIdentifier(callee.expression) && + callee.expression.text === "Math" + ) { + const fnName = callee.name.text; + + // Special case: Math.hypot(a, b) -> sp.sqrt(a**2 + b**2) + if (fnName === "hypot") { + const args: string[] = []; + for (const a of node.arguments) { + const r = emitSymPy( + a, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!r.ok) return r; + args.push(r.sympyCode); + } + const sumOfSquares = args.map((a) => `(${a})**2`).join(" + "); + return { ok: true, sympyCode: `sp.sqrt(${sumOfSquares})` }; + } + + // Special case: Math.pow(a, b) -> a**b + if (fnName === "pow" && node.arguments.length === 2) { + const base = emitSymPy( + node.arguments[0]!, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!base.ok) return base; + const exp = emitSymPy( + node.arguments[1]!, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!exp.ok) return exp; + return { + ok: true, + sympyCode: `(${base.sympyCode})**(${exp.sympyCode})`, + }; + } + + const sympyFn = MATH_FUNCTION_MAP[fnName]; + if (!sympyFn) { + return err( + `Unsupported Math function: Math.${fnName}`, + callee, + sourceFile, + ); + } + + const args: string[] = []; + for (const a of node.arguments) { + const r = emitSymPy(a, context, localBindings, innerParams, sourceFile); + if (!r.ok) return r; + args.push(r.sympyCode); + } + return { ok: true, sympyCode: `${sympyFn}(${args.join(", ")})` }; + } + + // Distribution.Gaussian(m, s), Distribution.Uniform(a, b), Distribution.Lognormal(mu, sigma) + if ( + ts.isPropertyAccessExpression(callee) && + ts.isIdentifier(callee.expression) && + callee.expression.text === "Distribution" + ) { + const distName = callee.name.text; + const args: string[] = []; + for (const a of node.arguments) { + const r = emitSymPy(a, context, localBindings, innerParams, sourceFile); + if (!r.ok) return r; + args.push(r.sympyCode); + } + + switch (distName) { + case "Gaussian": + return { + ok: true, + sympyCode: `sp.stats.Normal('X', ${args.join(", ")})`, + }; + case "Uniform": + return { + ok: true, + sympyCode: `sp.stats.Uniform('X', ${args.join(", ")})`, + }; + case "Lognormal": + return { + ok: true, + sympyCode: `sp.stats.LogNormal('X', ${args.join(", ")})`, + }; + default: + return err( + `Unsupported distribution: Distribution.${distName}`, + callee, + sourceFile, + ); + } + } + + // Global built-in functions: Boolean(expr), Number(expr) + if (ts.isIdentifier(callee)) { + if (callee.text === "Boolean" && node.arguments.length === 1) { + const arg = emitSymPy( + node.arguments[0]!, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!arg.ok) return arg; + return { ok: true, sympyCode: `sp.Ne(${arg.sympyCode}, 0)` }; + } + + if (callee.text === "Number" && node.arguments.length === 1) { + return emitSymPy( + node.arguments[0]!, + context, + localBindings, + innerParams, + sourceFile, + ); + } + } + + // .map(callback) on arrays/tokens — emit as Python list comprehension + if ( + ts.isPropertyAccessExpression(callee) && + callee.name.text === "map" && + node.arguments.length === 1 + ) { + const callback = node.arguments[0]!; + if (ts.isArrowFunction(callback) || ts.isFunctionExpression(callback)) { + return compileMapCall( + callee.expression, + callback, + context, + localBindings, + innerParams, + sourceFile, + ); + } + } + + return err( + `Unsupported function call: ${callee.getText(sourceFile)}`, + node, + sourceFile, + ); + } + + // Array literal expression [a, b, c] + if (ts.isArrayLiteralExpression(node)) { + const elements: string[] = []; + for (const elem of node.elements) { + const result = emitSymPy( + elem, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!result.ok) return result; + elements.push(result.sympyCode); + } + return { ok: true, sympyCode: `[${elements.join(", ")}]` }; + } + + // Object literal expression { field: expr, ... } + if (ts.isObjectLiteralExpression(node)) { + const entries: string[] = []; + for (const prop of node.properties) { + if (!ts.isPropertyAssignment(prop)) { + return err( + `Unsupported object property kind: ${ts.SyntaxKind[prop.kind]}`, + prop, + sourceFile, + ); + } + const key = prop.name.getText(sourceFile); + const val = emitSymPy( + prop.initializer, + context, + localBindings, + innerParams, + sourceFile, + ); + if (!val.ok) return val; + entries.push(`'${key}': ${val.sympyCode}`); + } + return { ok: true, sympyCode: `{${entries.join(", ")}}` }; + } + + // Non-null assertion (x!) — just unwrap + if (ts.isNonNullExpression(node)) { + return emitSymPy( + node.expression, + context, + localBindings, + innerParams, + sourceFile, + ); + } + + // Type assertion (x as T) — just unwrap + if (ts.isAsExpression(node)) { + return emitSymPy( + node.expression, + context, + localBindings, + innerParams, + sourceFile, + ); + } + + return err( + `Unsupported syntax: ${ts.SyntaxKind[node.kind]}`, + node, + sourceFile, + ); +} diff --git a/libs/@hashintel/petrinaut/src/views/Editor/editor-view.tsx b/libs/@hashintel/petrinaut/src/views/Editor/editor-view.tsx index 0fe4dc934c0..c95a1bdaa88 100644 --- a/libs/@hashintel/petrinaut/src/views/Editor/editor-view.tsx +++ b/libs/@hashintel/petrinaut/src/views/Editor/editor-view.tsx @@ -18,6 +18,7 @@ import { SDCPNView } from "../SDCPN/sdcpn-view"; import { BottomBar } from "./components/BottomBar/bottom-bar"; import { TopBar } from "./components/TopBar/top-bar"; import { exportSDCPN } from "./lib/export-sdcpn"; +import { exportWithSymPy } from "./lib/export-sympy"; import { exportTikZ } from "./lib/export-tikz"; import { importSDCPN } from "./lib/import-sdcpn"; import { BottomPanel } from "./panels/BottomPanel/panel"; @@ -110,6 +111,10 @@ export const EditorView = ({ exportTikZ({ petriNetDefinition, title }); } + function handleExportWithSymPy() { + exportWithSymPy({ petriNetDefinition, title }); + } + function handleImport() { importSDCPN((loadedSDCPN) => { const convertedSdcpn = convertOldFormatToSDCPN(loadedSDCPN); @@ -163,6 +168,11 @@ export const EditorView = ({ label: "TikZ", onClick: handleExportTikZ, }, + { + id: "export-sympy", + label: "JSON with SymPy expressions", + onClick: handleExportWithSymPy, + }, ], }, { diff --git a/libs/@hashintel/petrinaut/src/views/Editor/lib/export-sympy.ts b/libs/@hashintel/petrinaut/src/views/Editor/lib/export-sympy.ts new file mode 100644 index 00000000000..281549292d3 --- /dev/null +++ b/libs/@hashintel/petrinaut/src/views/Editor/lib/export-sympy.ts @@ -0,0 +1,93 @@ +import type { SDCPN } from "../../../core/types/sdcpn"; +import { + buildContextForDifferentialEquation, + buildContextForTransition, + compileToSymPy, +} from "../../../simulation/simulator/compile-to-sympy"; + +type SymPyExpression = { + name: string; + type: string; + sympyCode: string | null; + error: string | null; +}; + +/** + * Converts all expressions in an SDCPN model to SymPy and produces a JSON + * export containing both the original model and the SymPy representations. + */ +export function exportWithSymPy({ + petriNetDefinition, + title, +}: { + petriNetDefinition: SDCPN; + title: string; +}): void { + const expressions: SymPyExpression[] = []; + + // Convert differential equation expressions + for (const de of petriNetDefinition.differentialEquations) { + const ctx = buildContextForDifferentialEquation( + petriNetDefinition, + de.colorId, + ); + const result = compileToSymPy(de.code, ctx); + expressions.push({ + name: de.name, + type: "differential-equation", + sympyCode: result.ok ? result.sympyCode : null, + error: result.ok ? null : result.error, + }); + } + + // Convert transition lambda and kernel expressions + for (const transition of petriNetDefinition.transitions) { + const lambdaCtx = buildContextForTransition( + petriNetDefinition, + transition, + "Lambda", + ); + const lambdaResult = compileToSymPy(transition.lambdaCode, lambdaCtx); + expressions.push({ + name: `${transition.name} (lambda)`, + type: "transition-lambda", + sympyCode: lambdaResult.ok ? lambdaResult.sympyCode : null, + error: lambdaResult.ok ? null : lambdaResult.error, + }); + + const kernelCtx = buildContextForTransition( + petriNetDefinition, + transition, + "TransitionKernel", + ); + const kernelResult = compileToSymPy( + transition.transitionKernelCode, + kernelCtx, + ); + expressions.push({ + name: `${transition.name} (kernel)`, + type: "transition-kernel", + sympyCode: kernelResult.ok ? kernelResult.sympyCode : null, + error: kernelResult.ok ? null : kernelResult.error, + }); + } + + const exportData = { + title, + sympy_expressions: expressions, + ...petriNetDefinition, + }; + + const jsonString = JSON.stringify(exportData, null, 2); + const blob = new Blob([jsonString], { type: "application/json" }); + const url = URL.createObjectURL(blob); + const link = document.createElement("a"); + link.href = url; + link.download = `${title.replace(/[^a-z0-9]/gi, "_").toLowerCase()}_sympy_${new Date().toISOString().replace(/:/g, "-")}.json`; + + document.body.appendChild(link); + link.click(); + + document.body.removeChild(link); + URL.revokeObjectURL(url); +}