-
Notifications
You must be signed in to change notification settings - Fork 33
Use tritonBLAS Device Side APIs in iris ops #358
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR integrates tritonBLAS's composable stage abstractions into iris.ops fused communication+compute kernels to reduce code complexity and improve maintainability.
Changes:
- Replaces custom GEMM implementations with tritonBLAS's
GemmContext,ScheduleContext, andmake_tensor_viewAPIs - Updates
TensorViewto store dimensions/strides as tensors instead of constexpr values, with a newmake_tensor_viewfactory function - Simplifies
all_reduce_two_shotto use interleaved distribution instead of parameterized start_tile/stride - Removes unnecessary parentheses in lambda function
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| iris/x/core.py | Adds make_tensor_view factory function and updates TensorView to store tensor fields instead of constexpr |
| iris/x/all_reduce.py | Simplifies all_reduce_two_shot by removing start_tile/stride parameters and using interleaved distribution |
| iris/x/init.py | Exports new make_tensor_view function |
| iris/ops/matmul_reduce_scatter.py | Replaces custom GEMM loop with tritonBLAS GemmContext and updates to use make_tensor_view |
| iris/ops/matmul_all_reduce.py | Replaces custom GEMM loop with tritonBLAS stages and restructures variant dispatch logic |
| iris/ops/matmul_all_gather.py | Replaces custom GEMM implementation with tritonBLAS GemmContext and ScheduleContext |
| iris/ops/all_gather_matmul.py | Replaces custom GEMM with tritonBLAS stages, adds NUM_K_BLOCKS_LOCAL parameter, and improves code organization |
| examples/common/utils.py | Removes unnecessary parentheses from lambda function |
| rm, rn = out_tile.indices() | ||
|
|
||
| c = acc.to(C.dtype.element_ty) | ||
| c = acc.to(C.type.element_ty) |
Copilot
AI
Feb 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Attribute access changed from C.dtype.element_ty to C.type.element_ty. Verify that C.type.element_ty is the correct attribute for the pointer type in Triton, as this differs from the original code pattern.
| c = acc.to(C.type.element_ty) | |
| c = acc.to(C.dtype.element_ty) |
|
|
||
| # Convert to output dtype | ||
| c = acc.to(C.dtype.element_ty) | ||
| c = acc.to(C.type.element_ty) |
Copilot
AI
Feb 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Attribute access changed from C.dtype.element_ty to C.type.element_ty. Verify that C.type.element_ty is the correct attribute for the pointer type in Triton, as this differs from the original code pattern.
| c = acc.to(C.type.element_ty) | |
| c = acc.to(C.dtype.element_ty) |
| # Promote tile_k to tensor (TileView expects tl.tensor for pid_n) | ||
| tile_k = pid_m * 0 + k_offset // BLOCK_SIZE_K |
Copilot
AI
Feb 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The expression pid_m * 0 + k_offset // BLOCK_SIZE_K uses an unusual pattern to convert to tensor. Consider using a more explicit method or adding a comment explaining why this approach is necessary for TileView.
| # Promote tile_k to tensor (TileView expects tl.tensor for pid_n) | |
| tile_k = pid_m * 0 + k_offset // BLOCK_SIZE_K | |
| # Promote scalar tile_k to tensor (TileView expects tl.tensor for pid_n) | |
| tile_k = tl.full_like(pid_m, k_offset // BLOCK_SIZE_K) |
| # Promote tile_k to tensor (TileView expects tl.tensor for pid_n) | ||
| tile_k = pid_m * 0 + k_offset // BLOCK_SIZE_K |
Copilot
AI
Feb 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The expression pid_m * 0 + k_offset // BLOCK_SIZE_K uses an unusual pattern to convert to tensor. Consider using a more explicit method or adding a comment explaining why this approach is necessary for TileView.
| # Promote tile_k to tensor (TileView expects tl.tensor for pid_n) | |
| tile_k = pid_m * 0 + k_offset // BLOCK_SIZE_K | |
| # Promote scalar K-tile index to tensor matching pid_m's shape for TileView | |
| tile_k_scalar = k_offset // BLOCK_SIZE_K | |
| tile_k = tl.full_like(pid_m, tile_k_scalar) |
|
Thanks, Ryan! you will need to change the SHA for the tests to run. Check this please: #350 |
iris/ops/all_gather_matmul.py
Outdated
| # Create DeviceContext and TensorView for gather operations | ||
| ctx = iris.x.DeviceContext(cur_rank, world_size, heap_bases) | ||
| src_view = iris.x.TensorView(A_sharded, M, K_local, stride_am, stride_ak) | ||
| iris_ctx = iris.x.DeviceContext(cur_rank, world_size, heap_bases) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There will be a small conflict here once you sync with main:
iris/iris/ops/all_gather_matmul.py
Line 96 in 383b35c
| ctx = iris.DeviceContext.initialize(context_tensor, cur_rank, world_size) |
mawad-amd
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks very neat. Thanks, Ryan!
Co-authored-by: Muhammad Awad <112003944+mawad-amd@users.noreply.github.com>
… ryaswann/use_tblas_stages
Motivation
This PR integrates tritonBLAS's composable stage abstractions into the iris.ops fused communication+compute kernels. The goal is to leverage tritonBLAS's device side APIs to reduce kernel code complexity.
Technical Details
Changes to iris.ops Kernels
The following iris.ops kernels now use tritonBLAS stages:
matmul_all_gather.pyGemmContext,ScheduleContext,make_tensor_viewmatmul_all_reduce.pyGemmContext,make_tensor_view,Tilematmul_reduce_scatter.pyGemmContext,ScheduleContext,make_tensor_viewall_gather_matmul.pyGemmContext,ScheduleContextExample Pattern
Before (custom GEMM implementation):
After (using tritonBLAS stages):
Test Plan
Run all the tests in tests/ops
Test Result
Tests passed
Submission Checklist