Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Strata/DDM/Elab.lean
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,10 @@ partial def loadDialectFromIonFragment
(stk : Array DialectName)
(dialect : DialectName)
(frag : Ion.Fragment)

: BaseIO (Except String Dialect) := do
if dialect ∈ (←fm.loaded.get).dialects then
return .error s!"{dialect} already loaded"
-- Read dialect from Ion fragment
let d ←
match Dialect.fromIonFragment dialect frag with
Expand Down
490 changes: 359 additions & 131 deletions Strata/Languages/Python/Specs.lean

Large diffs are not rendered by default.

102 changes: 81 additions & 21 deletions Strata/Languages/Python/Specs/DDM.lean
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dialect PythonSpecs;

category Int;
op natInt (x : Num) : Int => x;
op negSuccInt (x : Num) : Int => "-" x;
op negInt (x : Num) : Int => "-" x;

category SpecType;
category DictFieldDecl;
Expand All @@ -46,8 +46,8 @@ op mkDictFieldDecl(name : Ident, fieldType : SpecType, isRequired : Bool) : Dict
name " : " fieldType " [required=" isRequired "]";

category ClassFieldDecl;
op mkClassFieldDecl(name : Ident, fieldType : SpecType) : ClassFieldDecl =>
name " : " fieldType "\n";
op mkClassFieldDecl(name : Ident, fieldType : SpecType, constValue : Option Str) : ClassFieldDecl =>
name " : " fieldType constValue "\n";

category ClassVarDecl;
op mkClassVarDecl(name : Ident, value : Ident) : ClassVarDecl =>
Expand All @@ -65,22 +65,44 @@ category SpecExprDecl;
op placeholderExpr() : SpecExprDecl => "placeholder";
op varExpr(name : Ident) : SpecExprDecl => name;
op getIndexExpr(subject : SpecExprDecl, field : Ident) : SpecExprDecl =>
subject "[" field "]";
@[prec(50)] subject "[" field "]";
op isInstanceOfExpr(subject : SpecExprDecl, typeName : Str) : SpecExprDecl =>
"isinstance(" subject ", " typeName ")";
"isinstance" "(" subject ", " typeName ")";
op lenExpr(subject : SpecExprDecl) : SpecExprDecl =>
"len(" subject ")";
"len" "(" subject ")";
op intExpr(value : Int) : SpecExprDecl => value;
op intGeExpr(subject : SpecExprDecl, bound : SpecExprDecl) : SpecExprDecl =>
subject " >= " bound;
@[prec(15)] subject " >=_int " bound;
op intLeExpr(subject : SpecExprDecl, bound : SpecExprDecl) : SpecExprDecl =>
subject " <= " bound;
@[prec(15)] subject " <=_int " bound;
op floatExpr(value : Str) : SpecExprDecl => value;
op floatGeExpr(subject : SpecExprDecl, bound : SpecExprDecl) : SpecExprDecl =>
@[prec(15)] subject " >=_float " bound;
op floatLeExpr(subject : SpecExprDecl, bound : SpecExprDecl) : SpecExprDecl =>
@[prec(15)] subject " <=_float " bound;
op enumMemberExpr(subject : SpecExprDecl, values : Seq Str) : SpecExprDecl =>
"enum(" subject ", [" values "])";
"enum" "(" subject ", [" values "]" ")";
op regexMatchExpr(subject : SpecExprDecl, pattern : Str) : SpecExprDecl =>
"regex" "(" subject ", " pattern ")";
op containsKeyExpr(container : SpecExprDecl, key : Ident) : SpecExprDecl =>
@[prec(15)] key " in " container;
op impliesExpr(condition : SpecExprDecl, body : SpecExprDecl) : SpecExprDecl =>
@[prec(10), rightassoc] condition " => " body;
op notExpr(e : SpecExprDecl) : SpecExprDecl =>
"not" "(" e ")";
op forallListExpr(list : SpecExprDecl, varName : Ident, body : SpecExprDecl) : SpecExprDecl =>
"forall" "(" list ", " varName ", " body ")";
op forallDictExpr(dict : SpecExprDecl, keyVar : Ident,
valVar : Ident, body : SpecExprDecl) : SpecExprDecl =>
"forallDict" "(" dict ", " keyVar ", " valVar ", " body ")";

category MessagePart;
op strMessagePart(s : Str) : MessagePart => s;
op exprMessagePart(e : SpecExprDecl) : MessagePart => "{" e "}";

category Assertion;
op mkAssertion(formula : SpecExprDecl, message : Str) : Assertion =>
"ensure(" formula ", " message ")\n";
op mkAssertion(formula : SpecExprDecl, message : Seq MessagePart) : Assertion =>
"ensure" "(" formula ", " message ")\n";

category PostconditionEntry;
op mkPostconditionEntry(expr : SpecExprDecl) : PostconditionEntry =>
Expand Down Expand Up @@ -155,18 +177,20 @@ private def PythonIdent.toDDM (d : PythonIdent) : Ann String SourceRange :=
⟨.none, toString d⟩

/-- Converts a Lean `Int` to the DDM representation which separates natural and negative cases. -/
private def toDDMInt {α} (ann : α) (i : Int) : DDM.Int α :=
def toDDMInt {α} (ann : α) (i : Int) : DDM.Int α :=
match i with
| .ofNat n => .natInt ann ⟨ann, n⟩
| .negSucc n => .negSuccInt ann ⟨ann, n
| .negSucc n => .negInt ann ⟨ann, (n+1)

private def DDM.Int.ofDDM : DDM.Int α → _root_.Int
def DDM.Int.ofDDM : DDM.Int α → _root_.Int
| .natInt _ ⟨_, n⟩ => .ofNat n
| .negSuccInt _ ⟨_, n⟩ => .negSucc n
| .negInt _ ⟨_, 0⟩ => 0
| .negInt _ ⟨_, n+1⟩ => .negSucc n

mutual

private def SpecAtomType.toDDM (d : SpecAtomType) (loc : SourceRange := .none) : DDM.SpecType SourceRange :=
private def SpecAtomType.toDDM (d : SpecAtomType)
(loc : SourceRange := .none) : DDM.SpecType SourceRange :=
match d with
| .ident nm args =>
if args.isEmpty then
Expand Down Expand Up @@ -216,12 +240,31 @@ private def SpecExpr.toDDM (e : SpecExpr) : DDM.SpecExprDecl SourceRange :=
| .intLit v => .intExpr .none (toDDMInt .none v)
| .intGe subj bound => .intGeExpr .none subj.toDDM bound.toDDM
| .intLe subj bound => .intLeExpr .none subj.toDDM bound.toDDM
| .floatLit v => .floatExpr .none ⟨.none, v⟩
| .floatGe subj bound => .floatGeExpr .none subj.toDDM bound.toDDM
| .floatLe subj bound => .floatLeExpr .none subj.toDDM bound.toDDM
| .enumMember subj values =>
.enumMemberExpr .none subj.toDDM
⟨.none, values.map (⟨.none, ·⟩)⟩
| .regexMatch subj pattern =>
.regexMatchExpr .none subj.toDDM ⟨.none, pattern⟩
| .containsKey container key =>
.containsKeyExpr .none container.toDDM ⟨.none, key⟩
| .implies cond body =>
.impliesExpr .none cond.toDDM body.toDDM
| .not e => .notExpr .none e.toDDM
| .forallList list varName body =>
.forallListExpr .none list.toDDM ⟨.none, varName⟩ body.toDDM
| .forallDict dict keyVar valVar body =>
.forallDictExpr .none dict.toDDM ⟨.none, keyVar⟩ ⟨.none, valVar⟩ body.toDDM

private def MessagePart.toDDM (p : MessagePart) : DDM.MessagePart SourceRange :=
match p with
| .str s => .strMessagePart .none ⟨.none, s⟩
| .expr e => .exprMessagePart .none e.toDDM

private def Assertion.toDDM (a : Assertion) : DDM.Assertion SourceRange :=
.mkAssertion .none a.formula.toDDM ⟨.none, a.message⟩
.mkAssertion .none a.formula.toDDM ⟨.none, a.message.map (·.toDDM)

private def FunctionDecl.toDDM (d : FunctionDecl) : DDM.FunDecl SourceRange :=
.mkFunDecl
Expand All @@ -247,7 +290,8 @@ private partial def ClassDef.toDDMDecl (d : ClassDef) : DDM.ClassDecl SourceRang
.mkClassDecl d.loc (.mk .none d.name)
⟨.none, d.bases.map (·.toDDM)⟩
⟨.none, d.fields.map fun f =>
.mkClassFieldDecl .none ⟨.none, f.name⟩ f.type.toDDM⟩
.mkClassFieldDecl .none ⟨.none, f.name⟩ f.type.toDDM
⟨.none, f.constValue.map (⟨.none, ·⟩)⟩⟩
⟨.none, d.classVars.map (·.toDDM)⟩
⟨.none, d.subclasses.map (·.toDDMDecl)⟩
⟨.none, d.methods.map (·.toDDM)⟩
Expand Down Expand Up @@ -323,11 +367,27 @@ private def DDM.SpecExprDecl.fromDDM (d : DDM.SpecExprDecl SourceRange) : Specs.
| .intExpr _ i => .intLit i.ofDDM
| .intGeExpr _ subj bound => .intGe subj.fromDDM bound.fromDDM
| .intLeExpr _ subj bound => .intLe subj.fromDDM bound.fromDDM
| .floatExpr _ ⟨_, v⟩ => .floatLit v
| .floatGeExpr _ subj bound => .floatGe subj.fromDDM bound.fromDDM
| .floatLeExpr _ subj bound => .floatLe subj.fromDDM bound.fromDDM
| .enumMemberExpr _ subj ⟨_, values⟩ => .enumMember subj.fromDDM (values.map (·.2))
| .regexMatchExpr _ subj ⟨_, pattern⟩ => .regexMatch subj.fromDDM pattern
| .containsKeyExpr _ container ⟨_, key⟩ => .containsKey container.fromDDM key
| .impliesExpr _ cond body => .implies cond.fromDDM body.fromDDM
| .notExpr _ e => .not e.fromDDM
| .forallListExpr _ list ⟨_, varName⟩ body =>
.forallList list.fromDDM varName body.fromDDM
| .forallDictExpr _ dict ⟨_, keyVar⟩ ⟨_, valVar⟩ body =>
.forallDict dict.fromDDM keyVar valVar body.fromDDM

private def DDM.MessagePart.fromDDM (d : DDM.MessagePart SourceRange) : Specs.MessagePart :=
match d with
| .strMessagePart _ ⟨_, s⟩ => .str s
| .exprMessagePart _ e => .expr e.fromDDM

private def DDM.Assertion.fromDDM (d : DDM.Assertion SourceRange) : Specs.Assertion :=
let .mkAssertion _ formula ⟨_, message⟩ := d
{ message := message, formula := formula.fromDDM }
{ message := message.map (·.fromDDM), formula := formula.fromDDM }

private def DDM.FunDecl.fromDDM (d : DDM.FunDecl SourceRange) : Specs.FunctionDecl :=
let .mkFunDecl loc ⟨nameLoc, name⟩ ⟨_, args⟩ ⟨_, kwonly⟩
Expand Down Expand Up @@ -363,8 +423,8 @@ private def DDM.ClassDecl.fromDDM (d : DDM.ClassDecl SourceRange) : Specs.ClassD
match PythonIdent.ofString s with
| some id => id
| none => panic! s!"Bad base class identifier: '{s}'"
fields := fields.map fun (.mkClassFieldDecl _ ⟨_, n⟩ tp) =>
{ name := n, type := tp.fromDDM : ClassField }
fields := fields.map fun (.mkClassFieldDecl _ ⟨_, n⟩ tp ⟨_, cv⟩) =>
{ name := n, type := tp.fromDDM, constValue := cv.map (·.2) : ClassField }
classVars := classVars.map fun (.mkClassVarDecl _ ⟨_, n⟩ ⟨_, v⟩) =>
{ name := n, value := v : ClassVariable }
subclasses := subclasses.map (·.fromDDM)
Expand Down
33 changes: 32 additions & 1 deletion Strata/Languages/Python/Specs/Decls.lean
Original file line number Diff line number Diff line change
Expand Up @@ -334,11 +334,40 @@ inductive SpecExpr where
| intLit (value : Int)
| intGe (subject : SpecExpr) (bound : SpecExpr)
| intLe (subject : SpecExpr) (bound : SpecExpr)
/-- A floating-point literal, stored as a string to preserve precision. -/
| floatLit (value : String)
| floatGe (subject : SpecExpr) (bound : SpecExpr)
| floatLe (subject : SpecExpr) (bound : SpecExpr)
| enumMember (subject : SpecExpr) (values : Array String)
/-- `regexMatch subject pattern` asserts that `subject` matches the regular
expression `pattern`. Corresponds to `compile(pattern).search(subject) is not None`
in the Python source. -/
| regexMatch (subject : SpecExpr) (pattern : String)
/-- `containsKey container key` asserts that `key` is present in `container`.
Corresponds to `"key" in container` in the Python source. -/
| containsKey (container : SpecExpr) (key : String)
/-- `implies condition body` asserts that if `condition` holds then `body` holds.
Used to represent conditional assertions like `if "field" in kwargs: assert ...`. -/
| implies (condition : SpecExpr) (body : SpecExpr)
/-- Logical negation. Used for else-branch conditions. -/
| not (e : SpecExpr)
/-- `forallList list varName body` asserts that `body` holds for every element
of `list`, with `varName` bound to each element in turn. Only `body` may
refer to `varName`. Corresponds to `for varName in list: assert body`. -/
| forallList (list : SpecExpr) (varName : String) (body : SpecExpr)
/-- `forallDict dict keyVar valVar body` asserts that `body` holds for every
key-value pair in `dict`. Both `keyVar` and `valVar` are bound in `body`.
Corresponds to `for keyVar, valVar in dict.items(): assert body`. -/
| forallDict (dict : SpecExpr) (keyVar : String) (valVar : String) (body : SpecExpr)
deriving Inhabited

inductive MessagePart where
| str (s : String)
| expr (e : SpecExpr)
deriving Inhabited

structure Assertion where
message : String
message : Array MessagePart
formula : SpecExpr
deriving Inhabited

Expand All @@ -356,6 +385,8 @@ deriving Inhabited
structure ClassField where
name : String
type : SpecType
/-- An optional constant value for the field (e.g., from `self.x = expr` in `__init__`). -/
constValue : Option String := none
deriving Inhabited

structure ClassVariable where
Expand Down
4 changes: 4 additions & 0 deletions Strata/Util/IO.lean
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def readStrataText (fm : Strata.DialectFileMap) (path : System.FilePath) (bytes
| .ok program => pure (.program program)
| .error errors => throw (IO.userError (← mkErrorReport path errors))
| .dialect stx dialect =>
if dialect ∈ (←fm.loaded.get).dialects then
throw <| IO.userError s!"{dialect} already loaded"
let (d, s) ←
Strata.Elab.elabDialectRest fm inputContext stx dialect (startPos := startPos)
if s.errors.size > 0 then
Expand All @@ -112,6 +114,8 @@ def readStrataIon (fm : Strata.DialectFileMap)
pure p
match hdr with
| .dialect dialect =>
if dialect ∈ (←fm.loaded.get).dialects then
throw <| IO.userError s!"{dialect} already loaded"
match ← Strata.Elab.loadDialectFromIonFragment fm #[] dialect frag with
| .error msg =>
throw (IO.userError (fileReadErrorMsg path msg))
Expand Down
5 changes: 5 additions & 0 deletions StrataMain.lean
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ def printCommand : Command where
help := "Pretty-print a Strata file (text or Ion) to stdout."
callback := fun v pflags => do
let searchPath ← pflags.buildDialectFileMap
-- Special case for already loaded dialects.
let ld ← searchPath.getLoaded
if mem : v[0] ∈ ld.dialects then
IO.print <| ld.dialects.format v[0] mem
return
let pd ← Strata.readStrataFile searchPath v[0]
match pd with
| .dialect d =>
Expand Down
56 changes: 55 additions & 1 deletion StrataTest/Languages/Python/Specs/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from re import compile
from basetypes import BaseClass
from typing import Any, Dict, List, Sequence
from typing import Any, Dict, List, Mapping, NotRequired, Sequence, TypedDict, Unpack

def dict_function(x : Dict[int, Any]):
pass
Expand All @@ -24,3 +24,57 @@ def main_function(x : MainClass):
def kwargs_function(**kw: int) -> Any:
assert isinstance(kw["name"], str), 'Expected name to be str'
assert kw["count"] >= 1, 'Expected count >= 1'

# Test f-string messages, regex, nested subscripts, and for-loops
TestRequest = TypedDict('TestRequest', {
'Name': str,
'Items': NotRequired[List[str]],
'Tags': NotRequired[Mapping[str, str]],
})

def fstring_and_regex(**params: Unpack[TestRequest]) -> None:
assert len(params["Name"]) >= 1, f'Expected len(params["Name"]) >= 1, got {len(params["Name"])}'
assert len(params["Name"]) <= 100, f'Expected len(params["Name"]) <= 100, got {len(params["Name"])}'
assert compile("^[a-zA-Z]+$").search(params["Name"]) is not None, f'params["Name"] did not match pattern'
if "Items" in params:
for item in params["Items"]:
assert len(item) >= 1, f'Expected len(item) >= 1, got {len(item)}'
assert len(item) <= 50, f'Expected len(item) <= 50, got {len(item)}'
if "Tags" in params:
for tag_key, tag_val in params["Tags"].items():
assert len(tag_key) >= 1, f'Expected len(tag_key) >= 1, got {len(tag_key)}'

# Test float comparisons, negative int bounds, and __init__ class fields
FloatRequest = TypedDict('FloatRequest', {
'SampleSize': NotRequired[float],
'Score': NotRequired[float],
'Count': NotRequired[int],
})

def float_and_negative_bounds(**fp: Unpack[FloatRequest]) -> None:
# Float field with float literal bound
if "Score" in fp:
assert fp["Score"] >= 0.0, f'Expected Score >= 0.0'
assert fp["Score"] <= 1.0, f'Expected Score <= 1.0'
else:
assert fp["SampleSize"] >= 0, f'Expected SampleSize >= 0 when no Score'
# Float field with integer literal bound (the SampleSize pattern)
if "SampleSize" in fp:
assert fp["SampleSize"] >= 0, f'Expected SampleSize >= 0'
# Float field with negative float bound
if "Score" in fp:
assert fp["Score"] >= -0.5, f'Expected Score >= -0.5'
# Int field with negative bound
if "Count" in fp:
assert fp["Count"] >= -1, f'Expected Count >= -1'

class InnerHelper:
pass

class ClassWithInit:
def __init__(self):
self.helper = self._InnerHelper()

class _InnerHelper(InnerHelper):
def do_work(self) -> None:
pass
32 changes: 32 additions & 0 deletions StrataTest/Languages/Python/Specs/warnings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Any, Dict, List, NotRequired, TypedDict, Unpack

# Unsupported assert pattern: equality comparison
def unsupported_assert(**kw: int) -> None:
assert kw["x"] == 1, 'x must be 1'

# Unsupported __init__ assignment value (not self._ClassName() pattern)
class BadInit:
def __init__(self):
self.name = "hello"

# Skipped Assign in function body
def skipped_assign(**kw: int) -> None:
x = kw["a"]
assert x >= 1, 'x >= 1'

# For loop with unsupported target (attribute, not simple Name)
LoopRequest = TypedDict('LoopRequest', {
'Items': NotRequired[List[str]],
'Data': NotRequired[Dict[str, str]],
})

# For loop with unsupported orelse (for/else)
def for_else_loop(**kw: Unpack[LoopRequest]) -> None:
for item in kw["Items"]:
assert len(item) >= 1, f'Expected len >= 1'
else:
pass

# Skipped Expr in function body (non-ellipsis expression statement)
def skipped_expr(**kw: int) -> None:
kw["a"]
Loading
Loading