8000 [FEATURE] Add associative_scan support by ThomasRaoux · Pull Request #1858 · triton-lang/triton · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[FEATURE] Add associative_scan support #1858

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 29, 2023
Merged

Conversation

ThomasRaoux
Copy link
Collaborator

Implement associative_scan in the front end and implement lowering to LLVM for blocked layout where the scan happens on the fastest moving dimension. This will later be generalized to support more layout.

@ThomasRaoux ThomasRaoux force-pushed the scan2 branch 2 times, most recently from 317dc4a to 8091ffb Compare June 29, 2023 06:36
@ThomasRaoux ThomasRaoux marked this pull request as ready for review June 29, 2023 14:56
Implement associative_scan in the front end and implement lowering to
LLVM for blocked layout where the scan happens on the fastest moving dimension.
This will later be generalized to support more layout.
@Jokeren
Copy link
Contributor
Jokeren commented Jun 29, 2023

Thanks! Will review soon

10000
// Return the number of elements per thread along non-axis dims.
unsigned getNumParallelElementsPerThread();
// Return the number of threads per warp along non-axis dims.
unsigned getNumParrallelThreadsPerWarp();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

// Return the number of threads per warp along non-axis dims.
unsigned getNumParrallelThreadsPerWarp();
// Return the flat numbers of threads computing independent scan results.
unsigned getNumParrallelThreadsPerCTA();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about what it returns from the function name

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hopefully the comment is explicit enough?

// Return the flat numbers of threads computing independent scan results.
unsigned getNumParrallelThreadsPerCTA();
// Return the number of warps per CTA along axis dim.
unsigned getNumAxisWarps();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getAxisNumWarps?

// Return the number of threads per warp along axis dim.
unsigned getAxisNumThreadsPerWarp();
// Return the number of blocks along axis dim.
unsigned getNumAxisBlocks();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getAxisNumBlocks?

//
def TT_ScanOp: TT_Op<"scan",
[Pure,
SameOperandsEncoding,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it have SameOperandsAndResultEncoding and SameOperandsAndResultElementType?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, added it.

return std::make_tuple(laneIdAxis, warpIdAxis, flatIdParallel);
}

// Naive lowering of the scan op as a fallback for cases that we don't know
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought emitFastScan is already not a native lowering because it does use warp shuffle and is not a fallback

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops yes this comment was out of date.

// reduction into shared memory. Each parallel scan and each warp will store its
// own partial reductions. The shared memory is organized as follow:
// -----------------------------------------------------------------
// chunk 0: | scan 0 warp 0 | scan 1 warp 0 | scan 0 warp 1 | scan 1 warp 1 |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me what scan 1 and scan 0 are.
I get the idea though after reading the code

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those numbers are meant to be the non-axis dimension. I improved the comment a bit. Let me know if you think it could still be clarified.

@ThomasRaoux ThomasRaoux requested a review from Jokeren June 29, 2023 20:40
@ThomasRaoux ThomasRaoux merged commit 3be0608 into triton-lang:main Jun 29, 2023
@@ -0,0 +1,15 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_SCAN_OP_H
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use #pragma once

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem to be the convention followed in triton project.

pingzhuu pushed a commit to siliconflow/triton that referenced this pull request Apr 2, 2024
Implement associative_scan in the front end and implement lowering to
LLVM for blocked layout where the scan happens on the fastest moving
dimension. This will later be generalized to support more layout.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0