8000 [REQ] Add wp.tile_reshape() · Issue #663 · NVIDIA/warp · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
[REQ] Add wp.tile_reshape() #663
Closed
@daedalus5

Description

@daedalus5

Description

Add wp.tile_reshape(), to function like np.reshape()

Test

@wp.kernel
def test_tile_reshape_kernel(
    x: wp.array2d(dtype=float),
    y: wp.array2d(dtype=float)
):
    a = wp.tile_load(x, shape=(TILE_M, TILE_N), offset=(0,0))
    b = wp.tile_reshape(a, shape=(wp.static(TILE_M*TILE_N), 1))

    wp.tile_store(y, b, offset=(0,0))


device = "cuda:0"

x = wp.ones((TILE_M, TILE_N), dtype=float, device=device, requires_grad=True)
y = wp.zeros((TILE_M*TILE_N, 1), dtype=float, device=device, requires_grad=True)

tape = wp.Tape()
with tape:
    wp.launch_tiled(test_tile_reshape_kernel, dim=1, inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device)

y.grad = wp.ones_like(y)
tape.backward()

Metadata

Metadata

Assignees

Labels

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0