8000 [Feature] Linear scan functionality · Issue #1997 · triton-lang/triton · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[Feature] Linear scan functionality #1997

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
bohnstingl opened this issue Jul 27, 2023 · 2 comments
Closed

[Feature] Linear scan functionality #1997

bohnstingl opened this issue Jul 27, 2023 · 2 comments

Comments

@bohnstingl
Copy link

Would there be a possibility to add a linear scan functionality? Currently there exists an associative_scan function, but the associativity might not always be guaranteed, for example in RNNs.
Could one tweak the associative_scan functionality to traverse linearly over the axis instead of parallelizing the sequence handling?

A concrete example where this is needed are generic RNNs with the following equations:
x -> input (time x batch x features)

carry1[t] = W @ x[t] + H carry2[t-1] + decay * carry1[t-1]
carry2[t] = f(carry1[t])

Thank you upfront for considering it.

cc @ThomasRaoux

@ThomasRaoux
Copy link
Collaborator

If the operation is not associative how would you distribute it within the threads of a GPU block?

If the idea is to parallelize the scan along a "batch" dimension only then this is something that can be done with existing triton language.

@bohnstingl
Copy link
Author

Right, I just realized that this boils down to a simple for-loop that executes a given function at every iteration.
I apologize for the confusion and close this issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants
0