-
Notifications
You must be signed in to change notification settings - Fork 35
Shard join nodes #3642
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Shard join nodes #3642
Conversation
3f453bc
to
445a98c
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #3642 +/- ##
==========================================
- Coverage 72.29% 72.26% -0.04%
==========================================
Files 213 213
Lines 29759 29788 +29
==========================================
+ Hits 21513 21525 +12
- Misses 8246 8263 +17 ☔ View full report in Codecov by Sentry. |
19758cb
to
16b20e6
Compare
include/ttmlir/Scheduler/Scheduler.h
Outdated
// Map of dependencies | ||
|
||
// Vector of all operations in deterministic order. | ||
llvm::SmallVector<mlir::Operation *> funcOps; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: name is a bit misleading. Seems like you want a vector of FuncOps instead of all op inside one FuncOp.
lib/Scheduler/Scheduler.cpp
Outdated
|
||
llvm::SmallVector<mlir::Operation *> Scheduler::getScheduleableOps() { | ||
llvm::SmallVector<mlir::Operation *> scheduleableOps; | ||
for (auto &op : unscheduledOps) { | ||
if (canSchedule(op)) { | ||
for (auto &op : funcOps) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to confirm, with this line the schedule will now always be deterministic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
// We will sort schedulable ops by prioritizing ops whose successors are still | ||
// blocked after scheduling it. This is a heuristic that lets us create longer | ||
// chains of ops that contain join nodes in fork-join structure. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking about this once more, does this heuristic improve performance in general tho? It seems to me like the length of the chain is not the important metric here and that we are left with the same number of tensors going to dram either way?
If so, but it works better for resnet we don't have to change it now. Let's leave a comment and file an issue for this tho.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also i think that this sorting logic should be in the policy and not in the scheduler. The idea is that the scheduler provides utilities for creating a legal op schedule. The actual logic of how we pick the next op should be within the policies as different policies might want to do different things.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general it doesn't improve, at least from what I can assume right now. But in resnet we get better sharding (9 ops more are sharded).
I'd leave sort as is, and leave a comment & issue that we should generalise this algo. Also, I think it's much easier to leave logic living in the Scheduler even though it's not the right place. Just because sorting is much easier given the state of Scheduler (dependencies between ops).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most probably because longer branches come to join node as LHS operand, thus we succeed in finding valid sharding layouts, while shorter branches come at RHS. If we pick ops randomly, we risk creating chains with shorter branches which ruins our sharding. Of course, this is super unstable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened: #3744
@@ -85,126 +85,129 @@ void DFShardingPolicy::run() { | |||
} | |||
} | |||
|
|||
if (nextOp) { | |||
bool validForSharding = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leave a comment what the condition here.
if (nextOp) { | ||
bool validForSharding = | ||
legalConfigs.lookup(currentOp).size() > 0 && | ||
(nextOp ? legalConfigs.lookup(nextOp).size() > 0 : true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this part of the condition needed? Why does the next op need to have a sharded config?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's not needed.
for (auto &op : unscheduledOps) { | ||
if (canSchedule(op)) { | ||
for (auto &op : funcOps) { | ||
if (!scheduledOpsMap.contains(op) && canSchedule(op)) { | ||
scheduleableOps.push_back(op); | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Fix type in name of function getSchedulableOps
// with API. | ||
// | ||
constexpr float tensorL1UsageCap = 0.8; | ||
bool l1UsageValid = (currentOpL1OutputUsage + nextOpL1OutputUsage) < |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to be calculating the usage for tensors that will be in L1 when nextOp is executed. But we are using the condition to check if we want to add currnetOp. Am i missing something or is this wrong?
If so feel free to delete all these memory checks as far as i'm concerned. These were written a long time ago and the actual memory checking is in MLA anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree these are not correct checks. Let me try removing them and see if something breaks.
// TODO(rpavlovicTT) After we inserted reshard in | ||
// preprocessFirstOp we dont need to try every producerId here, right? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me it seems like we do, but not completely sure as i don't have time to dig into this all the way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's leave this for discussion.
@@ -296,6 +305,8 @@ bool ShardSolver::supportsInterleavedInputShardedOutput(Operation *op, | |||
"Checking if interleaved to sharded is possible for op : {}", | |||
op->getName()); | |||
|
|||
// TODO(rpavlovicTT) this is bad as we are hardcoding this layout, while it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
File and issue please, the the problem is not in this function but in the place where it's used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (use.getOwner() != toLayoutOp) { | ||
use.getOwner()->setOperand(use.getOperandNumber(), | ||
toLayoutOp->getResult(0)); | ||
// TODO(rpavlovicTT): It's possible that spilled op was followed by |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add some tests to cover some of the new cases here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#3750 opened issue for it
1b8058e
to
4bcc61d
Compare
Three fold change in this commit: 1. Allow sharding join nodes in fork-join structure. 2. Change scheduler to prioritize "shorter" branches first. This will enabled us create longer chains including join nodes. 3. Fix connections after creating spill to dram op following join node. This node is followed by other chains and it's possible to skip reading from DRAM in some paths. Piggyback: re-enable sharding of eltwise binary ops.
opened issues to follow up
4bcc61d
to
913d40b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dialect changes LGTM.
lib/Scheduler/Scheduler.cpp
Outdated
// chains of ops that contain join nodes in fork-join structure. | ||
// This is not general solution and we want to change it in the future. | ||
// TODO(rpavlovicTT) https://github.com/tenstorrent/tt-mlir/issues/3744 | ||
if (schedulableOps.size() > 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: block is fairly big, it might be better to denest by inverting branch condition.
28b6d61
to
1113cec
Compare
Three fold change in this commit:
Piggyback: re-enable sharding of eltwise binary ops.
Related to #2276