Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 257 additions & 0 deletions Strata/Languages/Laurel/ConstrainedTypeElim.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
/-
Copyright Strata Contributors

SPDX-License-Identifier: Apache-2.0 OR MIT
-/

import Strata.Languages.Laurel.Laurel
import Strata.Languages.Laurel.Resolution

/-!
# Constrained Type Elimination

A Laurel-to-Laurel pass that eliminates constrained types by:
1. Generating a constraint function per constrained type (e.g. `nat$constraint(x: int): bool`)
2. Adding `requires constraintFunc(param)` for constrained-typed inputs
3. Adding `ensures constraintFunc(result)` for constrained-typed outputs
- Skipped for `isFunctional` procedures since the Laurel translator does not yet support
function postconditions. Constrained return types on functions are not checked.
4. Inserting `assert constraintFunc(var)` for local variable init and reassignment
5. Assuming the constraint for uninitialized constrained-typed variables (havoc + assume)
6. Adding a synthetic witness-validation procedure per constrained type
7. Injecting constraint function calls into quantifier bodies (`forall` → `implies`, `exists` → `and`)
8. Resolving all constrained type references to their base types
-/

namespace Strata.Laurel

open Strata

abbrev ConstrainedTypeMap := Std.HashMap String ConstrainedType
/-- Map from variable name to its constrained HighType (e.g. UserDefined "nat") -/
abbrev PredVarMap := Std.HashMap String HighType

def buildConstrainedTypeMap (types : List TypeDefinition) : ConstrainedTypeMap :=
types.foldl (init := {}) fun m td =>
match td with | .Constrained ct => m.insert ct.name.text ct | _ => m

partial def resolveBaseType (ptMap : ConstrainedTypeMap) (ty : HighType) : HighType :=
match ty with
| .UserDefined name => match ptMap.get? name.text with
| some ct => resolveBaseType ptMap ct.base.val | none => ty
| .Applied ctor args =>
.Applied ctor (args.map fun a => ⟨resolveBaseType ptMap a.val, a.md⟩)
| _ => ty

def resolveType (ptMap : ConstrainedTypeMap) (ty : HighTypeMd) : HighTypeMd :=
⟨resolveBaseType ptMap ty.val, ty.md⟩

def isConstrainedType (ptMap : ConstrainedTypeMap) (ty : HighType) : Bool :=
match ty with | .UserDefined name => ptMap.contains name.text | _ => false

/-- Build a call to the constraint function for a constrained type, or `none` if not constrained -/
def constraintCallFor (ptMap : ConstrainedTypeMap) (ty : HighType)
(varName : Identifier) (md : Imperative.MetaData Core.Expression) : Option StmtExprMd :=
match ty with
| .UserDefined name => if ptMap.contains name.text then
some ⟨.StaticCall (mkId s!"{name.text}$constraint") [⟨.Identifier varName, md⟩], md⟩
else none
| _ => none

/-- Generate a constraint function for a constrained type.
For nested types, the function calls the parent's constraint function. -/
def mkConstraintFunc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) : Procedure :=
let md : Imperative.MetaData Core.Expression := #[]
let baseType := resolveType ptMap ct.base
let bodyExpr := match ct.base.val with
| .UserDefined parent =>
if ptMap.contains parent.text then
let parentCall : StmtExprMd :=
⟨.StaticCall (mkId s!"{parent.text}$constraint") [⟨.Identifier { ct.valueName with uniqueId := none }, md⟩], md⟩
⟨.PrimitiveOp .And [ct.constraint, parentCall], md⟩
else ct.constraint
| _ => ct.constraint
{ name := mkId s!"{ct.name.text}$constraint"
inputs := [{ name := ct.valueName, type := { baseType with md := #[] } }]
outputs := [{ name := mkId "result", type := ⟨.TBool, #[]⟩ }]
body := .Transparent ⟨.Block [bodyExpr] none, #[]⟩
isFunctional := true
determinism := .deterministic none
decreases := none
preconditions := []
md := #[] }

private def wrap (stmts : List StmtExprMd) (md : Imperative.MetaData Core.Expression)
: StmtExprMd :=
match stmts with | [s] => s | ss => ⟨.Block ss none, md⟩

/-- Resolve constrained types in all type positions of an expression,
and inject constraint function calls into quantifier bodies -/
def resolveExpr (ptMap : ConstrainedTypeMap) : StmtExprMd → StmtExprMd
| ⟨.LocalVariable n ty (some init), md⟩ =>
⟨.LocalVariable n (resolveType ptMap ty) (some (resolveExpr ptMap init)), md⟩
| ⟨.LocalVariable n ty none, md⟩ =>
⟨.LocalVariable n (resolveType ptMap ty) none, md⟩
| ⟨.Forall param body, md⟩ =>
let body' := resolveExpr ptMap body
let param' := { param with type := resolveType ptMap param.type }
let injected := match constraintCallFor ptMap param.type.val param.name md with
| some c => ⟨.PrimitiveOp .Implies [c, body'], md⟩
| none => body'
⟨.Forall param' injected, md⟩
| ⟨.Exists param body, md⟩ =>
let body' := resolveExpr ptMap body
let param' := { param with type := resolveType ptMap param.type }
let injected := match constraintCallFor ptMap param.type.val param.name md with
| some c => ⟨.PrimitiveOp .And [c, body'], md⟩
| none => body'
⟨.Exists param' injected, md⟩
| ⟨.AsType t ty, md⟩ => ⟨.AsType (resolveExpr ptMap t) (resolveType ptMap ty), md⟩
| ⟨.IsType t ty, md⟩ => ⟨.IsType (resolveExpr ptMap t) (resolveType ptMap ty), md⟩
| ⟨.PrimitiveOp op args, md⟩ =>
⟨.PrimitiveOp op (args.attach.map fun ⟨a, _⟩ => resolveExpr ptMap a), md⟩
| ⟨.StaticCall c args, md⟩ =>
⟨.StaticCall c (args.attach.map fun ⟨a, _⟩ => resolveExpr ptMap a), md⟩
| ⟨.Block ss sep, md⟩ =>
⟨.Block (ss.attach.map fun ⟨s, _⟩ => resolveExpr ptMap s) sep, md⟩
| ⟨.IfThenElse c t (some el), md⟩ =>
⟨.IfThenElse (resolveExpr ptMap c) (resolveExpr ptMap t) (some (resolveExpr ptMap el)), md⟩
| ⟨.IfThenElse c t none, md⟩ =>
⟨.IfThenElse (resolveExpr ptMap c) (resolveExpr ptMap t) none, md⟩
| ⟨.While c inv dec body, md⟩ =>
⟨.While (resolveExpr ptMap c) (inv.attach.map fun ⟨i, _⟩ => resolveExpr ptMap i)
dec (resolveExpr ptMap body), md⟩
| ⟨.Assign ts v, md⟩ =>
⟨.Assign (ts.attach.map fun ⟨t, _⟩ => resolveExpr ptMap t) (resolveExpr ptMap v), md⟩
| ⟨.Return (some v), md⟩ => ⟨.Return (some (resolveExpr ptMap v)), md⟩
| ⟨.Return none, md⟩ => ⟨.Return none, md⟩
| ⟨.Assert c, md⟩ => ⟨.Assert (resolveExpr ptMap c), md⟩
| ⟨.Assume c, md⟩ => ⟨.Assume (resolveExpr ptMap c), md⟩
| e => e
termination_by e => sizeOf e
decreasing_by all_goals (have := WithMetadata.sizeOf_val_lt ‹_›; term_by_mem)

abbrev ElimM := StateM PredVarMap

private def inScope (action : ElimM α) : ElimM α := do
let saved ← get
let result ← action
set saved
return result

def elimStmt (ptMap : ConstrainedTypeMap)
(stmt : StmtExprMd) : ElimM (List StmtExprMd) := do
let md := stmt.md
match _h : stmt.val with
| .LocalVariable name ty init =>
let callOpt := constraintCallFor ptMap ty.val name md
if callOpt.isSome then modify fun pv => pv.insert name.text ty.val
let (init', check) : Option StmtExprMd × List StmtExprMd := match init with
| none => match callOpt with
| some c => (none, [⟨.Assume c, md⟩])
| none => (none, [])
| some _ => (init, callOpt.toList.map fun c => ⟨.Assert c, md⟩)
pure ([⟨.LocalVariable name ty init', md⟩] ++ check)

| .Assign [target] _ => match target.val with
| .Identifier name => do
match (← get).get? name.text with
| some ty =>
let assert := (constraintCallFor ptMap ty name md).toList.map
fun c => ⟨.Assert c, md⟩
pure ([stmt] ++ assert)
| none => pure [stmt]
| _ => pure [stmt]

| .Block stmts sep =>
let stmtss ← inScope (stmts.mapM (elimStmt ptMap))
pure [⟨.Block stmtss.flatten sep, md⟩]

| .IfThenElse cond thenBr (some elseBr) =>
let thenSs ← inScope (elimStmt ptMap thenBr)
let elseSs ← inScope (elimStmt ptMap elseBr)
pure [⟨.IfThenElse cond (wrap thenSs md) (some (wrap elseSs md)), md⟩]
| .IfThenElse cond thenBr none =>
let thenSs ← inScope (elimStmt ptMap thenBr)
pure [⟨.IfThenElse cond (wrap thenSs md) none, md⟩]

| .While cond inv dec body =>
let bodySs ← inScope (elimStmt ptMap body)
pure [⟨.While cond inv dec (wrap bodySs md), md⟩]

| _ => pure [stmt]
termination_by sizeOf stmt
decreasing_by
all_goals simp_wf
all_goals (try have := WithMetadata.sizeOf_val_lt stmt)
all_goals (try term_by_mem)
all_goals omega

def elimProc (ptMap : ConstrainedTypeMap) (proc : Procedure) : Procedure :=
let inputRequires := proc.inputs.filterMap fun p =>
constraintCallFor ptMap p.type.val p.name p.type.md
let outputEnsures := if proc.isFunctional then [] else proc.outputs.filterMap fun p =>
(constraintCallFor ptMap p.type.val p.name p.type.md).map
fun c => ⟨c.val, p.type.md⟩
let initVars : PredVarMap := proc.inputs.foldl (init := {}) fun s p =>
if isConstrainedType ptMap p.type.val then s.insert p.name.text p.type.val else s
let body' := match proc.body with
| .Transparent bodyExpr =>
let (stmts, _) := (elimStmt ptMap bodyExpr).run initVars
let body := wrap stmts bodyExpr.md
if outputEnsures.isEmpty then .Transparent body
else
let retBody := if proc.isFunctional then ⟨.Return (some body), bodyExpr.md⟩ else body
.Opaque outputEnsures (some retBody) []
| .Opaque postconds impl modif =>
let impl' := impl.map fun b => wrap ((elimStmt ptMap b).run initVars).1 b.md
.Opaque (postconds ++ outputEnsures) impl' modif
| .Abstract postconds => .Abstract (postconds ++ outputEnsures)
| .External => .External
let resolve := resolveExpr ptMap
let resolveBody : Body → Body := fun body => match body with
| .Transparent b => .Transparent (resolve b)
| .Opaque ps impl modif => .Opaque (ps.map resolve) (impl.map resolve) (modif.map resolve)
| .Abstract ps => .Abstract (ps.map resolve)
| .External => .External
{ proc with
body := resolveBody body'
inputs := proc.inputs.map fun p => { p with type := resolveType ptMap p.type }
outputs := proc.outputs.map fun p => { p with type := resolveType ptMap p.type }
preconditions := (proc.preconditions ++ inputRequires).map resolve }

private def mkWitnessProc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) : Procedure :=
let md := ct.witness.md
let witnessId : Identifier := mkId "$witness"
let witnessInit : StmtExprMd :=
⟨.LocalVariable witnessId (resolveType ptMap ct.base) (some ct.witness), md⟩
let assert : StmtExprMd :=
⟨.Assert (constraintCallFor ptMap (.UserDefined ct.name) witnessId md).get!, md⟩
{ name := mkId s!"$witness_{ct.name.text}"
inputs := []
outputs := []
body := .Transparent ⟨.Block [witnessInit, assert] none, md⟩
preconditions := []
isFunctional := false
determinism := .deterministic none
decreases := none
md := md }

def constrainedTypeElim (_model : SemanticModel) (program : Program) : Program × Array DiagnosticModel :=
let ptMap := buildConstrainedTypeMap program.types
if ptMap.isEmpty then (program, #[]) else
let constraintFuncs := program.types.filterMap fun
| .Constrained ct => some (mkConstraintFunc ptMap ct) | _ => none
let witnessProcedures := program.types.filterMap fun
| .Constrained ct => some (mkWitnessProc ptMap ct) | _ => none
let funcDiags := program.staticProcedures.foldl (init := #[]) fun acc proc =>
if proc.isFunctional && proc.outputs.any (fun p => isConstrainedType ptMap p.type.val) then
acc.push (proc.md.toDiagnostic "constrained return types on functions are not yet supported")
else acc
({ program with
staticProcedures := constraintFuncs ++ program.staticProcedures.map (elimProc ptMap)
++ witnessProcedures
types := program.types.filter fun | .Constrained _ => false | _ => true },
funcDiags)

end Strata.Laurel
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,20 @@ def parseDatatype (arg : Arg) : TransM TypeDefinition := do
| _, _ =>
TransM.error s!"parseDatatype expects datatype, got {repr op.name}"

def parseConstrainedType (arg : Arg) : TransM ConstrainedType := do
let .op op := arg
| TransM.error s!"parseConstrainedType expects operation"
match op.name, op.args with
| q`Laurel.constrainedType, #[nameArg, valueNameArg, baseArg, constraintArg, witnessArg] =>
let name ← translateIdent nameArg
let valueName ← translateIdent valueNameArg
let base ← translateHighType baseArg
let constraint ← translateStmtExpr constraintArg
let witness ← translateStmtExpr witnessArg
return { name, base, valueName, constraint, witness }
| _, _ =>
TransM.error s!"parseConstrainedType expects constrainedType, got {repr op.name}"

def parseTopLevel (arg : Arg) : TransM (Option Procedure × Option TypeDefinition) := do
let .op op := arg
| TransM.error s!"parseTopLevel expects operation"
Expand All @@ -509,8 +523,11 @@ def parseTopLevel (arg : Arg) : TransM (Option Procedure × Option TypeDefinitio
| q`Laurel.topLevelDatatype, #[datatypeArg] =>
let typeDef ← parseDatatype datatypeArg
return (none, some typeDef)
| q`Laurel.topLevelConstrainedType, #[ctArg] =>
let ct ← parseConstrainedType ctArg
return (none, some (.Constrained ct))
| _, _ =>
TransM.error s!"parseTopLevel expects topLevelProcedure, topLevelComposite, or topLevelDatatype, got {repr op.name}"
TransM.error s!"parseTopLevel expects topLevelProcedure, topLevelComposite, topLevelDatatype, or topLevelConstrainedType, got {repr op.name}"

/--
Translate concrete Laurel syntax into abstract Laurel syntax
Expand Down
1 change: 0 additions & 1 deletion Strata/Languages/Laurel/Grammar/LaurelGrammar.lean
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
-- Laurel dialect definition, loaded from LaurelGrammar.st
-- NOTE: Changes to LaurelGrammar.st are not automatically tracked by the build system.
-- Update this file (e.g. this comment) to trigger a recompile after modifying LaurelGrammar.st.
-- Last grammar change: require semicolon after procedure/function definitions.
import Strata.DDM.Integration.Lean

namespace Strata.Laurel
Expand Down
6 changes: 6 additions & 0 deletions Strata/Languages/Laurel/Grammar/LaurelGrammar.st
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,10 @@ op topLevelComposite(composite: Composite): TopLevel => composite;
op topLevelProcedure(procedure: Procedure): TopLevel => procedure;
op topLevelDatatype(datatype: Datatype): TopLevel => datatype;

category ConstrainedType;
op constrainedType (name: Ident, valueName: Ident, base: LaurelType,
constraint: StmtExpr, witness: StmtExpr): ConstrainedType
=> "constrained " name " = " valueName ": " base " where " constraint:0 " witness " witness:0;
op topLevelConstrainedType(ct: ConstrainedType): TopLevel => ct;

op program (items: Seq TopLevel): Command => items;
8 changes: 7 additions & 1 deletion Strata/Languages/Laurel/LaurelToCoreTranslator.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import Strata.DL.Imperative.Stmt
import Strata.DL.Imperative.MetaData
import Strata.DL.Lambda.LExpr
import Strata.Languages.Laurel.LaurelFormat
import Strata.Languages.Laurel.ConstrainedTypeElim
import Strata.Util.Tactics

open Core (VCResult VCResults VerifyOptions)
Expand Down Expand Up @@ -619,6 +620,11 @@ def translate (program : Program): Except (Array DiagnosticModel) (Core.Program
let (program, model) := (result.program, result.model)
_resolutionDiags := _resolutionDiags ++ result.errors

let (program, constrainedTypeDiags) := constrainedTypeElim model program
let result := resolve program (some model)
let (program, model) := (result.program, result.model)
_resolutionDiags := _resolutionDiags ++ result.errors

-- Procedures marked isFunctional are translated to Core functions; all others become Core procedures.
-- External procedures are completely ignored (not translated to Core).
let nonExternal := program.staticProcedures.filter (fun p => !p.body.isExternal)
Expand Down Expand Up @@ -667,7 +673,7 @@ def translate (program : Program): Except (Array DiagnosticModel) (Core.Program
-- dbg_trace "=== Generated Strata Core Program ==="
-- dbg_trace (toString (Std.Format.pretty (Strata.Core.formatProgram program) 100))
-- dbg_trace "================================="
pure (program, diamondErrors ++ modifiesDiags)
pure (program, diamondErrors ++ modifiesDiags ++ constrainedTypeDiags.toList)

/--
Verify a Laurel program using an SMT solver
Expand Down
Loading
Loading