10000 Shard join nodes by rpavlovicTT · Pull Request #3642 · tenstorrent/tt-mlir · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Shard join nodes #3642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 12, 2025
Merged

Shard join nodes #3642

merged 3 commits into from
Jun 12, 2025

Conversation

rpavlovicTT
Copy link
Contributor

Three fold change in this commit:

  1. Allow sharding join nodes in fork-join structure.
  2. Change scheduler to prioritize "shorter" branches first. This will enabled us create longer chains including join nodes.
  3. Fix connections after creating spill to dram op following join node. This node is followed by other chains and it's possible to skip reading from DRAM in some paths.

Piggyback: re-enable sharding of eltwise binary ops.

Related to #2276

@codecov-commenter
Copy link
codecov-commenter commented Jun 3, 2025

Codecov Report

Attention: Patch coverage is 38.28125% with 79 lines in your changes missing coverage. Please review.

Project coverage is 72.26%. Comparing base (79da966) to head (1113cec).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
lib/Dialect/TTNN/Transforms/Optimizer.cpp 18.51% 44 Missing ⚠️
lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp 0.00% 18 Missing ⚠️
include/ttmlir/Dialect/TTNN/Analysis/Edge.h 0.00% 8 Missing ⚠️
lib/Scheduler/Scheduler.cpp 86.84% 5 Missing ⚠️
...clude/ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h 0.00% 1 Missing ⚠️
lib/Dialect/TTNN/Analysis/BFInterleavedPolicy.cpp 0.00% 1 Missing ⚠️
...ialect/TTNN/Analysis/GreedyL1InterleavedPolicy.cpp 0.00% 1 Missing ⚠️
lib/Dialect/TTNN/Analysis/ShardSolver.cpp 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3642      +/-   ##
==========================================
- Coverage   72.29%   72.26%   -0.04%     
==========================================
  Files         213      213              
  Lines       29759    29788      +29     
==========================================
+ Hits        21513    21525      +12     
- Misses       8246     8263      +17     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@rpavlovicTT rpavlovicTT force-pushed the rpavlovic/shard_joins branch 2 times, most recently from 19758cb to 16b20e6 Compare June 4, 2025 11:23
// Map of dependencies

// Vector of all operations in deterministic order.
llvm::SmallVector<mlir::Operation *> funcOps;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: name is a bit misleading. Seems like you want a vector of FuncOps instead of all op inside one FuncOp.


llvm::SmallVector<mlir::Operation *> Scheduler::getScheduleableOps() {
llvm::SmallVector<mlir::Operation *> scheduleableOps;
for (auto &op : unscheduledOps) {
if (canSchedule(op)) {
for (auto &op : funcOps) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, with this line the schedule will now always be deterministic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Comment on lines +73 to +76
// We will sort schedulable ops by prioritizing ops whose successors are still
// blocked after scheduling it. This is a heuristic that lets us create longer
// chains of ops that contain join nodes in fork-join structure.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about this once more, does this heuristic improve performance in general tho? It seems to me like the length of the chain is not the important metric here and that we are left with the same number of tensors going to dram either way?

If so, but it works better for resnet we don't have to change it now. Let's leave a comment and file an issue for this tho.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also i think that this sorting logic should be in the policy and not in the scheduler. The idea is that the scheduler provides utilities for creating a legal op schedule. The actual logic of how we pick the next op should be within the policies as different policies might want to do different things.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general it doesn't improve, at least from what I can assume right now. But in resnet we get better sharding (9 ops more are sharded).
I'd leave sort as is, and leave a comment & issue that we should generalise this algo. Also, I think it's much easier to leave logic living in the Scheduler even though it's not the right place. Just because sorting is much easier given the state of Scheduler (dependencies between ops).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most probably because longer branches come to join node as LHS operand, thus we succeed in finding valid sharding layouts, while shorter branches come at RHS. If we pick ops randomly, we risk creating chains with shorter branches which ruins our sharding. Of course, this is super unstable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened: #3744

@@ -85,126 +85,129 @@ void DFShardingPolicy::run() {
}
}

if (nextOp) {
bool validForSharding =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave a comment what the condition here.

if (nextOp) {
bool validForSharding =
legalConfigs.lookup(currentOp).size() > 0 &&
(nextOp ? legalConfigs.lookup(nextOp).size() > 0 : true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this part of the condition needed? Why does the next op need to have a sharded config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's not needed.

for (auto &op : unscheduledOps) {
if (canSchedule(op)) {
for (auto &op : funcOps) {
if (!scheduledOpsMap.contains(op) && canSchedule(op)) {
scheduleableOps.push_back(op);
}
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Fix type in name of function getSchedulableOps

// with API.
//
constexpr float tensorL1UsageCap = 0.8;
bool l1UsageValid = (currentOpL1OutputUsage + nextOpL1OutputUsage) <
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be calculating the usage for tensors that will be in L1 when nextOp is executed. But we are using the condition to check if we want to add currnetOp. Am i missing something or is this wrong?

If so feel free to delete all these memory checks as far as i'm concerned. These were written a long time ago and the actual memory checking is in MLA anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree these are not correct checks. Let me try removing them and see if something breaks.

Comment on lines +170 to +171
// TODO(rpavlovicTT) After we inserted reshard in
// preprocessFirstOp we dont need to try every producerId here, right?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me it seems like we do, but not completely sure as i don't have time to dig into this all the way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's leave this for discussion.

@@ -296,6 +305,8 @@ bool ShardSolver::supportsInterleavedInputShardedOutput(Operation *op,
"Checking if interleaved to sharded is possible for op : {}",
op->getName());

// TODO(rpavlovicTT) this is bad as we are hardcoding this layout, while it
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

File and issue please, the the problem is not in this function but in the place where it's used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (use.getOwner() != toLayoutOp) {
use.getOwner()->setOperand(use.getOperandNumber(),
toLayoutOp->getResult(0));
// TODO(rpavlovicTT): It's possible that spilled op was followed by
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some tests to cover some of the new cases here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#3750 opened issue for it

Three fold change in this commit:
1. Allow sharding join nodes in fork-join structure.
2. Change scheduler to prioritize "shorter" branches first. This will
   enabled us create longer chains including join nodes.
3. Fix connections after creating spill to dram op following join node.
   This node is followed by other chains and it's possible to skip
   reading from DRAM in some paths.

Piggyback: re-enable sharding of eltwise binary ops.
opened issues to follow up
@rpavlovicTT rpavlovicTT force-pushed the rpavlovic/shard_joins branch from 4bcc61d to 913d40b Compare June 12, 2025 07:31
Copy link
Contributor
@azecevicTT azecevicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dialect changes LGTM.

// chains of ops that contain join nodes in fork-join structure.
// This is not general solution and we want to change it in the future.
// TODO(rpavlovicTT) https://github.com/tenstorrent/tt-mlir/issues/3744
if (schedulableOps.size() > 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: block is fairly big, it might be better to denest by inverting branch condition.

@rpavlovicTT rpavlovicTT force-pushed the rpavlovic/shard_joins branch from 28b6d61 to 1113cec Compare June 12, 2025 13:28
@rpavlovicTT rpavlovicTT merged commit 27a4f73 into main Jun 12, 2025
61 checks passed
@rpavlovicTT rpavlovicTT deleted the rpavlovic/shard_joins branch June 12, 2025 15:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0