-
Notifications
You must be signed in to change notification settings - Fork 2k
[FEATURE] Add associative_scan support #1858
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
317dc4a
to
8091ffb
Compare
Implement associative_scan in the front end and implement lowering to LLVM for blocked layout where the scan happens on the fastest moving dimension. This will later be generalized to support more layout.
Thanks! Will review soon |
10000
include/triton/Analysis/Utility.h
Outdated
// Return the number of elements per thread along non-axis dims. | ||
unsigned getNumParallelElementsPerThread(); | ||
// Return the number of threads per warp along non-axis dims. | ||
unsigned getNumParrallelThreadsPerWarp(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
include/triton/Analysis/Utility.h
Outdated
// Return the number of threads per warp along non-axis dims. | ||
unsigned getNumParrallelThreadsPerWarp(); | ||
// Return the flat numbers of threads computing independent scan results. | ||
unsigned getNumParrallelThreadsPerCTA(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about what it returns from the function name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hopefully the comment is explicit enough?
include/triton/Analysis/Utility.h
Outdated
// Return the flat numbers of threads computing independent scan results. | ||
unsigned getNumParrallelThreadsPerCTA(); | ||
// Return the number of warps per CTA along axis dim. | ||
unsigned getNumAxisWarps(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getAxisNumWarps?
include/triton/Analysis/Utility.h
Outdated
// Return the number of threads per warp along axis dim. | ||
unsigned getAxisNumThreadsPerWarp(); | ||
// Return the number of blocks along axis dim. | ||
unsigned getNumAxisBlocks(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getAxisNumBlocks?
// | ||
def TT_ScanOp: TT_Op<"scan", | ||
[Pure, | ||
SameOperandsEncoding, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it have SameOperandsAndResultEncoding
and SameOperandsAndResultElementType
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, added it.
return std::make_tuple(laneIdAxis, warpIdAxis, flatIdParallel); | ||
} | ||
|
||
// Naive lowering of the scan op as a fallback for cases that we don't know |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought emitFastScan
is already not a native lowering because it does use warp shuffle and is not a fallback
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops yes this comment was out of date.
// reduction into shared memory. Each parallel scan and each warp will store its | ||
// own partial reductions. The shared memory is organized as follow: | ||
// ----------------------------------------------------------------- | ||
// chunk 0: | scan 0 warp 0 | scan 1 warp 0 | scan 0 warp 1 | scan 1 warp 1 | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not clear to me what scan 1 and scan 0 are.
I get the idea though after reading the code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
those numbers are meant to be the non-axis dimension. I improved the comment a bit. Let me know if you think it could still be clarified.
@@ -0,0 +1,15 @@ | |||
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_SCAN_OP_H |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use #pragma once
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem to be the convention followed in triton project.
Implement associative_scan in the front end and implement lowering to LLVM for blocked layout where the scan happens on the fastest moving dimension. This will later be generalized to support more layout.
Implement associative_scan in the front end and implement lowering to LLVM for blocked layout where the scan happens on the fastest moving dimension. This will later be generalized to support more layout.