From 4f13d8a90b04bb6f37aa14625db44822e9edad33 Mon Sep 17 00:00:00 2001 From: Fabio Madge Date: Mon, 9 Mar 2026 05:24:47 +0100 Subject: [PATCH 01/11] Laurel: add constrained type support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A Laurel-to-Laurel elimination pass (ConstrainedTypeElim.lean) that: - Adds requires for constrained-typed inputs - Adds ensures for constrained-typed outputs - Clears isFunctional when adding ensures (function postconditions not yet supported) - Inserts assert for local variable init and reassignment - Uses witness as default initializer for uninitialized constrained variables - Validates witnesses via synthetic procedures - Injects constraints into quantifier bodies (forall → implies, exists → and) - Resolves all constrained type references to base types - Handles capture avoidance in identifier substitution Core's call elimination handles caller-side argument asserts and return value assumes automatically via requires/ensures. Grammar: constrained type syntax Parser: parseConstrainedType + topLevelConstrainedType Test: T09_ConstrainedTypes — 25 test procedures --- .../Languages/Laurel/ConstrainedTypeElim.lean | 275 ++++++++++++++++++ .../ConcreteToAbstractTreeTranslator.lean | 19 +- .../Languages/Laurel/Grammar/LaurelGrammar.st | 6 + .../Laurel/LaurelToCoreTranslator.lean | 8 +- .../Laurel/ConstrainedTypeElimTest.lean | 61 ++++ .../Fundamentals/T09_ConstrainedTypes.lean | 168 +++++++++++ 6 files changed, 535 insertions(+), 2 deletions(-) create mode 100644 Strata/Languages/Laurel/ConstrainedTypeElim.lean create mode 100644 StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean create mode 100644 StrataTest/Languages/Laurel/Examples/Fundamentals/T09_ConstrainedTypes.lean diff --git a/Strata/Languages/Laurel/ConstrainedTypeElim.lean b/Strata/Languages/Laurel/ConstrainedTypeElim.lean new file mode 100644 index 000000000..6c00f6644 --- /dev/null +++ b/Strata/Languages/Laurel/ConstrainedTypeElim.lean @@ -0,0 +1,275 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ + +import Strata.Languages.Laurel.Laurel +import Strata.Languages.Laurel.Resolution + +/-! +# Constrained Type Elimination + +A Laurel-to-Laurel pass that eliminates constrained types by: +1. Adding `requires` for constrained-typed inputs (Core handles caller asserts and body assumes) +2. Adding `ensures` for constrained-typed outputs (Core handles body checks and caller assumes) + - Skipped for `isFunctional` procedures since the Laurel translator does not yet support + function postconditions. Constrained return types on functions are not checked. +3. Inserting `assert` for local variable init and reassignment of constrained-typed variables +4. Using the witness as default initializer for uninitialized constrained-typed variables +5. Adding a synthetic witness-validation procedure per constrained type +6. Injecting constraints into quantifier bodies (`forall` → `implies`, `exists` → `and`) +7. Resolving all constrained type references to their base types +-/ + +namespace Strata.Laurel + +open Strata + +abbrev ConstrainedTypeMap := Std.HashMap String ConstrainedType +/-- Map from variable name to its constrained HighType (e.g. UserDefined "nat") -/ +abbrev PredVarMap := Std.HashMap String HighType + +def buildConstrainedTypeMap (types : List TypeDefinition) : ConstrainedTypeMap := + types.foldl (init := {}) fun m td => + match td with | .Constrained ct => m.insert ct.name.text ct | _ => m + +partial def resolveBaseType (ptMap : ConstrainedTypeMap) (ty : HighType) : HighType := + match ty with + | .UserDefined name => match ptMap.get? name.text with + | some ct => resolveBaseType ptMap ct.base.val | none => ty + | .Applied ctor args => + .Applied ctor (args.map fun a => ⟨resolveBaseType ptMap a.val, a.md⟩) + | _ => ty + +def resolveType (ptMap : ConstrainedTypeMap) (ty : HighTypeMd) : HighTypeMd := + ⟨resolveBaseType ptMap ty.val, ty.md⟩ + +def isConstrainedType (ptMap : ConstrainedTypeMap) (ty : HighType) : Bool := + match ty with | .UserDefined name => ptMap.contains name.text | _ => false + +/-- All predicates for a type transitively (e.g. evenpos → [(x, x > 0), (x, x % 2 == 0)]) -/ +partial def getAllConstraints (ptMap : ConstrainedTypeMap) (ty : HighType) + : List (Identifier × StmtExprMd) := + match ty with + | .UserDefined name => match ptMap.get? name.text with + | some ct => (ct.valueName, ct.constraint) :: getAllConstraints ptMap ct.base.val + | none => [] + | _ => [] + +/-- Substitute `Identifier old` with `Identifier new` in a constraint expression -/ +partial def substId (old new : Identifier) : StmtExprMd → StmtExprMd + | ⟨.Identifier n, md⟩ => ⟨if n == old then .Identifier new else .Identifier n, md⟩ + | ⟨.PrimitiveOp op args, md⟩ => + ⟨.PrimitiveOp op (args.map fun a => substId old new a), md⟩ + | ⟨.StaticCall c args, md⟩ => + ⟨.StaticCall c (args.map fun a => substId old new a), md⟩ + | ⟨.IfThenElse c t (some el), md⟩ => + ⟨.IfThenElse (substId old new c) (substId old new t) (some (substId old new el)), md⟩ + | ⟨.IfThenElse c t none, md⟩ => + ⟨.IfThenElse (substId old new c) (substId old new t) none, md⟩ + | ⟨.Block ss sep, md⟩ => + ⟨.Block (ss.map fun s => substId old new s) sep, md⟩ + | ⟨.Forall param body, md⟩ => + if param.name == old then ⟨.Forall param body, md⟩ + else if param.name == new then + let fresh : Identifier := mkId (param.name.text ++ "$") + ⟨.Forall { param with name := fresh } (substId old new (substId param.name fresh body)), md⟩ + else ⟨.Forall param (substId old new body), md⟩ + | ⟨.Exists param body, md⟩ => + if param.name == old then ⟨.Exists param body, md⟩ + else if param.name == new then + let fresh : Identifier := mkId (param.name.text ++ "$") + ⟨.Exists { param with name := fresh } (substId old new (substId param.name fresh body)), md⟩ + else ⟨.Exists param (substId old new body), md⟩ + | e => e + +def mkAsserts (ptMap : ConstrainedTypeMap) (ty : HighType) (varName : Identifier) + (md : Imperative.MetaData Core.Expression) : List StmtExprMd := + (getAllConstraints ptMap ty).map fun (valueName, pred) => + ⟨.Assert (substId valueName varName pred), md⟩ + +private def wrap (stmts : List StmtExprMd) (md : Imperative.MetaData Core.Expression) + : StmtExprMd := + match stmts with | [s] => s | ss => ⟨.Block ss none, md⟩ + +/-- Inject constraints into a quantifier body for a constrained type -/ +private def injectQuantifierConstraint (ptMap : ConstrainedTypeMap) (ty : HighType) + (varName : Identifier) (body : StmtExprMd) (isForall : Bool) : StmtExprMd := + let constraints := getAllConstraints ptMap ty + match constraints with + | [] => body + | _ => + let preds := constraints.map fun (vn, pred) => substId vn varName pred + let conj := preds.tail.foldl (init := preds.head!) fun acc p => + ⟨.PrimitiveOp .And [acc, p], body.md⟩ + if isForall then ⟨.PrimitiveOp .Implies [conj, body], body.md⟩ + else ⟨.PrimitiveOp .And [conj, body], body.md⟩ + +/-- Resolve constrained types in all type positions of an expression -/ +def resolveExpr (ptMap : ConstrainedTypeMap) : StmtExprMd → StmtExprMd + | ⟨.LocalVariable n ty (some init), md⟩ => + ⟨.LocalVariable n (resolveType ptMap ty) (some (resolveExpr ptMap init)), md⟩ + | ⟨.LocalVariable n ty none, md⟩ => + ⟨.LocalVariable n (resolveType ptMap ty) none, md⟩ + | ⟨.Forall param body, md⟩ => + let body' := resolveExpr ptMap body + let param' := { param with type := resolveType ptMap param.type } + ⟨.Forall param' (injectQuantifierConstraint ptMap param.type.val param.name body' true), md⟩ + | ⟨.Exists param body, md⟩ => + let body' := resolveExpr ptMap body + let param' := { param with type := resolveType ptMap param.type } + ⟨.Exists param' (injectQuantifierConstraint ptMap param.type.val param.name body' false), md⟩ + | ⟨.AsType t ty, md⟩ => ⟨.AsType (resolveExpr ptMap t) (resolveType ptMap ty), md⟩ + | ⟨.IsType t ty, md⟩ => ⟨.IsType (resolveExpr ptMap t) (resolveType ptMap ty), md⟩ + | ⟨.PrimitiveOp op args, md⟩ => + ⟨.PrimitiveOp op (args.attach.map fun ⟨a, _⟩ => resolveExpr ptMap a), md⟩ + | ⟨.StaticCall c args, md⟩ => + ⟨.StaticCall c (args.attach.map fun ⟨a, _⟩ => resolveExpr ptMap a), md⟩ + | ⟨.Block ss sep, md⟩ => + ⟨.Block (ss.attach.map fun ⟨s, _⟩ => resolveExpr ptMap s) sep, md⟩ + | ⟨.IfThenElse c t (some el), md⟩ => + ⟨.IfThenElse (resolveExpr ptMap c) (resolveExpr ptMap t) (some (resolveExpr ptMap el)), md⟩ + | ⟨.IfThenElse c t none, md⟩ => + ⟨.IfThenElse (resolveExpr ptMap c) (resolveExpr ptMap t) none, md⟩ + | ⟨.While c inv dec body, md⟩ => + ⟨.While (resolveExpr ptMap c) (inv.attach.map fun ⟨i, _⟩ => resolveExpr ptMap i) + dec (resolveExpr ptMap body), md⟩ + | ⟨.Assign ts v, md⟩ => + ⟨.Assign (ts.attach.map fun ⟨t, _⟩ => resolveExpr ptMap t) (resolveExpr ptMap v), md⟩ + | ⟨.Return (some v), md⟩ => ⟨.Return (some (resolveExpr ptMap v)), md⟩ + | ⟨.Return none, md⟩ => ⟨.Return none, md⟩ + | ⟨.Assert c, md⟩ => ⟨.Assert (resolveExpr ptMap c), md⟩ + | ⟨.Assume c, md⟩ => ⟨.Assume (resolveExpr ptMap c), md⟩ + | e => e +termination_by e => sizeOf e +decreasing_by all_goals (have := WithMetadata.sizeOf_val_lt ‹_›; term_by_mem) + +/-- Insert asserts for constrained-typed variable init and reassignment -/ +abbrev ElimM := StateM PredVarMap + +def elimStmt (ptMap : ConstrainedTypeMap) + (stmt : StmtExprMd) : ElimM (List StmtExprMd) := do + let md := stmt.md + match _h : stmt.val with + | .LocalVariable name ty init => + let isPred := isConstrainedType ptMap ty.val + if isPred then modify fun pv => pv.insert name.text ty.val + let asserts := if isPred then mkAsserts ptMap ty.val name md else [] + -- Use witness as default initializer for uninitialized constrained variables + let init' := match init with + | none => match ty.val with + | .UserDefined n => (ptMap.get? n.text).map (·.witness) + | _ => none + | some _ => init + pure ([⟨.LocalVariable name ty init', md⟩] ++ asserts) + + -- Single-target only; multi-target assignments are not supported by the Laurel grammar + | .Assign [target] _ => match target.val with + | .Identifier name => do + match (← get).get? name.text with + | some ty => pure ([stmt] ++ mkAsserts ptMap ty name md) + | none => pure [stmt] + | _ => pure [stmt] + + | .Block stmts sep => + let stmtss ← stmts.mapM (elimStmt ptMap) + pure [⟨.Block stmtss.flatten sep, md⟩] + + | .IfThenElse cond thenBr (some elseBr) => + let thenSs ← elimStmt ptMap thenBr + let elseSs ← elimStmt ptMap elseBr + pure [⟨.IfThenElse cond (wrap thenSs md) (some (wrap elseSs md)), md⟩] + | .IfThenElse cond thenBr none => + let thenSs ← elimStmt ptMap thenBr + pure [⟨.IfThenElse cond (wrap thenSs md) none, md⟩] + + | .While cond inv dec body => + let bodySs ← elimStmt ptMap body + pure [⟨.While cond inv dec (wrap bodySs md), md⟩] + + | _ => pure [stmt] +termination_by sizeOf stmt +decreasing_by + all_goals simp_wf + all_goals (try have := WithMetadata.sizeOf_val_lt stmt) + all_goals (try term_by_mem) + all_goals omega + +def elimProc (ptMap : ConstrainedTypeMap) (proc : Procedure) : Procedure := + -- Add requires for constrained-typed inputs + let inputRequires := proc.inputs.flatMap fun p => + (getAllConstraints ptMap p.type.val).map fun (vn, pred) => + ⟨(substId vn p.name pred).val, p.type.md⟩ + -- Add ensures for constrained-typed outputs (skip for isFunctional — not yet supported) + let outputEnsures := if proc.isFunctional then [] else proc.outputs.flatMap fun p => + (getAllConstraints ptMap p.type.val).map fun (vn, pred) => + ⟨(substId vn p.name pred).val, p.type.md⟩ + -- Transform body: insert asserts for local variable init/reassignment + let initVars : PredVarMap := proc.inputs.foldl (init := {}) fun s p => + if isConstrainedType ptMap p.type.val then s.insert p.name.text p.type.val else s + let body' := match proc.body with + | .Transparent bodyExpr => + let (stmts, _) := (elimStmt ptMap bodyExpr).run initVars + let body := wrap stmts bodyExpr.md + if outputEnsures.isEmpty then .Transparent body + else + -- Wrap expression body in a Return so it translates correctly as a procedure + let retBody := if proc.isFunctional then ⟨.Return (some body), bodyExpr.md⟩ else body + .Opaque outputEnsures (some retBody) [] + | .Opaque postconds impl modif => + let impl' := impl.map fun b => wrap ((elimStmt ptMap b).run initVars).1 b.md + .Opaque (postconds ++ outputEnsures) impl' modif + | .Abstract postconds => .Abstract (postconds ++ outputEnsures) + | .External => .External + -- Resolve all constrained types to base types + let resolve := resolveExpr ptMap + let resolveBody : Body → Body := fun body => match body with + | .Transparent b => .Transparent (resolve b) + | .Opaque ps impl modif => .Opaque (ps.map resolve) (impl.map resolve) (modif.map resolve) + | .Abstract ps => .Abstract (ps.map resolve) + | .External => .External + { proc with + body := resolveBody body' + inputs := proc.inputs.map fun p => { p with type := resolveType ptMap p.type } + outputs := proc.outputs.map fun p => { p with type := resolveType ptMap p.type } + preconditions := (proc.preconditions ++ inputRequires).map resolve } + +/-- Create a synthetic procedure that asserts the witness satisfies all constraints -/ +private def mkWitnessProc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) : Procedure := + let md := ct.witness.md + let witnessId : Identifier := mkId "$witness" + let witnessInit : StmtExprMd := + ⟨.LocalVariable witnessId (resolveType ptMap ct.base) (some ct.witness), md⟩ + let asserts := (getAllConstraints ptMap (.UserDefined ct.name)).map fun (vn, pred) => + ⟨.Assert (substId vn witnessId pred), md⟩ + { name := mkId s!"$witness_{ct.name.text}" + inputs := [] + outputs := [] + body := .Transparent ⟨.Block ([witnessInit] ++ asserts) none, md⟩ + preconditions := [] + isFunctional := false + determinism := .deterministic none + decreases := none + md := md } + +/-- Eliminate constrained types from a Laurel program. + The `witness` field is used as the default initializer for uninitialized + constrained-typed variables, and is validated via synthetic procedures. -/ +def constrainedTypeElim (_model : SemanticModel) (program : Program) : Program × Array DiagnosticModel := + let ptMap := buildConstrainedTypeMap program.types + if ptMap.isEmpty then (program, #[]) else + -- Report unsupported: isFunctional procedures with constrained return types + let funcDiags := program.staticProcedures.foldl (init := #[]) fun acc proc => + if proc.isFunctional && proc.outputs.any (fun p => isConstrainedType ptMap p.type.val) then + acc.push (proc.md.toDiagnostic "constrained return types on functions are not yet supported") + else acc + let witnessProcedures := program.types.filterMap fun + | .Constrained ct => some (mkWitnessProc ptMap ct) + | _ => none + ({ program with + staticProcedures := program.staticProcedures.map (elimProc ptMap) ++ witnessProcedures + types := program.types.filter fun | .Constrained _ => false | _ => true }, + funcDiags) + +end Strata.Laurel diff --git a/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean b/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean index 2ada05bab..75eaf50ca 100644 --- a/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean +++ b/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean @@ -495,6 +495,20 @@ def parseDatatype (arg : Arg) : TransM TypeDefinition := do | _, _ => TransM.error s!"parseDatatype expects datatype, got {repr op.name}" +def parseConstrainedType (arg : Arg) : TransM ConstrainedType := do + let .op op := arg + | TransM.error s!"parseConstrainedType expects operation" + match op.name, op.args with + | q`Laurel.constrainedType, #[nameArg, valueNameArg, baseArg, constraintArg, witnessArg] => + let name ← translateIdent nameArg + let valueName ← translateIdent valueNameArg + let base ← translateHighType baseArg + let constraint ← translateStmtExpr constraintArg + let witness ← translateStmtExpr witnessArg + return { name, base, valueName, constraint, witness } + | _, _ => + TransM.error s!"parseConstrainedType expects constrainedType, got {repr op.name}" + def parseTopLevel (arg : Arg) : TransM (Option Procedure × Option TypeDefinition) := do let .op op := arg | TransM.error s!"parseTopLevel expects operation" @@ -509,8 +523,11 @@ def parseTopLevel (arg : Arg) : TransM (Option Procedure × Option TypeDefinitio | q`Laurel.topLevelDatatype, #[datatypeArg] => let typeDef ← parseDatatype datatypeArg return (none, some typeDef) + | q`Laurel.topLevelConstrainedType, #[ctArg] => + let ct ← parseConstrainedType ctArg + return (none, some (.Constrained ct)) | _, _ => - TransM.error s!"parseTopLevel expects topLevelProcedure, topLevelComposite, or topLevelDatatype, got {repr op.name}" + TransM.error s!"parseTopLevel expects topLevelProcedure, topLevelComposite, topLevelDatatype, or topLevelConstrainedType, got {repr op.name}" /-- Translate concrete Laurel syntax into abstract Laurel syntax diff --git a/Strata/Languages/Laurel/Grammar/LaurelGrammar.st b/Strata/Languages/Laurel/Grammar/LaurelGrammar.st index 720d64b86..406e704bb 100644 --- a/Strata/Languages/Laurel/Grammar/LaurelGrammar.st +++ b/Strata/Languages/Laurel/Grammar/LaurelGrammar.st @@ -166,4 +166,10 @@ op topLevelComposite(composite: Composite): TopLevel => composite; op topLevelProcedure(procedure: Procedure): TopLevel => procedure; op topLevelDatatype(datatype: Datatype): TopLevel => datatype; +category ConstrainedType; +op constrainedType (name: Ident, valueName: Ident, base: LaurelType, + constraint: StmtExpr, witness: StmtExpr): ConstrainedType + => "constrained " name " = " valueName ": " base " where " constraint:0 " witness " witness:0; +op topLevelConstrainedType(ct: ConstrainedType): TopLevel => ct; + op program (items: Seq TopLevel): Command => items; \ No newline at end of file diff --git a/Strata/Languages/Laurel/LaurelToCoreTranslator.lean b/Strata/Languages/Laurel/LaurelToCoreTranslator.lean index 5f9babbe1..7308d70dc 100644 --- a/Strata/Languages/Laurel/LaurelToCoreTranslator.lean +++ b/Strata/Languages/Laurel/LaurelToCoreTranslator.lean @@ -21,6 +21,7 @@ import Strata.DL.Imperative.Stmt import Strata.DL.Imperative.MetaData import Strata.DL.Lambda.LExpr import Strata.Languages.Laurel.LaurelFormat +import Strata.Languages.Laurel.ConstrainedTypeElim import Strata.Util.Tactics open Core (VCResult VCResults VerifyOptions) @@ -619,6 +620,11 @@ def translate (program : Program): Except (Array DiagnosticModel) (Core.Program let (program, model) := (result.program, result.model) _resolutionDiags := _resolutionDiags ++ result.errors + let (program, constrainedTypeDiags) := constrainedTypeElim model program + let result := resolve program (some model) + let (program, model) := (result.program, result.model) + _resolutionDiags := _resolutionDiags ++ result.errors + -- Procedures marked isFunctional are translated to Core functions; all others become Core procedures. -- External procedures are completely ignored (not translated to Core). let nonExternal := program.staticProcedures.filter (fun p => !p.body.isExternal) @@ -667,7 +673,7 @@ def translate (program : Program): Except (Array DiagnosticModel) (Core.Program -- dbg_trace "=== Generated Strata Core Program ===" -- dbg_trace (toString (Std.Format.pretty (Strata.Core.formatProgram program) 100)) -- dbg_trace "=================================" - pure (program, diamondErrors ++ modifiesDiags) + pure (program, diamondErrors ++ modifiesDiags ++ constrainedTypeDiags.toList) /-- Verify a Laurel program using an SMT solver diff --git a/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean b/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean new file mode 100644 index 000000000..7509bce29 --- /dev/null +++ b/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean @@ -0,0 +1,61 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ + +/- +Tests that the constrained type elimination pass correctly transforms +Laurel programs by comparing the output against expected results. +-/ + +import Strata.DDM.Elab +import Strata.DDM.BuiltinDialects.Init +import Strata.Languages.Laurel.Grammar.LaurelGrammar +import Strata.Languages.Laurel.Grammar.ConcreteToAbstractTreeTranslator +import Strata.Languages.Laurel.ConstrainedTypeElim +import Strata.Languages.Laurel.Resolution + +open Strata +open Strata.Elab (parseStrataProgramFromDialect) + +namespace Strata.Laurel + +def testProgram : String := r" +constrained nat = x: int where x >= 0 witness 0 +procedure test(n: nat) returns (r: nat) { + var y: nat := n; + return y; +} +" + +def parseLaurelAndElim (input : String) : IO Program := do + let inputCtx := Strata.Parser.stringInputContext "test" input + let dialects := Strata.Elab.LoadedDialects.ofDialects! #[initDialect, Laurel] + let strataProgram ← parseStrataProgramFromDialect dialects Laurel.name inputCtx + let uri := Strata.Uri.file "test" + match Laurel.TransM.run uri (Laurel.parseProgram strataProgram) with + | .error e => throw (IO.userError s!"Translation errors: {e}") + | .ok program => + let result := resolve program + let (program, model) := (result.program, result.model) + pure (constrainedTypeElim model program).1 + +/-- +info: procedure test(n: int) returns ⏎ +(r: int) +requires n >= 0 +deterministic + ensures r >= 0 := { var y: int := n; assert y >= 0; return y } +procedure $witness_nat() returns ⏎ +() +deterministic +{ var $witness: int := 0; assert $witness >= 0 } +-/ +#guard_msgs in +#eval! do + let program ← parseLaurelAndElim testProgram + for proc in program.staticProcedures do + IO.println (toString (Std.Format.pretty (Std.ToFormat.format proc))) + +end Laurel diff --git a/StrataTest/Languages/Laurel/Examples/Fundamentals/T09_ConstrainedTypes.lean b/StrataTest/Languages/Laurel/Examples/Fundamentals/T09_ConstrainedTypes.lean new file mode 100644 index 000000000..cbd9d7461 --- /dev/null +++ b/StrataTest/Languages/Laurel/Examples/Fundamentals/T09_ConstrainedTypes.lean @@ -0,0 +1,168 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ + +import StrataTest.Util.TestDiagnostics +import StrataTest.Languages.Laurel.TestExamples + +open StrataTest.Util + +namespace Strata +namespace Laurel + +def program := r" +constrained nat = x: int where x >= 0 witness 0 +constrained posnat = x: nat where x > 0 witness 1 + +// Input constraint becomes requires — body can rely on it +procedure inputAssumed(n: nat) { + assert n >= 0; +} + +// Output constraint — valid return passes +procedure outputValid(): nat { + return 3; +} + +// Output constraint — invalid return fails +procedure outputInvalid(): nat { +// ^^^ error: assertion does not hold + return -1; +} + +// Return value of constrained type — caller gets ensures via call elimination +procedure opaqueNat(): nat +procedure callerAssumes() returns (r: int) { + var x: int := opaqueNat(); + assert x >= 0; + return x; +} + +// Assignment to constrained-typed variable — valid +procedure assignValid() { + var y: nat := 5; +} + +// Assignment to constrained-typed variable — invalid +procedure assignInvalid() { + var y: nat := -1; +// ^^ error: assertion does not hold +} + +// Reassignment to constrained-typed variable — invalid +procedure reassignInvalid() { + var y: nat := 5; + y := -1; +//^^^^^^^^ error: assertion does not hold +} + +// Argument to constrained-typed parameter — valid +procedure takesNat(n: nat) returns (r: int) { return n; } +// ^^^ error: assertion does not hold +procedure argValid() returns (r: int) { + var x: int := takesNat(3); + return x; +} + +// Argument to constrained-typed parameter — invalid (requires violation) +procedure argInvalid() returns (r: int) { + var x: int := takesNat(-1); + return x; +} + +// Nested constrained type — independent constraints require transitive collection +constrained even = x: int where x % 2 == 0 witness 0 +constrained evenpos = x: even where x > 0 witness 2 +procedure nestedInput(x: evenpos) { + assert x > 0; + assert x % 2 == 0; +} + +// Multiple constrained-typed parameters +procedure multiParam(a: nat, b: nat) { + assert a >= 0; + assert b >= 0; +} + +// Two calls to same procedure — no temp var collision +procedure twoCalls() returns (r: int) { + var a: int := takesNat(1); + var b: int := takesNat(2); + return a + b; +} + +// Constrained type in expression position must be resolved +procedure constrainedInExpr() { + var b: bool := forall(n: nat) => n + 1 > n; + assert b; +} + +// Invalid witness — witness -1 does not satisfy x > 0 +constrained bad = x: int where x > 0 witness -1 +// ^^ error: assertion does not hold + +// Uninitialized constrained variable — witness used as default +procedure uninitNat() { + var y: nat; + assert y >= 0; +} + +// Uninitialized nested constrained variable — outermost witness used +procedure uninitPosnat() { + var y: posnat; + assert y > 0; + assert y >= 0; +} + +// Function with valid constrained return — constraint not checked (not yet supported) +function goodFunc(): nat { 3 } +// ^^^^^^^^ error: constrained return types on functions are not yet supported + +// Function with invalid constrained return — constraint not checked (not yet supported) +function badFunc(): nat { -1 } +// ^^^^^^^ error: constrained return types on functions are not yet supported + +// Caller of constrained function — body is inlined, caller sees actual value +procedure callerGood() { + var x: int := goodFunc(); + assert x >= 0; +} + +// Quantifier constraint injection — forall +// n + 1 > 0 is only provable with n >= 0 injected; false for all int +procedure forallNat() { + var b: bool := forall(n: nat) => n + 1 > 0; + assert b; +} + +// Quantifier constraint injection — exists +// n == -1 is satisfiable for int, but not when n >= 0 is required +// n == 42 works because 42 >= 0 +procedure existsNat() { + var b: bool := exists(n: nat) => n == 42; + assert b; +} + +// Quantifier constraint injection — nested constrained type +// n - 1 >= 0 is only provable with n > 0 injected +procedure forallPosnat() { + var b: bool := forall(n: posnat) => n - 1 >= 0; + assert b; +} + +// Capture avoidance — bound var y in constraint must not collide with parameter y +// Without capture avoidance, requires becomes exists(y) => y > y (false), making body vacuously true +constrained haslarger = x: int where (exists(y: int) => y > x) witness 0 +procedure captureTest(y: haslarger) { + assert false; +//^^^^^^^^^^^^^ error: assertion does not hold +} +" + +#guard_msgs(drop info, error) in +#eval testInputWithOffset "ConstrainedTypes" program 14 processLaurelFile + +end Laurel +end Strata From 46feec73554f29b10a4cf40e0e618932d3aaeb3f Mon Sep 17 00:00:00 2001 From: Fabio Madge Date: Tue, 10 Mar 2026 18:21:08 +0100 Subject: [PATCH 02/11] Fix test annotation after main merge --- .../Laurel/Examples/Fundamentals/T09_ConstrainedTypes.lean | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/StrataTest/Languages/Laurel/Examples/Fundamentals/T09_ConstrainedTypes.lean b/StrataTest/Languages/Laurel/Examples/Fundamentals/T09_ConstrainedTypes.lean index cbd9d7461..bc76e547b 100644 --- a/StrataTest/Languages/Laurel/Examples/Fundamentals/T09_ConstrainedTypes.lean +++ b/StrataTest/Languages/Laurel/Examples/Fundamentals/T09_ConstrainedTypes.lean @@ -48,7 +48,7 @@ procedure assignValid() { // Assignment to constrained-typed variable — invalid procedure assignInvalid() { var y: nat := -1; -// ^^ error: assertion does not hold +//^^^^^^^^^^^^^^^^^ error: assertion does not hold } // Reassignment to constrained-typed variable — invalid From af4e7973fa863ab52c43941786977e9334dc944c Mon Sep 17 00:00:00 2001 From: Fabio Madge Date: Tue, 10 Mar 2026 23:57:43 +0100 Subject: [PATCH 03/11] Replace head! with pattern match; fix grammar and test semicolons --- .../Languages/Laurel/ConstrainedTypeElim.lean | 11 +- .../Laurel/Grammar/LaurelGrammar.lean | 1 - .../Laurel/ConstrainedTypeElimTest.lean | 2 +- .../Fundamentals/T09_ConstrainedTypes.lean | 168 ------------------ .../Fundamentals/T10_ConstrainedTypes.lean | 158 ++++++++++++++-- 5 files changed, 156 insertions(+), 184 deletions(-) delete mode 100644 StrataTest/Languages/Laurel/Examples/Fundamentals/T09_ConstrainedTypes.lean diff --git a/Strata/Languages/Laurel/ConstrainedTypeElim.lean b/Strata/Languages/Laurel/ConstrainedTypeElim.lean index 6c00f6644..5eaab12ee 100644 --- a/Strata/Languages/Laurel/ConstrainedTypeElim.lean +++ b/Strata/Languages/Laurel/ConstrainedTypeElim.lean @@ -101,10 +101,13 @@ private def injectQuantifierConstraint (ptMap : ConstrainedTypeMap) (ty : HighTy | [] => body | _ => let preds := constraints.map fun (vn, pred) => substId vn varName pred - let conj := preds.tail.foldl (init := preds.head!) fun acc p => - ⟨.PrimitiveOp .And [acc, p], body.md⟩ - if isForall then ⟨.PrimitiveOp .Implies [conj, body], body.md⟩ - else ⟨.PrimitiveOp .And [conj, body], body.md⟩ + match preds with + | [] => body -- unreachable + | first :: rest => + let conj := rest.foldl (init := first) fun acc p => + ⟨.PrimitiveOp .And [acc, p], body.md⟩ + if isForall then ⟨.PrimitiveOp .Implies [conj, body], body.md⟩ + else ⟨.PrimitiveOp .And [conj, body], body.md⟩ /-- Resolve constrained types in all type positions of an expression -/ def resolveExpr (ptMap : ConstrainedTypeMap) : StmtExprMd → StmtExprMd diff --git a/Strata/Languages/Laurel/Grammar/LaurelGrammar.lean b/Strata/Languages/Laurel/Grammar/LaurelGrammar.lean index 282254daa..632e0a69b 100644 --- a/Strata/Languages/Laurel/Grammar/LaurelGrammar.lean +++ b/Strata/Languages/Laurel/Grammar/LaurelGrammar.lean @@ -7,7 +7,6 @@ -- Laurel dialect definition, loaded from LaurelGrammar.st -- NOTE: Changes to LaurelGrammar.st are not automatically tracked by the build system. -- Update this file (e.g. this comment) to trigger a recompile after modifying LaurelGrammar.st. --- Last grammar change: require semicolon after procedure/function definitions. import Strata.DDM.Integration.Lean namespace Strata.Laurel diff --git a/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean b/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean index 7509bce29..f3ce4399e 100644 --- a/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean +++ b/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean @@ -26,7 +26,7 @@ constrained nat = x: int where x >= 0 witness 0 procedure test(n: nat) returns (r: nat) { var y: nat := n; return y; -} +}; " def parseLaurelAndElim (input : String) : IO Program := do diff --git a/StrataTest/Languages/Laurel/Examples/Fundamentals/T09_ConstrainedTypes.lean b/StrataTest/Languages/Laurel/Examples/Fundamentals/T09_ConstrainedTypes.lean deleted file mode 100644 index bc76e547b..000000000 --- a/StrataTest/Languages/Laurel/Examples/Fundamentals/T09_ConstrainedTypes.lean +++ /dev/null @@ -1,168 +0,0 @@ -/- - Copyright Strata Contributors - - SPDX-License-Identifier: Apache-2.0 OR MIT --/ - -import StrataTest.Util.TestDiagnostics -import StrataTest.Languages.Laurel.TestExamples - -open StrataTest.Util - -namespace Strata -namespace Laurel - -def program := r" -constrained nat = x: int where x >= 0 witness 0 -constrained posnat = x: nat where x > 0 witness 1 - -// Input constraint becomes requires — body can rely on it -procedure inputAssumed(n: nat) { - assert n >= 0; -} - -// Output constraint — valid return passes -procedure outputValid(): nat { - return 3; -} - -// Output constraint — invalid return fails -procedure outputInvalid(): nat { -// ^^^ error: assertion does not hold - return -1; -} - -// Return value of constrained type — caller gets ensures via call elimination -procedure opaqueNat(): nat -procedure callerAssumes() returns (r: int) { - var x: int := opaqueNat(); - assert x >= 0; - return x; -} - -// Assignment to constrained-typed variable — valid -procedure assignValid() { - var y: nat := 5; -} - -// Assignment to constrained-typed variable — invalid -procedure assignInvalid() { - var y: nat := -1; -//^^^^^^^^^^^^^^^^^ error: assertion does not hold -} - -// Reassignment to constrained-typed variable — invalid -procedure reassignInvalid() { - var y: nat := 5; - y := -1; -//^^^^^^^^ error: assertion does not hold -} - -// Argument to constrained-typed parameter — valid -procedure takesNat(n: nat) returns (r: int) { return n; } -// ^^^ error: assertion does not hold -procedure argValid() returns (r: int) { - var x: int := takesNat(3); - return x; -} - -// Argument to constrained-typed parameter — invalid (requires violation) -procedure argInvalid() returns (r: int) { - var x: int := takesNat(-1); - return x; -} - -// Nested constrained type — independent constraints require transitive collection -constrained even = x: int where x % 2 == 0 witness 0 -constrained evenpos = x: even where x > 0 witness 2 -procedure nestedInput(x: evenpos) { - assert x > 0; - assert x % 2 == 0; -} - -// Multiple constrained-typed parameters -procedure multiParam(a: nat, b: nat) { - assert a >= 0; - assert b >= 0; -} - -// Two calls to same procedure — no temp var collision -procedure twoCalls() returns (r: int) { - var a: int := takesNat(1); - var b: int := takesNat(2); - return a + b; -} - -// Constrained type in expression position must be resolved -procedure constrainedInExpr() { - var b: bool := forall(n: nat) => n + 1 > n; - assert b; -} - -// Invalid witness — witness -1 does not satisfy x > 0 -constrained bad = x: int where x > 0 witness -1 -// ^^ error: assertion does not hold - -// Uninitialized constrained variable — witness used as default -procedure uninitNat() { - var y: nat; - assert y >= 0; -} - -// Uninitialized nested constrained variable — outermost witness used -procedure uninitPosnat() { - var y: posnat; - assert y > 0; - assert y >= 0; -} - -// Function with valid constrained return — constraint not checked (not yet supported) -function goodFunc(): nat { 3 } -// ^^^^^^^^ error: constrained return types on functions are not yet supported - -// Function with invalid constrained return — constraint not checked (not yet supported) -function badFunc(): nat { -1 } -// ^^^^^^^ error: constrained return types on functions are not yet supported - -// Caller of constrained function — body is inlined, caller sees actual value -procedure callerGood() { - var x: int := goodFunc(); - assert x >= 0; -} - -// Quantifier constraint injection — forall -// n + 1 > 0 is only provable with n >= 0 injected; false for all int -procedure forallNat() { - var b: bool := forall(n: nat) => n + 1 > 0; - assert b; -} - -// Quantifier constraint injection — exists -// n == -1 is satisfiable for int, but not when n >= 0 is required -// n == 42 works because 42 >= 0 -procedure existsNat() { - var b: bool := exists(n: nat) => n == 42; - assert b; -} - -// Quantifier constraint injection — nested constrained type -// n - 1 >= 0 is only provable with n > 0 injected -procedure forallPosnat() { - var b: bool := forall(n: posnat) => n - 1 >= 0; - assert b; -} - -// Capture avoidance — bound var y in constraint must not collide with parameter y -// Without capture avoidance, requires becomes exists(y) => y > y (false), making body vacuously true -constrained haslarger = x: int where (exists(y: int) => y > x) witness 0 -procedure captureTest(y: haslarger) { - assert false; -//^^^^^^^^^^^^^ error: assertion does not hold -} -" - -#guard_msgs(drop info, error) in -#eval testInputWithOffset "ConstrainedTypes" program 14 processLaurelFile - -end Laurel -end Strata diff --git a/StrataTest/Languages/Laurel/Examples/Fundamentals/T10_ConstrainedTypes.lean b/StrataTest/Languages/Laurel/Examples/Fundamentals/T10_ConstrainedTypes.lean index 33fcb29b0..a8c36fe87 100644 --- a/StrataTest/Languages/Laurel/Examples/Fundamentals/T10_ConstrainedTypes.lean +++ b/StrataTest/Languages/Laurel/Examples/Fundamentals/T10_ConstrainedTypes.lean @@ -8,23 +8,161 @@ import StrataTest.Util.TestDiagnostics import StrataTest.Languages.Laurel.TestExamples open StrataTest.Util -open Strata +namespace Strata namespace Laurel def program := r" constrained nat = x: int where x >= 0 witness 0 +constrained posnat = x: nat where x > 0 witness 1 -composite Option {} -composite Some extends Option { - value: int -} -composite None extends Option -constrained SealedOption = x: Option where x is Some || x is None witness None +// Input constraint becomes requires — body can rely on it +procedure inputAssumed(n: nat) { + assert n >= 0; +}; + +// Output constraint — valid return passes +procedure outputValid(): nat { + return 3; +}; + +// Output constraint — invalid return fails +procedure outputInvalid(): nat { +// ^^^ error: assertion does not hold + return -1; +}; + +// Return value of constrained type — caller gets ensures via call elimination +procedure opaqueNat(): nat; +procedure callerAssumes() returns (r: int) { + var x: int := opaqueNat(); + assert x >= 0; + return x; +}; + +// Assignment to constrained-typed variable — valid +procedure assignValid() { + var y: nat := 5; +}; + +// Assignment to constrained-typed variable — invalid +procedure assignInvalid() { + var y: nat := -1; +//^^^^^^^^^^^^^^^^^ error: assertion does not hold +}; + +// Reassignment to constrained-typed variable — invalid +procedure reassignInvalid() { + var y: nat := 5; + y := -1; +//^^^^^^^^ error: assertion does not hold +}; + +// Argument to constrained-typed parameter — valid +procedure takesNat(n: nat) returns (r: int) { return n; }; +// ^^^ error: assertion does not hold +procedure argValid() returns (r: int) { + var x: int := takesNat(3); + return x; +}; + +// Argument to constrained-typed parameter — invalid (requires violation) +procedure argInvalid() returns (r: int) { + var x: int := takesNat(-1); + return x; +}; + +// Nested constrained type — independent constraints require transitive collection +constrained even = x: int where x % 2 == 0 witness 0 +constrained evenpos = x: even where x > 0 witness 2 +procedure nestedInput(x: evenpos) { + assert x > 0; + assert x % 2 == 0; +}; + +// Multiple constrained-typed parameters +procedure multiParam(a: nat, b: nat) { + assert a >= 0; + assert b >= 0; +}; -procedure foo() returns (r: nat) { +// Two calls to same procedure — no temp var collision +procedure twoCalls() returns (r: int) { + var a: int := takesNat(1); + var b: int := takesNat(2); + return a + b; +}; + +// Constrained type in expression position must be resolved +procedure constrainedInExpr() { + var b: bool := forall(n: nat) => n + 1 > n; + assert b; +}; + +// Invalid witness — witness -1 does not satisfy x > 0 +constrained bad = x: int where x > 0 witness -1 +// ^^ error: assertion does not hold + +// Uninitialized constrained variable — witness used as default +procedure uninitNat() { + var y: nat; + assert y >= 0; +}; + +// Uninitialized nested constrained variable — outermost witness used +procedure uninitPosnat() { + var y: posnat; + assert y > 0; + assert y >= 0; +}; + +// Function with valid constrained return — constraint not checked (not yet supported) +function goodFunc(): nat { 3 }; +// ^^^^^^^^ error: constrained return types on functions are not yet supported + +// Function with invalid constrained return — constraint not checked (not yet supported) +function badFunc(): nat { -1 }; +// ^^^^^^^ error: constrained return types on functions are not yet supported + +// Caller of constrained function — body is inlined, caller sees actual value +procedure callerGood() { + var x: int := goodFunc(); + assert x >= 0; +}; + +// Quantifier constraint injection — forall +// n + 1 > 0 is only provable with n >= 0 injected; false for all int +procedure forallNat() { + var b: bool := forall(n: nat) => n + 1 > 0; + assert b; +}; + +// Quantifier constraint injection — exists +// n == -1 is satisfiable for int, but not when n >= 0 is required +// n == 42 works because 42 >= 0 +procedure existsNat() { + var b: bool := exists(n: nat) => n == 42; + assert b; +}; + +// Quantifier constraint injection — nested constrained type +// n - 1 >= 0 is only provable with n > 0 injected +procedure forallPosnat() { + var b: bool := forall(n: posnat) => n - 1 >= 0; + assert b; +}; + +// Capture avoidance — bound var y in constraint must not collide with parameter y +// Without capture avoidance, requires becomes exists(y) => y > y (false), making body vacuously true +constrained haslarger = x: int where (exists(y: int) => y > x) witness 0 +procedure captureTest(y: haslarger) { + assert false; +//^^^^^^^^^^^^^ error: assertion does not hold }; " --- Not working yet --- #eval! testInput "ConstrainedTypes" program processLaurelFile +#guard_msgs(drop info, error) in +#eval testInputWithOffset "ConstrainedTypes" program 14 processLaurelFile + +end Laurel +end Strata From 6c2b49bda37d1900d15ece5b7dd50d40d46874f7 Mon Sep 17 00:00:00 2001 From: Fabio Madge Date: Wed, 11 Mar 2026 03:15:03 +0100 Subject: [PATCH 04/11] Fix scope management bug in elimStmt; add regression test Save/restore PredVarMap state around Block, IfThenElse, and While via inScope helper to prevent constrained variable entries from leaking across sibling scopes. Reported by shigoel: a constrained variable declared in an if-branch would cause spurious asserts on same-named variables in sibling blocks. --- .../Languages/Laurel/ConstrainedTypeElim.lean | 16 ++++++---- .../Laurel/ConstrainedTypeElimTest.lean | 30 +++++++++++++++++++ 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/Strata/Languages/Laurel/ConstrainedTypeElim.lean b/Strata/Languages/Laurel/ConstrainedTypeElim.lean index 5eaab12ee..0d6264991 100644 --- a/Strata/Languages/Laurel/ConstrainedTypeElim.lean +++ b/Strata/Languages/Laurel/ConstrainedTypeElim.lean @@ -151,6 +151,12 @@ decreasing_by all_goals (have := WithMetadata.sizeOf_val_lt ‹_›; term_by_mem /-- Insert asserts for constrained-typed variable init and reassignment -/ abbrev ElimM := StateM PredVarMap +private def inScope (action : ElimM α) : ElimM α := do + let saved ← get + let result ← action + set saved + return result + def elimStmt (ptMap : ConstrainedTypeMap) (stmt : StmtExprMd) : ElimM (List StmtExprMd) := do let md := stmt.md @@ -176,19 +182,19 @@ def elimStmt (ptMap : ConstrainedTypeMap) | _ => pure [stmt] | .Block stmts sep => - let stmtss ← stmts.mapM (elimStmt ptMap) + let stmtss ← inScope (stmts.mapM (elimStmt ptMap)) pure [⟨.Block stmtss.flatten sep, md⟩] | .IfThenElse cond thenBr (some elseBr) => - let thenSs ← elimStmt ptMap thenBr - let elseSs ← elimStmt ptMap elseBr + let thenSs ← inScope (elimStmt ptMap thenBr) + let elseSs ← inScope (elimStmt ptMap elseBr) pure [⟨.IfThenElse cond (wrap thenSs md) (some (wrap elseSs md)), md⟩] | .IfThenElse cond thenBr none => - let thenSs ← elimStmt ptMap thenBr + let thenSs ← inScope (elimStmt ptMap thenBr) pure [⟨.IfThenElse cond (wrap thenSs md) none, md⟩] | .While cond inv dec body => - let bodySs ← elimStmt ptMap body + let bodySs ← inScope (elimStmt ptMap body) pure [⟨.While cond inv dec (wrap bodySs md), md⟩] | _ => pure [stmt] diff --git a/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean b/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean index f3ce4399e..54841c473 100644 --- a/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean +++ b/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean @@ -58,4 +58,34 @@ deterministic for proc in program.staticProcedures do IO.println (toString (Std.Format.pretty (Std.ToFormat.format proc))) +-- Scope management: constrained variable in if-branch must not leak into sibling block +def scopeProgram : String := r" +constrained pos = v: int where v > 0 witness 1 +procedure test(b: bool) { + if (b) { + var x: pos := 1; + } + { + var x: int := -5; + x := -10; + } +}; +" + +/-- +info: procedure test(b: bool) returns ⏎ +() +deterministic +{ if b then { var x: int := 1; assert x > 0 }; { var x: int := -5; x := -10 } } +procedure $witness_pos() returns ⏎ +() +deterministic +{ var $witness: int := 1; assert $witness > 0 } +-/ +#guard_msgs in +#eval! do + let program ← parseLaurelAndElim scopeProgram + for proc in program.staticProcedures do + IO.println (toString (Std.Format.pretty (Std.ToFormat.format proc))) + end Laurel From 64dfbf70830b8fbbe1e35945d097e2e411d1ea7a Mon Sep 17 00:00:00 2001 From: Fabio Madge Date: Wed, 11 Mar 2026 20:37:18 +0100 Subject: [PATCH 05/11] Replace substId with constraint functions Generate named constraint functions (e.g. nat$constraint) instead of inlining substituted constraint expressions. This eliminates substId (with its partial annotation and capture avoidance complexity) and produces clearer Core output with named function calls. Constraint functions are placed before user procedures to ensure resolution processes them first (workaround for resolution ID assignment order dependency). --- .../Languages/Laurel/ConstrainedTypeElim.lean | 174 ++++++++---------- .../Laurel/ConstrainedTypeElimTest.lean | 22 ++- 2 files changed, 95 insertions(+), 101 deletions(-) diff --git a/Strata/Languages/Laurel/ConstrainedTypeElim.lean b/Strata/Languages/Laurel/ConstrainedTypeElim.lean index 0d6264991..d45cc0b43 100644 --- a/Strata/Languages/Laurel/ConstrainedTypeElim.lean +++ b/Strata/Languages/Laurel/ConstrainedTypeElim.lean @@ -11,15 +11,16 @@ import Strata.Languages.Laurel.Resolution # Constrained Type Elimination A Laurel-to-Laurel pass that eliminates constrained types by: -1. Adding `requires` for constrained-typed inputs (Core handles caller asserts and body assumes) -2. Adding `ensures` for constrained-typed outputs (Core handles body checks and caller assumes) +1. Generating a constraint function per constrained type (e.g. `nat$constraint(x: int): bool`) +2. Adding `requires constraintFunc(param)` for constrained-typed inputs +3. Adding `ensures constraintFunc(result)` for constrained-typed outputs - Skipped for `isFunctional` procedures since the Laurel translator does not yet support function postconditions. Constrained return types on functions are not checked. -3. Inserting `assert` for local variable init and reassignment of constrained-typed variables -4. Using the witness as default initializer for uninitialized constrained-typed variables -5. Adding a synthetic witness-validation procedure per constrained type -6. Injecting constraints into quantifier bodies (`forall` → `implies`, `exists` → `and`) -7. Resolving all constrained type references to their base types +4. Inserting `assert constraintFunc(var)` for local variable init and reassignment +5. Using the witness as default initializer for uninitialized constrained-typed variables +6. Adding a synthetic witness-validation procedure per constrained type +7. Injecting constraint function calls into quantifier bodies (`forall` → `implies`, `exists` → `and`) +8. Resolving all constrained type references to their base types -/ namespace Strata.Laurel @@ -48,68 +49,56 @@ def resolveType (ptMap : ConstrainedTypeMap) (ty : HighTypeMd) : HighTypeMd := def isConstrainedType (ptMap : ConstrainedTypeMap) (ty : HighType) : Bool := match ty with | .UserDefined name => ptMap.contains name.text | _ => false -/-- All predicates for a type transitively (e.g. evenpos → [(x, x > 0), (x, x % 2 == 0)]) -/ -partial def getAllConstraints (ptMap : ConstrainedTypeMap) (ty : HighType) - : List (Identifier × StmtExprMd) := +/-- Build a call to the constraint function for a constrained type, or `none` if not constrained -/ +def constraintCallFor (ptMap : ConstrainedTypeMap) (ty : HighType) + (arg : StmtExprMd) (md : Imperative.MetaData Core.Expression) : Option StmtExprMd := match ty with - | .UserDefined name => match ptMap.get? name.text with - | some ct => (ct.valueName, ct.constraint) :: getAllConstraints ptMap ct.base.val - | none => [] - | _ => [] + | .UserDefined name => if ptMap.contains name.text then + some ⟨.StaticCall (mkId s!"{name.text}$constraint") [arg], md⟩ + else none + | _ => none -/-- Substitute `Identifier old` with `Identifier new` in a constraint expression -/ -partial def substId (old new : Identifier) : StmtExprMd → StmtExprMd - | ⟨.Identifier n, md⟩ => ⟨if n == old then .Identifier new else .Identifier n, md⟩ - | ⟨.PrimitiveOp op args, md⟩ => - ⟨.PrimitiveOp op (args.map fun a => substId old new a), md⟩ - | ⟨.StaticCall c args, md⟩ => - ⟨.StaticCall c (args.map fun a => substId old new a), md⟩ - | ⟨.IfThenElse c t (some el), md⟩ => - ⟨.IfThenElse (substId old new c) (substId old new t) (some (substId old new el)), md⟩ - | ⟨.IfThenElse c t none, md⟩ => - ⟨.IfThenElse (substId old new c) (substId old new t) none, md⟩ - | ⟨.Block ss sep, md⟩ => - ⟨.Block (ss.map fun s => substId old new s) sep, md⟩ - | ⟨.Forall param body, md⟩ => - if param.name == old then ⟨.Forall param body, md⟩ - else if param.name == new then - let fresh : Identifier := mkId (param.name.text ++ "$") - ⟨.Forall { param with name := fresh } (substId old new (substId param.name fresh body)), md⟩ - else ⟨.Forall param (substId old new body), md⟩ - | ⟨.Exists param body, md⟩ => - if param.name == old then ⟨.Exists param body, md⟩ - else if param.name == new then - let fresh : Identifier := mkId (param.name.text ++ "$") - ⟨.Exists { param with name := fresh } (substId old new (substId param.name fresh body)), md⟩ - else ⟨.Exists param (substId old new body), md⟩ +/-- Clear all uniqueIds from identifiers in an expression so resolution can re-assign them -/ +partial def clearIds : StmtExprMd → StmtExprMd + | ⟨.Identifier n, md⟩ => ⟨.Identifier { n with uniqueId := none }, md⟩ + | ⟨.PrimitiveOp op args, md⟩ => ⟨.PrimitiveOp op (args.map clearIds), md⟩ + | ⟨.StaticCall c args, md⟩ => ⟨.StaticCall { c with uniqueId := none } (args.map clearIds), md⟩ + | ⟨.Block ss sep, md⟩ => ⟨.Block (ss.map clearIds) sep, md⟩ + | ⟨.IfThenElse c t (some e), md⟩ => ⟨.IfThenElse (clearIds c) (clearIds t) (some (clearIds e)), md⟩ + | ⟨.IfThenElse c t none, md⟩ => ⟨.IfThenElse (clearIds c) (clearIds t) none, md⟩ + | ⟨.Forall p body, md⟩ => ⟨.Forall { p with name := { p.name with uniqueId := none } } (clearIds body), md⟩ + | ⟨.Exists p body, md⟩ => ⟨.Exists { p with name := { p.name with uniqueId := none } } (clearIds body), md⟩ | e => e -def mkAsserts (ptMap : ConstrainedTypeMap) (ty : HighType) (varName : Identifier) - (md : Imperative.MetaData Core.Expression) : List StmtExprMd := - (getAllConstraints ptMap ty).map fun (valueName, pred) => - ⟨.Assert (substId valueName varName pred), md⟩ +/-- Generate a constraint function for a constrained type. + For nested types, the function calls the parent's constraint function. -/ +def mkConstraintFunc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) : Procedure := + let md : Imperative.MetaData Core.Expression := #[] + let baseType := resolveType ptMap ct.base + let bodyExpr := match ct.base.val with + | .UserDefined parent => + if ptMap.contains parent.text then + let parentCall : StmtExprMd := + ⟨.StaticCall (mkId s!"{parent.text}$constraint") [⟨.Identifier { ct.valueName with uniqueId := none }, md⟩], md⟩ + ⟨.PrimitiveOp .And [clearIds ct.constraint, parentCall], md⟩ + else clearIds ct.constraint + | _ => clearIds ct.constraint + { name := mkId s!"{ct.name.text}$constraint" + inputs := [{ name := { ct.valueName with uniqueId := none }, type := { baseType with md := #[] } }] + outputs := [{ name := mkId "result", type := ⟨.TBool, #[]⟩ }] + body := .Transparent ⟨.Block [bodyExpr] none, #[]⟩ + isFunctional := true + determinism := .deterministic none + decreases := none + preconditions := [] + md := #[] } private def wrap (stmts : List StmtExprMd) (md : Imperative.MetaData Core.Expression) : StmtExprMd := match stmts with | [s] => s | ss => ⟨.Block ss none, md⟩ -/-- Inject constraints into a quantifier body for a constrained type -/ -private def injectQuantifierConstraint (ptMap : ConstrainedTypeMap) (ty : HighType) - (varName : Identifier) (body : StmtExprMd) (isForall : Bool) : StmtExprMd := - let constraints := getAllConstraints ptMap ty - match constraints with - | [] => body - | _ => - let preds := constraints.map fun (vn, pred) => substId vn varName pred - match preds with - | [] => body -- unreachable - | first :: rest => - let conj := rest.foldl (init := first) fun acc p => - ⟨.PrimitiveOp .And [acc, p], body.md⟩ - if isForall then ⟨.PrimitiveOp .Implies [conj, body], body.md⟩ - else ⟨.PrimitiveOp .And [conj, body], body.md⟩ - -/-- Resolve constrained types in all type positions of an expression -/ +/-- Resolve constrained types in all type positions of an expression, + and inject constraint function calls into quantifier bodies -/ def resolveExpr (ptMap : ConstrainedTypeMap) : StmtExprMd → StmtExprMd | ⟨.LocalVariable n ty (some init), md⟩ => ⟨.LocalVariable n (resolveType ptMap ty) (some (resolveExpr ptMap init)), md⟩ @@ -118,11 +107,17 @@ def resolveExpr (ptMap : ConstrainedTypeMap) : StmtExprMd → StmtExprMd | ⟨.Forall param body, md⟩ => let body' := resolveExpr ptMap body let param' := { param with type := resolveType ptMap param.type } - ⟨.Forall param' (injectQuantifierConstraint ptMap param.type.val param.name body' true), md⟩ + let injected := match constraintCallFor ptMap param.type.val ⟨.Identifier param.name, md⟩ md with + | some c => ⟨.PrimitiveOp .Implies [c, body'], md⟩ + | none => body' + ⟨.Forall param' injected, md⟩ | ⟨.Exists param body, md⟩ => let body' := resolveExpr ptMap body let param' := { param with type := resolveType ptMap param.type } - ⟨.Exists param' (injectQuantifierConstraint ptMap param.type.val param.name body' false), md⟩ + let injected := match constraintCallFor ptMap param.type.val ⟨.Identifier param.name, md⟩ md with + | some c => ⟨.PrimitiveOp .And [c, body'], md⟩ + | none => body' + ⟨.Exists param' injected, md⟩ | ⟨.AsType t ty, md⟩ => ⟨.AsType (resolveExpr ptMap t) (resolveType ptMap ty), md⟩ | ⟨.IsType t ty, md⟩ => ⟨.IsType (resolveExpr ptMap t) (resolveType ptMap ty), md⟩ | ⟨.PrimitiveOp op args, md⟩ => @@ -148,7 +143,6 @@ def resolveExpr (ptMap : ConstrainedTypeMap) : StmtExprMd → StmtExprMd termination_by e => sizeOf e decreasing_by all_goals (have := WithMetadata.sizeOf_val_lt ‹_›; term_by_mem) -/-- Insert asserts for constrained-typed variable init and reassignment -/ abbrev ElimM := StateM PredVarMap private def inScope (action : ElimM α) : ElimM α := do @@ -162,22 +156,23 @@ def elimStmt (ptMap : ConstrainedTypeMap) let md := stmt.md match _h : stmt.val with | .LocalVariable name ty init => - let isPred := isConstrainedType ptMap ty.val - if isPred then modify fun pv => pv.insert name.text ty.val - let asserts := if isPred then mkAsserts ptMap ty.val name md else [] - -- Use witness as default initializer for uninitialized constrained variables + let callOpt := constraintCallFor ptMap ty.val ⟨.Identifier name, md⟩ md + if callOpt.isSome then modify fun pv => pv.insert name.text ty.val + let assert := callOpt.toList.map fun c => ⟨.Assert c, md⟩ let init' := match init with | none => match ty.val with | .UserDefined n => (ptMap.get? n.text).map (·.witness) | _ => none | some _ => init - pure ([⟨.LocalVariable name ty init', md⟩] ++ asserts) + pure ([⟨.LocalVariable name ty init', md⟩] ++ assert) - -- Single-target only; multi-target assignments are not supported by the Laurel grammar | .Assign [target] _ => match target.val with | .Identifier name => do match (← get).get? name.text with - | some ty => pure ([stmt] ++ mkAsserts ptMap ty name md) + | some ty => + let assert := (constraintCallFor ptMap ty ⟨.Identifier name, md⟩ md).toList.map + fun c => ⟨.Assert c, md⟩ + pure ([stmt] ++ assert) | none => pure [stmt] | _ => pure [stmt] @@ -206,15 +201,11 @@ decreasing_by all_goals omega def elimProc (ptMap : ConstrainedTypeMap) (proc : Procedure) : Procedure := - -- Add requires for constrained-typed inputs - let inputRequires := proc.inputs.flatMap fun p => - (getAllConstraints ptMap p.type.val).map fun (vn, pred) => - ⟨(substId vn p.name pred).val, p.type.md⟩ - -- Add ensures for constrained-typed outputs (skip for isFunctional — not yet supported) - let outputEnsures := if proc.isFunctional then [] else proc.outputs.flatMap fun p => - (getAllConstraints ptMap p.type.val).map fun (vn, pred) => - ⟨(substId vn p.name pred).val, p.type.md⟩ - -- Transform body: insert asserts for local variable init/reassignment + let inputRequires := proc.inputs.filterMap fun p => + constraintCallFor ptMap p.type.val ⟨.Identifier p.name, p.type.md⟩ p.type.md + let outputEnsures := if proc.isFunctional then [] else proc.outputs.filterMap fun p => + (constraintCallFor ptMap p.type.val ⟨.Identifier p.name, p.type.md⟩ p.type.md).map + fun c => ⟨c.val, p.type.md⟩ let initVars : PredVarMap := proc.inputs.foldl (init := {}) fun s p => if isConstrainedType ptMap p.type.val then s.insert p.name.text p.type.val else s let body' := match proc.body with @@ -223,7 +214,6 @@ def elimProc (ptMap : ConstrainedTypeMap) (proc : Procedure) : Procedure := let body := wrap stmts bodyExpr.md if outputEnsures.isEmpty then .Transparent body else - -- Wrap expression body in a Return so it translates correctly as a procedure let retBody := if proc.isFunctional then ⟨.Return (some body), bodyExpr.md⟩ else body .Opaque outputEnsures (some retBody) [] | .Opaque postconds impl modif => @@ -231,7 +221,6 @@ def elimProc (ptMap : ConstrainedTypeMap) (proc : Procedure) : Procedure := .Opaque (postconds ++ outputEnsures) impl' modif | .Abstract postconds => .Abstract (postconds ++ outputEnsures) | .External => .External - -- Resolve all constrained types to base types let resolve := resolveExpr ptMap let resolveBody : Body → Body := fun body => match body with | .Transparent b => .Transparent (resolve b) @@ -244,40 +233,37 @@ def elimProc (ptMap : ConstrainedTypeMap) (proc : Procedure) : Procedure := outputs := proc.outputs.map fun p => { p with type := resolveType ptMap p.type } preconditions := (proc.preconditions ++ inputRequires).map resolve } -/-- Create a synthetic procedure that asserts the witness satisfies all constraints -/ private def mkWitnessProc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) : Procedure := let md := ct.witness.md let witnessId : Identifier := mkId "$witness" let witnessInit : StmtExprMd := - ⟨.LocalVariable witnessId (resolveType ptMap ct.base) (some ct.witness), md⟩ - let asserts := (getAllConstraints ptMap (.UserDefined ct.name)).map fun (vn, pred) => - ⟨.Assert (substId vn witnessId pred), md⟩ + ⟨.LocalVariable witnessId (resolveType ptMap ct.base) (some (clearIds ct.witness)), md⟩ + let assert : StmtExprMd := + ⟨.Assert (constraintCallFor ptMap (.UserDefined ct.name) ⟨.Identifier witnessId, md⟩ md).get!, md⟩ { name := mkId s!"$witness_{ct.name.text}" inputs := [] outputs := [] - body := .Transparent ⟨.Block ([witnessInit] ++ asserts) none, md⟩ + body := .Transparent ⟨.Block [witnessInit, assert] none, md⟩ preconditions := [] isFunctional := false determinism := .deterministic none decreases := none md := md } -/-- Eliminate constrained types from a Laurel program. - The `witness` field is used as the default initializer for uninitialized - constrained-typed variables, and is validated via synthetic procedures. -/ def constrainedTypeElim (_model : SemanticModel) (program : Program) : Program × Array DiagnosticModel := let ptMap := buildConstrainedTypeMap program.types if ptMap.isEmpty then (program, #[]) else - -- Report unsupported: isFunctional procedures with constrained return types + let constraintFuncs := program.types.filterMap fun + | .Constrained ct => some (mkConstraintFunc ptMap ct) | _ => none + let witnessProcedures := program.types.filterMap fun + | .Constrained ct => some (mkWitnessProc ptMap ct) | _ => none let funcDiags := program.staticProcedures.foldl (init := #[]) fun acc proc => if proc.isFunctional && proc.outputs.any (fun p => isConstrainedType ptMap p.type.val) then acc.push (proc.md.toDiagnostic "constrained return types on functions are not yet supported") else acc - let witnessProcedures := program.types.filterMap fun - | .Constrained ct => some (mkWitnessProc ptMap ct) - | _ => none ({ program with - staticProcedures := program.staticProcedures.map (elimProc ptMap) ++ witnessProcedures + staticProcedures := constraintFuncs ++ program.staticProcedures.map (elimProc ptMap) + ++ witnessProcedures types := program.types.filter fun | .Constrained _ => false | _ => true }, funcDiags) diff --git a/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean b/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean index 54841c473..accadce66 100644 --- a/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean +++ b/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean @@ -42,15 +42,19 @@ def parseLaurelAndElim (input : String) : IO Program := do pure (constrainedTypeElim model program).1 /-- -info: procedure test(n: int) returns ⏎ +info: function nat$constraint(x: int) returns ⏎ +(result: bool) +deterministic +{ x >= 0 } +procedure test(n: int) returns ⏎ (r: int) -requires n >= 0 +requires nat$constraint(n) deterministic - ensures r >= 0 := { var y: int := n; assert y >= 0; return y } + ensures nat$constraint(r) := { var y: int := n; assert nat$constraint(y); return y } procedure $witness_nat() returns ⏎ () deterministic -{ var $witness: int := 0; assert $witness >= 0 } +{ var $witness: int := 0; assert nat$constraint($witness) } -/ #guard_msgs in #eval! do @@ -73,14 +77,18 @@ procedure test(b: bool) { " /-- -info: procedure test(b: bool) returns ⏎ +info: function pos$constraint(v: int) returns ⏎ +(result: bool) +deterministic +{ v > 0 } +procedure test(b: bool) returns ⏎ () deterministic -{ if b then { var x: int := 1; assert x > 0 }; { var x: int := -5; x := -10 } } +{ if b then { var x: int := 1; assert pos$constraint(x) }; { var x: int := -5; x := -10 } } procedure $witness_pos() returns ⏎ () deterministic -{ var $witness: int := 1; assert $witness > 0 } +{ var $witness: int := 1; assert pos$constraint($witness) } -/ #guard_msgs in #eval! do From 78244a36f410181afeb1e270174acbe44b2bbabe Mon Sep 17 00:00:00 2001 From: Fabio Madge Date: Wed, 11 Mar 2026 20:46:06 +0100 Subject: [PATCH 06/11] Replace substId with constraint functions Generate named constraint functions (e.g. nat$constraint) instead of inlining substituted constraint expressions. This eliminates substId (with its partial annotation and capture avoidance complexity) and produces clearer Core output with named function calls. Constraint functions are placed before user procedures so that resolution assigns their IDs before resolving references in user procedure bodies. Also changes posnat constraint to x != 0 per review suggestion. --- .../Languages/Laurel/ConstrainedTypeElim.lean | 40 +++++++------------ .../Fundamentals/T10_ConstrainedTypes.lean | 4 +- 2 files changed, 16 insertions(+), 28 deletions(-) diff --git a/Strata/Languages/Laurel/ConstrainedTypeElim.lean b/Strata/Languages/Laurel/ConstrainedTypeElim.lean index d45cc0b43..be15161fd 100644 --- a/Strata/Languages/Laurel/ConstrainedTypeElim.lean +++ b/Strata/Languages/Laurel/ConstrainedTypeElim.lean @@ -51,25 +51,13 @@ def isConstrainedType (ptMap : ConstrainedTypeMap) (ty : HighType) : Bool := /-- Build a call to the constraint function for a constrained type, or `none` if not constrained -/ def constraintCallFor (ptMap : ConstrainedTypeMap) (ty : HighType) - (arg : StmtExprMd) (md : Imperative.MetaData Core.Expression) : Option StmtExprMd := + (varName : Identifier) (md : Imperative.MetaData Core.Expression) : Option StmtExprMd := match ty with | .UserDefined name => if ptMap.contains name.text then - some ⟨.StaticCall (mkId s!"{name.text}$constraint") [arg], md⟩ + some ⟨.StaticCall (mkId s!"{name.text}$constraint") [⟨.Identifier varName, md⟩], md⟩ else none | _ => none -/-- Clear all uniqueIds from identifiers in an expression so resolution can re-assign them -/ -partial def clearIds : StmtExprMd → StmtExprMd - | ⟨.Identifier n, md⟩ => ⟨.Identifier { n with uniqueId := none }, md⟩ - | ⟨.PrimitiveOp op args, md⟩ => ⟨.PrimitiveOp op (args.map clearIds), md⟩ - | ⟨.StaticCall c args, md⟩ => ⟨.StaticCall { c with uniqueId := none } (args.map clearIds), md⟩ - | ⟨.Block ss sep, md⟩ => ⟨.Block (ss.map clearIds) sep, md⟩ - | ⟨.IfThenElse c t (some e), md⟩ => ⟨.IfThenElse (clearIds c) (clearIds t) (some (clearIds e)), md⟩ - | ⟨.IfThenElse c t none, md⟩ => ⟨.IfThenElse (clearIds c) (clearIds t) none, md⟩ - | ⟨.Forall p body, md⟩ => ⟨.Forall { p with name := { p.name with uniqueId := none } } (clearIds body), md⟩ - | ⟨.Exists p body, md⟩ => ⟨.Exists { p with name := { p.name with uniqueId := none } } (clearIds body), md⟩ - | e => e - /-- Generate a constraint function for a constrained type. For nested types, the function calls the parent's constraint function. -/ def mkConstraintFunc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) : Procedure := @@ -80,11 +68,11 @@ def mkConstraintFunc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) : Proce if ptMap.contains parent.text then let parentCall : StmtExprMd := ⟨.StaticCall (mkId s!"{parent.text}$constraint") [⟨.Identifier { ct.valueName with uniqueId := none }, md⟩], md⟩ - ⟨.PrimitiveOp .And [clearIds ct.constraint, parentCall], md⟩ - else clearIds ct.constraint - | _ => clearIds ct.constraint + ⟨.PrimitiveOp .And [ct.constraint, parentCall], md⟩ + else ct.constraint + | _ => ct.constraint { name := mkId s!"{ct.name.text}$constraint" - inputs := [{ name := { ct.valueName with uniqueId := none }, type := { baseType with md := #[] } }] + inputs := [{ name := ct.valueName, type := { baseType with md := #[] } }] outputs := [{ name := mkId "result", type := ⟨.TBool, #[]⟩ }] body := .Transparent ⟨.Block [bodyExpr] none, #[]⟩ isFunctional := true @@ -107,14 +95,14 @@ def resolveExpr (ptMap : ConstrainedTypeMap) : StmtExprMd → StmtExprMd | ⟨.Forall param body, md⟩ => let body' := resolveExpr ptMap body let param' := { param with type := resolveType ptMap param.type } - let injected := match constraintCallFor ptMap param.type.val ⟨.Identifier param.name, md⟩ md with + let injected := match constraintCallFor ptMap param.type.val param.name md with | some c => ⟨.PrimitiveOp .Implies [c, body'], md⟩ | none => body' ⟨.Forall param' injected, md⟩ | ⟨.Exists param body, md⟩ => let body' := resolveExpr ptMap body let param' := { param with type := resolveType ptMap param.type } - let injected := match constraintCallFor ptMap param.type.val ⟨.Identifier param.name, md⟩ md with + let injected := match constraintCallFor ptMap param.type.val param.name md with | some c => ⟨.PrimitiveOp .And [c, body'], md⟩ | none => body' ⟨.Exists param' injected, md⟩ @@ -156,7 +144,7 @@ def elimStmt (ptMap : ConstrainedTypeMap) let md := stmt.md match _h : stmt.val with | .LocalVariable name ty init => - let callOpt := constraintCallFor ptMap ty.val ⟨.Identifier name, md⟩ md + let callOpt := constraintCallFor ptMap ty.val name md if callOpt.isSome then modify fun pv => pv.insert name.text ty.val let assert := callOpt.toList.map fun c => ⟨.Assert c, md⟩ let init' := match init with @@ -170,7 +158,7 @@ def elimStmt (ptMap : ConstrainedTypeMap) | .Identifier name => do match (← get).get? name.text with | some ty => - let assert := (constraintCallFor ptMap ty ⟨.Identifier name, md⟩ md).toList.map + let assert := (constraintCallFor ptMap ty name md).toList.map fun c => ⟨.Assert c, md⟩ pure ([stmt] ++ assert) | none => pure [stmt] @@ -202,9 +190,9 @@ decreasing_by def elimProc (ptMap : ConstrainedTypeMap) (proc : Procedure) : Procedure := let inputRequires := proc.inputs.filterMap fun p => - constraintCallFor ptMap p.type.val ⟨.Identifier p.name, p.type.md⟩ p.type.md + constraintCallFor ptMap p.type.val p.name p.type.md let outputEnsures := if proc.isFunctional then [] else proc.outputs.filterMap fun p => - (constraintCallFor ptMap p.type.val ⟨.Identifier p.name, p.type.md⟩ p.type.md).map + (constraintCallFor ptMap p.type.val p.name p.type.md).map fun c => ⟨c.val, p.type.md⟩ let initVars : PredVarMap := proc.inputs.foldl (init := {}) fun s p => if isConstrainedType ptMap p.type.val then s.insert p.name.text p.type.val else s @@ -237,9 +225,9 @@ private def mkWitnessProc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) : let md := ct.witness.md let witnessId : Identifier := mkId "$witness" let witnessInit : StmtExprMd := - ⟨.LocalVariable witnessId (resolveType ptMap ct.base) (some (clearIds ct.witness)), md⟩ + ⟨.LocalVariable witnessId (resolveType ptMap ct.base) (some ct.witness), md⟩ let assert : StmtExprMd := - ⟨.Assert (constraintCallFor ptMap (.UserDefined ct.name) ⟨.Identifier witnessId, md⟩ md).get!, md⟩ + ⟨.Assert (constraintCallFor ptMap (.UserDefined ct.name) witnessId md).get!, md⟩ { name := mkId s!"$witness_{ct.name.text}" inputs := [] outputs := [] diff --git a/StrataTest/Languages/Laurel/Examples/Fundamentals/T10_ConstrainedTypes.lean b/StrataTest/Languages/Laurel/Examples/Fundamentals/T10_ConstrainedTypes.lean index a8c36fe87..8db7b12ef 100644 --- a/StrataTest/Languages/Laurel/Examples/Fundamentals/T10_ConstrainedTypes.lean +++ b/StrataTest/Languages/Laurel/Examples/Fundamentals/T10_ConstrainedTypes.lean @@ -14,7 +14,7 @@ namespace Laurel def program := r" constrained nat = x: int where x >= 0 witness 0 -constrained posnat = x: nat where x > 0 witness 1 +constrained posnat = x: nat where x != 0 witness 1 // Input constraint becomes requires — body can rely on it procedure inputAssumed(n: nat) { @@ -112,7 +112,7 @@ procedure uninitNat() { // Uninitialized nested constrained variable — outermost witness used procedure uninitPosnat() { var y: posnat; - assert y > 0; + assert y != 0; assert y >= 0; }; From 860eaca0248f202855ccf8a8394b47aa50691ddd Mon Sep 17 00:00:00 2001 From: Fabio Madge Date: Wed, 11 Mar 2026 21:27:00 +0100 Subject: [PATCH 07/11] Add output variable assert test per review suggestion --- StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean b/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean index accadce66..ca870f21c 100644 --- a/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean +++ b/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean @@ -24,6 +24,7 @@ namespace Strata.Laurel def testProgram : String := r" constrained nat = x: int where x >= 0 witness 0 procedure test(n: nat) returns (r: nat) { + assert r >= 0; var y: nat := n; return y; }; @@ -50,7 +51,7 @@ procedure test(n: int) returns ⏎ (r: int) requires nat$constraint(n) deterministic - ensures nat$constraint(r) := { var y: int := n; assert nat$constraint(y); return y } + ensures nat$constraint(r) := { assert r >= 0; var y: int := n; assert nat$constraint(y); return y } procedure $witness_nat() returns ⏎ () deterministic From 9f9d86ea65786620f11753c23e1132c5286857aa Mon Sep 17 00:00:00 2001 From: Fabio Madge Date: Wed, 11 Mar 2026 22:06:28 +0100 Subject: [PATCH 08/11] Add TODO for witness injection follow-up --- Strata/Languages/Laurel/ConstrainedTypeElim.lean | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Strata/Languages/Laurel/ConstrainedTypeElim.lean b/Strata/Languages/Laurel/ConstrainedTypeElim.lean index be15161fd..02381bb44 100644 --- a/Strata/Languages/Laurel/ConstrainedTypeElim.lean +++ b/Strata/Languages/Laurel/ConstrainedTypeElim.lean @@ -147,6 +147,9 @@ def elimStmt (ptMap : ConstrainedTypeMap) let callOpt := constraintCallFor ptMap ty.val name md if callOpt.isSome then modify fun pv => pv.insert name.text ty.val let assert := callOpt.toList.map fun c => ⟨.Assert c, md⟩ + -- TODO: Once the translator emits `init` without RHS (havoc) for uninitialized variables, + -- switch from witness injection + assert to assume. Currently the translator initializes + -- uninitialized variables to 0 (defaultExprForType), making assume unsound. let init' := match init with | none => match ty.val with | .UserDefined n => (ptMap.get? n.text).map (·.witness) From 673995aebcf322020134218ffb679ff2189f7cc4 Mon Sep 17 00:00:00 2001 From: Fabio Madge Date: Wed, 11 Mar 2026 23:16:55 +0100 Subject: [PATCH 09/11] Add uninitialized constrained variable benchmark per review --- .../Laurel/ConstrainedTypeElimTest.lean | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean b/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean index ca870f21c..86b243eb3 100644 --- a/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean +++ b/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean @@ -97,4 +97,35 @@ deterministic for proc in program.staticProcedures do IO.println (toString (Std.Format.pretty (Std.ToFormat.format proc))) +-- Uninitialized constrained variable: currently uses witness as default. +-- TODO: Once the translator emits havoc for uninitialized variables (#550), +-- switch to assume instead of witness injection + assert. +def uninitProgram : String := r" +constrained posint = x: int where x > 0 witness 1 +procedure f() { + var x: posint; + assert x == 1; +}; +" + +/-- +info: function posint$constraint(x: int) returns ⏎ +(result: bool) +deterministic +{ x > 0 } +procedure f() returns ⏎ +() +deterministic +{ var x: int := 1; assert posint$constraint(x); assert x == 1 } +procedure $witness_posint() returns ⏎ +() +deterministic +{ var $witness: int := 1; assert posint$constraint($witness) } +-/ +#guard_msgs in +#eval! do + let program ← parseLaurelAndElim uninitProgram + for proc in program.staticProcedures do + IO.println (toString (Std.Format.pretty (Std.ToFormat.format proc))) + end Laurel From 7526c9fa61a43d2dfb5c351b04306bd7c29561d6 Mon Sep 17 00:00:00 2001 From: Fabio Madge Date: Thu, 12 Mar 2026 02:27:20 +0100 Subject: [PATCH 10/11] Replace witness injection with havoc + assume for uninitialized constrained variables MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Include translator fix: emit init without RHS (havoc) for uninitialized variables instead of defaultExprForType. For uninitialized constrained variables, emit assume instead of assert: var x: posint; → var x: int; assume posint$constraint(x); The witness is now only used in witness validation procedures. --- .../Languages/Laurel/ConstrainedTypeElim.lean | 18 +++++++----------- .../Laurel/ConstrainedTypeElimTest.lean | 9 ++++----- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/Strata/Languages/Laurel/ConstrainedTypeElim.lean b/Strata/Languages/Laurel/ConstrainedTypeElim.lean index 02381bb44..86751ccf3 100644 --- a/Strata/Languages/Laurel/ConstrainedTypeElim.lean +++ b/Strata/Languages/Laurel/ConstrainedTypeElim.lean @@ -17,7 +17,7 @@ A Laurel-to-Laurel pass that eliminates constrained types by: - Skipped for `isFunctional` procedures since the Laurel translator does not yet support function postconditions. Constrained return types on functions are not checked. 4. Inserting `assert constraintFunc(var)` for local variable init and reassignment -5. Using the witness as default initializer for uninitialized constrained-typed variables +5. Assuming the constraint for uninitialized constrained-typed variables (havoc + assume) 6. Adding a synthetic witness-validation procedure per constrained type 7. Injecting constraint function calls into quantifier bodies (`forall` → `implies`, `exists` → `and`) 8. Resolving all constrained type references to their base types @@ -146,16 +146,12 @@ def elimStmt (ptMap : ConstrainedTypeMap) | .LocalVariable name ty init => let callOpt := constraintCallFor ptMap ty.val name md if callOpt.isSome then modify fun pv => pv.insert name.text ty.val - let assert := callOpt.toList.map fun c => ⟨.Assert c, md⟩ - -- TODO: Once the translator emits `init` without RHS (havoc) for uninitialized variables, - -- switch from witness injection + assert to assume. Currently the translator initializes - -- uninitialized variables to 0 (defaultExprForType), making assume unsound. - let init' := match init with - | none => match ty.val with - | .UserDefined n => (ptMap.get? n.text).map (·.witness) - | _ => none - | some _ => init - pure ([⟨.LocalVariable name ty init', md⟩] ++ assert) + let (init', check) : Option StmtExprMd × List StmtExprMd := match init with + | none => match callOpt with + | some c => (none, [⟨.Assume c, md⟩]) + | none => (none, []) + | some _ => (init, callOpt.toList.map fun c => ⟨.Assert c, md⟩) + pure ([⟨.LocalVariable name ty init', md⟩] ++ check) | .Assign [target] _ => match target.val with | .Identifier name => do diff --git a/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean b/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean index 86b243eb3..5b86f46ca 100644 --- a/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean +++ b/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean @@ -97,14 +97,13 @@ deterministic for proc in program.staticProcedures do IO.println (toString (Std.Format.pretty (Std.ToFormat.format proc))) --- Uninitialized constrained variable: currently uses witness as default. --- TODO: Once the translator emits havoc for uninitialized variables (#550), --- switch to assume instead of witness injection + assert. +-- Uninitialized constrained variable: havoc + assume constraint. +-- The variable has no known value, only the type constraint is assumed. def uninitProgram : String := r" constrained posint = x: int where x > 0 witness 1 procedure f() { var x: posint; - assert x == 1; + assert x > 0; }; " @@ -116,7 +115,7 @@ deterministic procedure f() returns ⏎ () deterministic -{ var x: int := 1; assert posint$constraint(x); assert x == 1 } +{ var x: int; assume posint$constraint(x); assert x > 0 } procedure $witness_posint() returns ⏎ () deterministic From f75d8b806d79a04df5408cafd4ddad50cfbf244b Mon Sep 17 00:00:00 2001 From: Fabio Madge Date: Thu, 12 Mar 2026 14:37:20 +0100 Subject: [PATCH 11/11] Keep assert x == 1 in uninit test; add verification test that witness is not provable --- .../Languages/Laurel/ConstrainedTypeElimTest.lean | 4 ++-- .../Examples/Fundamentals/T10_ConstrainedTypes.lean | 11 +++++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean b/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean index 5b86f46ca..a3a686b4a 100644 --- a/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean +++ b/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean @@ -103,7 +103,7 @@ def uninitProgram : String := r" constrained posint = x: int where x > 0 witness 1 procedure f() { var x: posint; - assert x > 0; + assert x == 1; }; " @@ -115,7 +115,7 @@ deterministic procedure f() returns ⏎ () deterministic -{ var x: int; assume posint$constraint(x); assert x > 0 } +{ var x: int; assume posint$constraint(x); assert x == 1 } procedure $witness_posint() returns ⏎ () deterministic diff --git a/StrataTest/Languages/Laurel/Examples/Fundamentals/T10_ConstrainedTypes.lean b/StrataTest/Languages/Laurel/Examples/Fundamentals/T10_ConstrainedTypes.lean index 8db7b12ef..2ae7e692e 100644 --- a/StrataTest/Languages/Laurel/Examples/Fundamentals/T10_ConstrainedTypes.lean +++ b/StrataTest/Languages/Laurel/Examples/Fundamentals/T10_ConstrainedTypes.lean @@ -103,19 +103,26 @@ procedure constrainedInExpr() { constrained bad = x: int where x > 0 witness -1 // ^^ error: assertion does not hold -// Uninitialized constrained variable — witness used as default +// Uninitialized constrained variable — havoc + assume constraint procedure uninitNat() { var y: nat; assert y >= 0; }; -// Uninitialized nested constrained variable — outermost witness used +// Uninitialized nested constrained variable — havoc + assume constraint procedure uninitPosnat() { var y: posnat; assert y != 0; assert y >= 0; }; +// Uninitialized constrained variable — witness value is not provable +procedure uninitNotWitness() { + var y: posnat; + assert y == 1; +//^^^^^^^^^^^^^^ error: assertion does not hold +}; + // Function with valid constrained return — constraint not checked (not yet supported) function goodFunc(): nat { 3 }; // ^^^^^^^^ error: constrained return types on functions are not yet supported