Skip to content

Comments

Add support for __matmul__ in ExpressionNode#235

Merged
inducer merged 5 commits intoinducer:mainfrom
alexfikl:add-matmul
Feb 12, 2026
Merged

Add support for __matmul__ in ExpressionNode#235
inducer merged 5 commits intoinducer:mainfrom
alexfikl:add-matmul

Conversation

@alexfikl
Copy link
Collaborator

This adds some support for the @ operator to ExpressionNode with

  • A new Matmul node (naming? the other nodes don't use short names)
  • Support in some mappers for Matmul, mainly StringifyMapper and Collector and friends.

Comment on lines 157 to 160
@override
def map_matmul(self, expr: p.Matmul, /) -> ResultT:
if not expr.children:
return cast("ResultT", 1)
Copy link
Collaborator Author

@alexfikl alexfikl Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this is a good idea? Mainly because ndarray @ 1 fails due to the shape mismatch (even if it's (1,)).

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, might be better to error in that case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

map_product: Callable[[Self, p.Product], ArithmeticExpression] = map_sum
map_sum: Callable[[Self, p.Sum], ArithmeticExpression] = _count_children
map_product: Callable[[Self, p.Product], ArithmeticExpression] = _count_children
map_matmul: Callable[[Self, p.Matmul], ArithmeticExpression] = _count_children
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a "matrix multiply", not sure this makes sense in the FlopCounter, but I guess subclasses will have to figure it out?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, leave unimplemented.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was already implemented in CombineMapper, so all this did was add the nops + 1 to the children cost. I switched that to NotImplementedError completely to avoid confusion.

https://github.com/inducer/pymbolic/compare/b3b9070e6a5ed0d9ab1e65dde84a0a6a08432e74..7f9907d104dff38ce50756f92f994532e49a47ad

enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
) -> str:
return self.parenthesize_if_needed(
self.join_rec(" ", expr.children, PREC_PRODUCT, *args, **kwargs),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

? Not sure I have any better idea of what to map it to in LaTeX.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's OK.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new Matmul expression primitive and plumbing so ExpressionNode can represent/use Python’s @ operator across core mappers (stringification, evaluation, traversal, etc.).

Changes:

  • Introduces prim.Matmul and implements ExpressionNode.__matmul__ / __rmatmul__.
  • Adds Matmul handling across multiple mappers (stringifiers, evaluator, graphviz, flop counter, generic mapper base classes).
  • Adds a basic unit test and updates the basedpyright baseline.

Reviewed changes

Copilot reviewed 17 out of 17 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
pymbolic/primitives.py Adds Matmul node and @ operator overloads on ExpressionNode.
pymbolic/mapper/__init__.py Adds map_matmul to base/utility mappers (default, combine, identity, walk, constant).
pymbolic/mapper/stringifier.py Adds map_matmul for standard and LaTeX stringification.
pymbolic/mapper/evaluator.py Adds evaluation logic for Matmul using operator.matmul/reduce; adjusts product evaluation implementation.
pymbolic/mapper/graphviz.py Adds graphviz visualization for Matmul.
pymbolic/mapper/flop_counter.py Counts Matmul similarly to Sum/Product (n-1 plus children).
pymbolic/mapper/c_code.py Explicitly rejects Matmul in C codegen via NotImplementedError.
pymbolic/mapper/cse_tagger.py Tags Matmul subexpressions for CSE handling.
pymbolic/mapper/unifier.py Refactors imports to use pymbolic.primitives as p consistently.
pymbolic/mapper/substitutor.py Minor import refactor for Variable.
pymbolic/mapper/flattener.py Refactors to call p.flattened_sum/product directly.
pymbolic/mapper/distributor.py Refactors Sum reference to p.Sum.
pymbolic/mapper/constant_folder.py Refactors to use prim.flattened_sum/product directly.
pymbolic/mapper/collector.py Import/type-expression refactors in TermCollector.split_term.
pymbolic/mapper/coefficient.py Refactors to use p.* helpers/types directly.
test/test_pymbolic.py Adds a basic test_matmul asserting node type and str() output.
.basedpyright/baseline.json Adds new suppressions (unannotated class attribute / Any).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

self, expr: p.Matmul, /, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs
) -> str:
return self.parenthesize_if_needed(
self.join_rec(" @ ", expr.children, PREC_PRODUCT, *args, **kwargs),
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StringifyMapper.map_product uses join_rec_with_parens_around_types to add parentheses around quotient-like children even though they share precedence, improving readability and avoiding ambiguous a / b * c-style output. map_matmul currently uses join_rec directly, so expressions like (a/b) @ c may stringify as a / b @ c without clarifying grouping. Consider using the same parenthesization helper/operand-type handling as map_product for consistency.

Suggested change
self.join_rec(" @ ", expr.children, PREC_PRODUCT, *args, **kwargs),
self.join_rec_with_parens_around_types(
" @ ",
expr.children,
PREC_PRODUCT,
self.multiplicative_primitives,
*args,
**kwargs,
),

Copilot uses AI. Check for mistakes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@inducer inducer enabled auto-merge (rebase) February 12, 2026 17:39
@inducer
Copy link
Owner

inducer commented Feb 12, 2026

Thx!

@inducer inducer merged commit c5483ac into inducer:main Feb 12, 2026
10 checks passed
@alexfikl alexfikl deleted the add-matmul branch February 12, 2026 18:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants