Add support for __matmul__ in ExpressionNode#235
Conversation
pymbolic/mapper/evaluator.py
Outdated
| @override | ||
| def map_matmul(self, expr: p.Matmul, /) -> ResultT: | ||
| if not expr.children: | ||
| return cast("ResultT", 1) |
There was a problem hiding this comment.
Not sure this is a good idea? Mainly because ndarray @ 1 fails due to the shape mismatch (even if it's (1,)).
There was a problem hiding this comment.
Yeah, might be better to error in that case.
There was a problem hiding this comment.
pymbolic/mapper/flop_counter.py
Outdated
| map_product: Callable[[Self, p.Product], ArithmeticExpression] = map_sum | ||
| map_sum: Callable[[Self, p.Sum], ArithmeticExpression] = _count_children | ||
| map_product: Callable[[Self, p.Product], ArithmeticExpression] = _count_children | ||
| map_matmul: Callable[[Self, p.Matmul], ArithmeticExpression] = _count_children |
There was a problem hiding this comment.
Since this is a "matrix multiply", not sure this makes sense in the FlopCounter, but I guess subclasses will have to figure it out?
There was a problem hiding this comment.
This was already implemented in CombineMapper, so all this did was add the nops + 1 to the children cost. I switched that to NotImplementedError completely to avoid confusion.
pymbolic/mapper/stringifier.py
Outdated
| enclosing_prec: int, *args: P.args, **kwargs: P.kwargs | ||
| ) -> str: | ||
| return self.parenthesize_if_needed( | ||
| self.join_rec(" ", expr.children, PREC_PRODUCT, *args, **kwargs), |
There was a problem hiding this comment.
? Not sure I have any better idea of what to map it to in LaTeX.
There was a problem hiding this comment.
Added a \cdot like suggested below.
There was a problem hiding this comment.
Pull request overview
Adds a new Matmul expression primitive and plumbing so ExpressionNode can represent/use Python’s @ operator across core mappers (stringification, evaluation, traversal, etc.).
Changes:
- Introduces
prim.Matmuland implementsExpressionNode.__matmul__/__rmatmul__. - Adds
Matmulhandling across multiple mappers (stringifiers, evaluator, graphviz, flop counter, generic mapper base classes). - Adds a basic unit test and updates the basedpyright baseline.
Reviewed changes
Copilot reviewed 17 out of 17 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
pymbolic/primitives.py |
Adds Matmul node and @ operator overloads on ExpressionNode. |
pymbolic/mapper/__init__.py |
Adds map_matmul to base/utility mappers (default, combine, identity, walk, constant). |
pymbolic/mapper/stringifier.py |
Adds map_matmul for standard and LaTeX stringification. |
pymbolic/mapper/evaluator.py |
Adds evaluation logic for Matmul using operator.matmul/reduce; adjusts product evaluation implementation. |
pymbolic/mapper/graphviz.py |
Adds graphviz visualization for Matmul. |
pymbolic/mapper/flop_counter.py |
Counts Matmul similarly to Sum/Product (n-1 plus children). |
pymbolic/mapper/c_code.py |
Explicitly rejects Matmul in C codegen via NotImplementedError. |
pymbolic/mapper/cse_tagger.py |
Tags Matmul subexpressions for CSE handling. |
pymbolic/mapper/unifier.py |
Refactors imports to use pymbolic.primitives as p consistently. |
pymbolic/mapper/substitutor.py |
Minor import refactor for Variable. |
pymbolic/mapper/flattener.py |
Refactors to call p.flattened_sum/product directly. |
pymbolic/mapper/distributor.py |
Refactors Sum reference to p.Sum. |
pymbolic/mapper/constant_folder.py |
Refactors to use prim.flattened_sum/product directly. |
pymbolic/mapper/collector.py |
Import/type-expression refactors in TermCollector.split_term. |
pymbolic/mapper/coefficient.py |
Refactors to use p.* helpers/types directly. |
test/test_pymbolic.py |
Adds a basic test_matmul asserting node type and str() output. |
.basedpyright/baseline.json |
Adds new suppressions (unannotated class attribute / Any). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
pymbolic/mapper/stringifier.py
Outdated
| self, expr: p.Matmul, /, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs | ||
| ) -> str: | ||
| return self.parenthesize_if_needed( | ||
| self.join_rec(" @ ", expr.children, PREC_PRODUCT, *args, **kwargs), |
There was a problem hiding this comment.
StringifyMapper.map_product uses join_rec_with_parens_around_types to add parentheses around quotient-like children even though they share precedence, improving readability and avoiding ambiguous a / b * c-style output. map_matmul currently uses join_rec directly, so expressions like (a/b) @ c may stringify as a / b @ c without clarifying grouping. Consider using the same parenthesization helper/operand-type handling as map_product for consistency.
| self.join_rec(" @ ", expr.children, PREC_PRODUCT, *args, **kwargs), | |
| self.join_rec_with_parens_around_types( | |
| " @ ", | |
| expr.children, | |
| PREC_PRODUCT, | |
| self.multiplicative_primitives, | |
| *args, | |
| **kwargs, | |
| ), |
There was a problem hiding this comment.
|
Thx! |
This adds some support for the
@operator toExpressionNodewithMatmulnode (naming? the other nodes don't use short names)Matmul, mainlyStringifyMapperandCollectorand friends.