diff --git a/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean b/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean index dbf3a8518..3bbab2c16 100644 --- a/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean +++ b/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean @@ -126,7 +126,7 @@ instance : Inhabited Procedure where name := "" inputs := [] outputs := [] - precondition := mkStmtExprMdEmpty <| .LiteralBool true + preconditions := [] determinism := .deterministic none decreases := none body := .Transparent ⟨.LiteralBool true, #[]⟩ @@ -353,12 +353,14 @@ def parseProcedure (arg : Arg) : TransM Procedure := do | .option _ none => pure [] | _ => TransM.error s!"Expected returnParameters operation, got {repr returnParamsArg}" | _ => TransM.error s!"Expected optionalReturnType operation, got {repr returnTypeArg}" - -- Parse precondition (requires clause) - let precondition ← match requiresArg with + -- Parse preconditions (requires clause) + let preconditions ← match requiresArg with | .option _ (some (.op requiresOp)) => match requiresOp.name, requiresOp.args with - | q`Laurel.optionalRequires, #[exprArg] => translateStmtExpr exprArg + | q`Laurel.optionalRequires, #[exprArg] => do + let precond ← translateStmtExpr exprArg + pure [precond] | _, _ => TransM.error s!"Expected optionalRequires operation, got {repr requiresOp.name}" - | .option _ none => pure (mkStmtExprMdEmpty <| .LiteralBool true) + | .option _ none => pure [] | _ => TransM.error s!"Expected optionalRequires operation, got {repr requiresArg}" -- Parse postconditions (ensures clauses - zero or more) let postconditions ← translateEnsuresClauses ensuresArg @@ -380,7 +382,7 @@ def parseProcedure (arg : Arg) : TransM Procedure := do name := name inputs := parameters outputs := returnParameters - precondition := precondition + preconditions := preconditions determinism := .deterministic none decreases := none body := procBody diff --git a/Strata/Languages/Laurel/HeapParameterization.lean b/Strata/Languages/Laurel/HeapParameterization.lean index ac5bb1897..ef8041ca9 100644 --- a/Strata/Languages/Laurel/HeapParameterization.lean +++ b/Strata/Languages/Laurel/HeapParameterization.lean @@ -111,9 +111,9 @@ def analyzeProc (proc : Procedure) : AnalysisResult := { readsHeapDirectly := r1.readsHeapDirectly || r2.readsHeapDirectly, writesHeapDirectly := r1.writesHeapDirectly || r2.writesHeapDirectly, callees := r1.callees ++ r2.callees } - | .Abstract postcond => (collectExprMd postcond).run {} |>.2 - -- Also analyze precondition - let precondResult := (collectExprMd proc.precondition).run {} |>.2 + | .Abstract postconds => (postconds.forM collectExprMd).run {} |>.2 + -- Also analyze preconditions + let precondResult := (proc.preconditions.forM collectExprMd).run {} |>.2 { readsHeapDirectly := bodyResult.readsHeapDirectly || precondResult.readsHeapDirectly, writesHeapDirectly := bodyResult.writesHeapDirectly || precondResult.writesHeapDirectly, callees := bodyResult.callees ++ precondResult.callees } @@ -395,8 +395,8 @@ def heapTransformProcedure (proc : Procedure) : TransformM Procedure := do let inputs' := heapInParam :: proc.inputs let outputs' := heapOutParam :: proc.outputs - -- Precondition uses $heap_in (the input state) - let precondition' ← heapTransformExpr heapInName initEnv proc.precondition + -- Preconditions use $heap_in (the input state) + let preconditions' ← proc.preconditions.mapM (heapTransformExpr heapInName initEnv) let bodyValueIsUsed := !proc.outputs.isEmpty let body' ← match proc.body with @@ -416,14 +416,14 @@ def heapTransformProcedure (proc : Procedure) : TransformM Procedure := do | none => pure none let modif' ← modif.mapM (heapTransformExpr heapName initEnv ·) pure (.Opaque postconds' impl' modif') - | .Abstract postcond => - let postcond' ← heapTransformExpr heapName initEnv postcond - pure (.Abstract postcond') + | .Abstract postconds => + let postconds' ← postconds.mapM (heapTransformExpr heapName initEnv ·) + pure (.Abstract postconds') return { proc with inputs := inputs', outputs := outputs', - precondition := precondition', + preconditions := preconditions', body := body' } else if readsHeap then @@ -431,7 +431,7 @@ def heapTransformProcedure (proc : Procedure) : TransformM Procedure := do let heapParam : Parameter := { name := heapName, type := ⟨.THeap, #[]⟩ } let inputs' := heapParam :: proc.inputs - let precondition' ← heapTransformExpr heapName initEnv proc.precondition + let preconditions' ← proc.preconditions.mapM (heapTransformExpr heapName initEnv) let body' ← match proc.body with | .Transparent bodyExpr => @@ -442,13 +442,13 @@ def heapTransformProcedure (proc : Procedure) : TransformM Procedure := do let impl' ← impl.mapM (heapTransformExpr heapName initEnv ·) let modif' ← modif.mapM (heapTransformExpr heapName initEnv ·) pure (.Opaque postconds' impl' modif') - | .Abstract postcond => - let postcond' ← heapTransformExpr heapName initEnv postcond - pure (.Abstract postcond') + | .Abstract postconds => + let postconds' ← postconds.mapM (heapTransformExpr heapName initEnv ·) + pure (.Abstract postconds') return { proc with inputs := inputs', - precondition := precondition', + preconditions := preconditions', body := body' } else diff --git a/Strata/Languages/Laurel/Laurel.lean b/Strata/Languages/Laurel/Laurel.lean index 839528f74..1c1b84f2c 100644 --- a/Strata/Languages/Laurel/Laurel.lean +++ b/Strata/Languages/Laurel/Laurel.lean @@ -68,6 +68,9 @@ inductive Operation : Type where | StrConcat deriving Repr +-- Explicit instance needed for deriving Repr in the mutual block +instance : Repr (Imperative.MetaData Core.Expression) := inferInstance + /-- A wrapper that pairs a value with source-level metadata such as source locations and annotations. All Laurel AST nodes are wrapped in @@ -79,6 +82,7 @@ structure WithMetadata (t : Type) : Type where val : t /-- Source-level metadata (locations, annotations). -/ md : Imperative.MetaData Core.Expression + deriving Repr /-- The type system for Laurel programs. @@ -119,6 +123,7 @@ inductive HighType : Type where /-- Temporary construct meant to aid the migration of Python->Core to Python->Laurel. Type "passed through" from Core. Intended to allow translations to Laurel to refer directly to Core. -/ | TCore (s: String) + deriving Repr mutual @@ -134,8 +139,8 @@ structure Procedure : Type where inputs : List Parameter /-- Output parameters with their types. Multiple outputs are supported. -/ outputs : List Parameter - /-- The precondition that callers must satisfy. -/ - precondition : WithMetadata StmtExpr + /-- The preconditions that callers must satisfy. -/ + preconditions : List (WithMetadata StmtExpr) /-- Whether the procedure is deterministic or nondeterministic. -/ determinism : Determinism /-- Optional termination measure for recursive procedures. -/ @@ -145,6 +150,15 @@ structure Procedure : Type where /-- Source-level metadata. -/ md : Imperative.MetaData Core.Expression +/-- +A typed parameter for a procedure. +-/ +structure Parameter where + /-- The parameter name. -/ + name : Identifier + /-- The parameter type. -/ + type : WithMetadata HighType + /-- Specifies whether a procedure is deterministic or nondeterministic. @@ -157,15 +171,6 @@ inductive Determinism where /-- A nondeterministic procedure. They can read from the heap but there is no benefit from specifying a reads clause. -/ | nondeterministic -/-- -A typed parameter for a procedure. --/ -structure Parameter where - /-- The parameter name. -/ - name : Identifier - /-- The parameter type. -/ - type : WithMetadata HighType - /-- The body of a procedure. A body can be transparent (with a visible implementation), opaque (with a postcondition and optional implementation), @@ -176,11 +181,11 @@ inductive Body where | Transparent (body : WithMetadata StmtExpr) /-- An opaque body with a postcondition, optional implementation, and modifies clause. Without an implementation the postcondition is assumed. -/ | Opaque - (postcondition : List (WithMetadata StmtExpr)) + (postconditions : List (WithMetadata StmtExpr)) (implementation : Option (WithMetadata StmtExpr)) (modifies : List (WithMetadata StmtExpr)) /-- An abstract body that must be overridden in extending types. A type containing any members with abstract bodies cannot be instantiated. -/ - | Abstract (postcondition : WithMetadata StmtExpr) + | Abstract (postconditions : List (WithMetadata StmtExpr)) /-- The unified statement-expression type for Laurel programs. diff --git a/Strata/Languages/Laurel/LaurelEval.lean b/Strata/Languages/Laurel/LaurelEval.lean index 44b6a8704..46f4e0dbc 100644 --- a/Strata/Languages/Laurel/LaurelEval.lean +++ b/Strata/Languages/Laurel/LaurelEval.lean @@ -226,8 +226,8 @@ partial def eval (expr : StmtExpr) : Eval TypedValue := withResult <| EvalResult.TypeError s!"Static invocation of {callee} with wrong return type" else pure transparantResult - | Body.Opaque (postcondition: StmtExpr) _ _ => panic! "not implemented: opaque body" - | Body.Abstract (postcondition: StmtExpr) => panic! "not implemented: opaque body" + | Body.Opaque _ _ _ => panic! "not implemented: opaque body" + | Body.Abstract _ => panic! "not implemented: abstract body" popStack pure result diff --git a/Strata/Languages/Laurel/LaurelFormat.lean b/Strata/Languages/Laurel/LaurelFormat.lean index e52670be2..6a86c0210 100644 --- a/Strata/Languages/Laurel/LaurelFormat.lean +++ b/Strata/Languages/Laurel/LaurelFormat.lean @@ -146,11 +146,6 @@ end def formatParameter (p : Parameter) : Format := Format.text p.name ++ ": " ++ formatHighType p.type -def formatDeterminism : Determinism → Format - | .deterministic none => "deterministic" - | .deterministic (some reads) => "deterministic reads " ++ formatStmtExpr reads - | .nondeterministic => "nondeterministic" - def formatBody : Body → Format | .Transparent body => formatStmtExpr body | .Opaque postconds impl modif => @@ -160,13 +155,21 @@ def formatBody : Body → Format match impl with | none => Format.nil | some e => " := " ++ formatStmtExpr e - | .Abstract post => "abstract ensures " ++ formatStmtExpr post + | .Abstract posts => "abstract" ++ Format.join (posts.map (fun p => " ensures " ++ formatStmtExpr p)) + +def formatDeterminism : Determinism → Format + | .deterministic none => "deterministic" + | .deterministic (some reads) => "deterministic reads " ++ formatStmtExpr reads + | .nondeterministic => "nondeterministic" + +instance : Std.ToFormat Determinism where + format := formatDeterminism def formatProcedure (proc : Procedure) : Format := "procedure " ++ Format.text proc.name ++ "(" ++ Format.joinSep (proc.inputs.map formatParameter) ", " ++ ") returns " ++ Format.line ++ "(" ++ Format.joinSep (proc.outputs.map formatParameter) ", " ++ ")" ++ Format.line ++ - "requires " ++ formatStmtExpr proc.precondition ++ Format.line ++ + Format.join (proc.preconditions.map (fun p => "requires " ++ formatStmtExpr p ++ Format.line)) ++ formatDeterminism proc.determinism ++ Format.line ++ formatBody proc.body @@ -222,9 +225,6 @@ instance : Std.ToFormat StmtExpr where instance : Std.ToFormat Parameter where format := formatParameter -instance : Std.ToFormat Determinism where - format := formatDeterminism - instance : Std.ToFormat Body where format := formatBody diff --git a/Strata/Languages/Laurel/LaurelToCoreTranslator.lean b/Strata/Languages/Laurel/LaurelToCoreTranslator.lean index 634765c84..767acd06b 100644 --- a/Strata/Languages/Laurel/LaurelToCoreTranslator.lean +++ b/Strata/Languages/Laurel/LaurelToCoreTranslator.lean @@ -323,13 +323,13 @@ def translateProcedure (fieldNames : List Identifier) (funcNames : FunctionNames } let initEnv : TypeEnv := proc.inputs.map (fun p => (p.name, p.type)) ++ proc.outputs.map (fun p => (p.name, p.type)) - -- Translate precondition if it's not just LiteralBool true + -- Translate preconditions let preconditions : ListMap Core.CoreLabel Core.Procedure.Check := - match proc.precondition with - | ⟨ .LiteralBool true, _ ⟩ => [] - | precond => + let (_, result) := proc.preconditions.foldl (fun (i, acc) precond => + let label := if proc.preconditions.length == 1 then "requires" else s!"requires_{i}" let check : Core.Procedure.Check := { expr := translateExpr fieldNames initEnv precond, md := precond.md } - [("requires", check)] + (i + 1, acc ++ [(label, check)])) (0, []) + result -- Translate postconditions for Opaque bodies let postconditions : ListMap Core.CoreLabel Core.Procedure.Check := match proc.body with @@ -438,7 +438,7 @@ def canBeBoogieFunction (proc : Procedure) : Bool := match proc.body with | .Transparent bodyExpr => isPureExpr bodyExpr && - (match proc.precondition.val with | .LiteralBool true => true | _ => false) && + proc.preconditions.isEmpty && proc.outputs.length == 1 | _ => false diff --git a/Strata/Languages/Laurel/LiftExpressionAssignments.lean b/Strata/Languages/Laurel/LiftExpressionAssignments.lean index 3b9f1f375..cfa6d4486 100644 --- a/Strata/Languages/Laurel/LiftExpressionAssignments.lean +++ b/Strata/Languages/Laurel/LiftExpressionAssignments.lean @@ -261,9 +261,12 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do match val with | .LocalVariable name ty _ => addToEnv name ty | _ => pure () - -- Process all-but-last right to left using transformExprDiscarded - for nonLastStatement in stmts.dropLast.reverse.attach do - transformExprDiscarded nonLastStatement + -- Process all-but-last as statements and prepend them in order + let mut blockStmts : List StmtExprMd := [] + for nonLastStatement in stmts.dropLast.attach do + have := List.dropLast_subset stmts nonLastStatement.property + blockStmts := blockStmts ++ (← transformStmt nonLastStatement) + for s in blockStmts.reverse do addPrepend s -- Last element is the expression value transformExpr last @@ -289,31 +292,6 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do termination_by (sizeOf expr, 0) decreasing_by all_goals (simp_all; try term_by_mem) - have := List.dropLast_subset stmts - have stmtInStmts : nonLastStatement.val ∈ stmts := by grind - -- term_by_mem gets a type error here, so we do it manually - have xSize := List.sizeOf_lt_of_mem stmtInStmts - omega - -/-- -Transform an expression whose result value is discarded (e.g. non-last elements in a block). All side-effects in Laurel are represented as assignments, so we only need to lift assignments, anything else can be forgotten. --/ -def transformExprDiscarded (expr2 : StmtExprMd) : LiftM Unit := do - match _hExpr: expr2 with - | WithMetadata.mk val md => - match _h: val with - | .Assign targets value => - -- Transform value to process nested assignments (side-effect only), - -- but use original value for the prepended assignment (no substitutions needed). - let _ ← transformExpr value - liftAssignExpr targets value md - | _ => - let result ← transformExpr expr2 - addPrepend result - termination_by (sizeOf expr2, 1) - decreasing_by - simp_all; omega - rw [<- _hExpr]; omega /-- Process a statement, handling any assignments in its sub-expressions. diff --git a/Strata/Languages/Laurel/TypeHierarchy.lean b/Strata/Languages/Laurel/TypeHierarchy.lean index afc7eced2..6710edabf 100644 --- a/Strata/Languages/Laurel/TypeHierarchy.lean +++ b/Strata/Languages/Laurel/TypeHierarchy.lean @@ -194,7 +194,7 @@ def validateDiamondFieldAccesses (uri : Uri) (program : Program) : Array Diagnos | some implExpr => validateDiamondFieldAccessesForStmtExpr uri program.types env implExpr | none => [] postErrors ++ implErrors - | .Abstract postcond => validateDiamondFieldAccessesForStmtExpr uri program.types env postcond + | .Abstract postconds => postconds.foldl (fun acc p => acc ++ validateDiamondFieldAccessesForStmtExpr uri program.types env p) [] acc ++ bodyErrors) [] errors.toArray @@ -296,7 +296,7 @@ def rewriteTypeHierarchyExpr (exprMd : StmtExprMd) : THM StmtExprMd := decreasing_by all_goals (simp_all; try term_by_mem) def rewriteTypeHierarchyProcedure (proc : Procedure) : THM Procedure := do - let precondition' ← rewriteTypeHierarchyExpr proc.precondition + let preconditions' ← proc.preconditions.mapM rewriteTypeHierarchyExpr let body' ← match proc.body with | .Transparent b => pure (.Transparent (← rewriteTypeHierarchyExpr b)) | .Opaque postconds impl modif => @@ -306,8 +306,8 @@ def rewriteTypeHierarchyProcedure (proc : Procedure) : THM Procedure := do | none => pure none let modif' ← modif.mapM rewriteTypeHierarchyExpr pure (.Opaque postconds' impl' modif') - | .Abstract postcond => pure (.Abstract (← rewriteTypeHierarchyExpr postcond)) - return { proc with precondition := precondition', body := body' } + | .Abstract postconds => pure (.Abstract (← postconds.mapM rewriteTypeHierarchyExpr)) + return { proc with preconditions := preconditions', body := body' } /-- Type hierarchy transformation pass (Laurel → Laurel). diff --git a/Strata/Languages/Python/PythonToLaurel.lean b/Strata/Languages/Python/PythonToLaurel.lean index 32719f541..ced934f9e 100644 --- a/Strata/Languages/Python/PythonToLaurel.lean +++ b/Strata/Languages/Python/PythonToLaurel.lean @@ -559,7 +559,7 @@ def translateFunction (ctx : TranslationContext) (f : Python.stmt SourceRange) name := funcName inputs := inputs outputs := outputs - precondition := mkStmtExprMd (StmtExpr.LiteralBool true) + preconditions := [] determinism := .deterministic none -- TODO: need to set reads decreases := none body := Body.Transparent bodyBlock @@ -642,9 +642,9 @@ def pythonToLaurel (prelude: Core.Program) name := "__main__", inputs := [], outputs := [], - precondition := mkStmtExprMd (StmtExpr.LiteralBool true), + preconditions := [], + determinism := .deterministic none, --TODO: need to set reads decreases := none, - determinism := .deterministic none --TODO: need to set reads body := .Transparent bodyBlock md := default } diff --git a/Strata/Languages/Python/Specs/ToLaurel.lean b/Strata/Languages/Python/Specs/ToLaurel.lean index 686d7dd81..5c22f6e74 100644 --- a/Strata/Languages/Python/Specs/ToLaurel.lean +++ b/Strata/Languages/Python/Specs/ToLaurel.lean @@ -268,7 +268,7 @@ def funcDeclToLaurel (procName : String) (func : FunctionDecl) name := procName inputs := inputs.toList outputs := outputs - precondition := ⟨.LiteralBool true, .empty⟩ + preconditions := [] determinism := .nondeterministic decreases := none body := .Opaque [] none [] diff --git a/StrataTest/Languages/Laurel/LiftExpressionAssignmentsTest.lean b/StrataTest/Languages/Laurel/LiftExpressionAssignmentsTest.lean new file mode 100644 index 000000000..ea03e304a --- /dev/null +++ b/StrataTest/Languages/Laurel/LiftExpressionAssignmentsTest.lean @@ -0,0 +1,68 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ + +/- +Tests that the expression lifter correctly handles statement constructs +(heap-updating assignments) in non-last positions of block expressions, +by comparing the lifted Laurel against expected output. +-/ + +import Strata.DDM.Elab +import Strata.DDM.BuiltinDialects.Init +import Strata.Languages.Laurel.Grammar.LaurelGrammar +import Strata.Languages.Laurel.Grammar.ConcreteToAbstractTreeTranslator +import Strata.Languages.Laurel.LaurelToCoreTranslator + +open Strata +open Strata.Elab (parseStrataProgramFromDialect) + +namespace Strata.Laurel + +def blockStmtLiftingProgram : String := r" +composite Box { + var value: int +} + +procedure heapUpdateInBlockExpr(b: Box) +{ + var x: int := { b#value := b#value + 1; b#value }; + assert x == b#value; +} + +procedure assertInBlockExpr() +{ + var x: int := 0; + var y: int := { assert x == 0; x := 1; x }; + assert y == 1; +} +" + +def parseLaurelAndLift (input : String) : IO Program := do + let inputCtx := Strata.Parser.stringInputContext "test" input + let dialects := Strata.Elab.LoadedDialects.ofDialects! #[initDialect, Laurel] + let strataProgram ← parseStrataProgramFromDialect dialects Laurel.name inputCtx + let uri := Strata.Uri.file "test" + match Laurel.TransM.run uri (Laurel.parseProgram strataProgram) with + | .error e => throw (IO.userError s!"Translation errors: {e}") + | .ok program => pure (liftExpressionAssignments program) + +/-- +info: procedure heapUpdateInBlockExpr(b: Box) returns +() +deterministic +{ b#value := b#value + 1; var x: int := b#value; assert x == b#value } +procedure assertInBlockExpr() returns +() +deterministic +{ var x: int := 0; assert x == 0; x := 1; var y: int := x; assert y == 1 } +-/ +#guard_msgs in +#eval! do + let program ← parseLaurelAndLift blockStmtLiftingProgram + for proc in program.staticProcedures do + IO.println (toString (Std.Format.pretty (Std.ToFormat.format proc))) + +end Laurel