Skip to content

Conversation

@ryanswann-amd
Copy link
Collaborator

Motivation

This PR integrates tritonBLAS's composable stage abstractions into the iris.ops fused communication+compute kernels. The goal is to leverage tritonBLAS's device side APIs to reduce kernel code complexity.

Technical Details

Changes to iris.ops Kernels

The following iris.ops kernels now use tritonBLAS stages:

Kernel tritonBLAS Components Used
matmul_all_gather.py GemmContext, ScheduleContext, make_tensor_view
matmul_all_reduce.py GemmContext, make_tensor_view, Tile
matmul_reduce_scatter.py GemmContext, ScheduleContext, make_tensor_view
all_gather_matmul.py GemmContext, ScheduleContext

Example Pattern

Before (custom GEMM implementation):

# Manual accumulator init, dot products, and K-loop
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
    a = tl.load(A_ptr + ...)
    b = tl.load(B_ptr + ...)
    acc = tl.dot(a, b, acc)

After (using tritonBLAS stages):

# Create views and context
tensorA = make_tensor_view(A, M, K, stride_am, stride_ak)
tensorB = make_tensor_view(B, K, N, stride_bk, stride_bn)
gemm_ctx = GemmContext(BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, ...)

# Single call handles accumulator init, K-loop, and reduction
acc = gemm_ctx.reduce_axis(tensorA, tensorB, out_tile)

Test Plan

Run all the tests in tests/ops

Test Result

Tests passed

Submission Checklist

Copilot AI review requested due to automatic review settings February 4, 2026 23:57
@github-actions github-actions bot added in-progress We are working on it iris Iris project issue labels Feb 4, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR integrates tritonBLAS's composable stage abstractions into iris.ops fused communication+compute kernels to reduce code complexity and improve maintainability.

Changes:

  • Replaces custom GEMM implementations with tritonBLAS's GemmContext, ScheduleContext, and make_tensor_view APIs
  • Updates TensorView to store dimensions/strides as tensors instead of constexpr values, with a new make_tensor_view factory function
  • Simplifies all_reduce_two_shot to use interleaved distribution instead of parameterized start_tile/stride
  • Removes unnecessary parentheses in lambda function

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
iris/x/core.py Adds make_tensor_view factory function and updates TensorView to store tensor fields instead of constexpr
iris/x/all_reduce.py Simplifies all_reduce_two_shot by removing start_tile/stride parameters and using interleaved distribution
iris/x/init.py Exports new make_tensor_view function
iris/ops/matmul_reduce_scatter.py Replaces custom GEMM loop with tritonBLAS GemmContext and updates to use make_tensor_view
iris/ops/matmul_all_reduce.py Replaces custom GEMM loop with tritonBLAS stages and restructures variant dispatch logic
iris/ops/matmul_all_gather.py Replaces custom GEMM implementation with tritonBLAS GemmContext and ScheduleContext
iris/ops/all_gather_matmul.py Replaces custom GEMM with tritonBLAS stages, adds NUM_K_BLOCKS_LOCAL parameter, and improves code organization
examples/common/utils.py Removes unnecessary parentheses from lambda function

rm, rn = out_tile.indices()

c = acc.to(C.dtype.element_ty)
c = acc.to(C.type.element_ty)
Copy link

Copilot AI Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attribute access changed from C.dtype.element_ty to C.type.element_ty. Verify that C.type.element_ty is the correct attribute for the pointer type in Triton, as this differs from the original code pattern.

Suggested change
c = acc.to(C.type.element_ty)
c = acc.to(C.dtype.element_ty)

Copilot uses AI. Check for mistakes.

# Convert to output dtype
c = acc.to(C.dtype.element_ty)
c = acc.to(C.type.element_ty)
Copy link

Copilot AI Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attribute access changed from C.dtype.element_ty to C.type.element_ty. Verify that C.type.element_ty is the correct attribute for the pointer type in Triton, as this differs from the original code pattern.

Suggested change
c = acc.to(C.type.element_ty)
c = acc.to(C.dtype.element_ty)

Copilot uses AI. Check for mistakes.
Comment on lines +102 to +103
# Promote tile_k to tensor (TileView expects tl.tensor for pid_n)
tile_k = pid_m * 0 + k_offset // BLOCK_SIZE_K
Copy link

Copilot AI Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The expression pid_m * 0 + k_offset // BLOCK_SIZE_K uses an unusual pattern to convert to tensor. Consider using a more explicit method or adding a comment explaining why this approach is necessary for TileView.

Suggested change
# Promote tile_k to tensor (TileView expects tl.tensor for pid_n)
tile_k = pid_m * 0 + k_offset // BLOCK_SIZE_K
# Promote scalar tile_k to tensor (TileView expects tl.tensor for pid_n)
tile_k = tl.full_like(pid_m, k_offset // BLOCK_SIZE_K)

Copilot uses AI. Check for mistakes.
Comment on lines +126 to +127
# Promote tile_k to tensor (TileView expects tl.tensor for pid_n)
tile_k = pid_m * 0 + k_offset // BLOCK_SIZE_K
Copy link

Copilot AI Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The expression pid_m * 0 + k_offset // BLOCK_SIZE_K uses an unusual pattern to convert to tensor. Consider using a more explicit method or adding a comment explaining why this approach is necessary for TileView.

Suggested change
# Promote tile_k to tensor (TileView expects tl.tensor for pid_n)
tile_k = pid_m * 0 + k_offset // BLOCK_SIZE_K
# Promote scalar K-tile index to tensor matching pid_m's shape for TileView
tile_k_scalar = k_offset // BLOCK_SIZE_K
tile_k = tl.full_like(pid_m, tile_k_scalar)

Copilot uses AI. Check for mistakes.
@mawad-amd
Copy link
Collaborator

Thanks, Ryan! you will need to change the SHA for the tests to run. Check this please: #350

# Create DeviceContext and TensorView for gather operations
ctx = iris.x.DeviceContext(cur_rank, world_size, heap_bases)
src_view = iris.x.TensorView(A_sharded, M, K_local, stride_am, stride_ak)
iris_ctx = iris.x.DeviceContext(cur_rank, world_size, heap_bases)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There will be a small conflict here once you sync with main:

ctx = iris.DeviceContext.initialize(context_tensor, cur_rank, world_size)

Copy link
Collaborator

@mawad-amd mawad-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very neat. Thanks, Ryan!

@mawad-amd mawad-amd merged commit ff6ef71 into main Feb 7, 2026
72 checks passed
@mawad-amd mawad-amd deleted the ryaswann/use_tblas_stages branch February 7, 2026 14:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

in-progress We are working on it iris Iris project issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants