8000 GitHub - xx-tao/point-transformer-pytorch: Implementation of the Point Transformer layer, in Pytorch
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

xx-tao/point-transformer-pytorch

 
 

Repository files navigation

Point Transformer - Pytorch

Implementation of the Point Transformer self-attention layer, in Pytorch. The simple circuit above seemed to have allowed their group to outperform all previous methods in point cloud classification and segmentation.

Install

$ pip install point-transformer-pytorch

Usage

import torch
from point_transformer_pytorch import PointTransformerLayer

attn = PointTransformerLayer(
    dim = 128,
    pos_mlp_hidden_dim = 64,
    attn_mlp_hidden_mult = 4
)

feats = torch.randn(1, 16, 128)
pos = torch.randn(1, 16, 3)
mask = torch.ones(1, 16).bool()

attn(feats, pos, mask = mask) # (1, 16, 128)

This type of vector attention is much more expensive than the traditional one. In the paper, they used k-nearest neighbors on the points to exclude attention on faraway points. You can do the same with a single extra setting.

import torch
from point_transformer_pytorch import PointTransformerLayer

attn = PointTransformerLayer(
    dim = 128,
    pos_mlp_hidden_dim = 64,
    attn_mlp_hidden_mult = 4,
    num_neighbors = 16          # only the 16 nearest neighbors would be attended to for each point
)

feats = torch.randn(1, 2048, 128)
pos = torch.randn(1, 2048, 3)
mask = torch.ones(1, 2048).bool()

attn(feats, pos, mask = mask) # (1, 16, 128)

Citations

@misc{zhao2020point,
    title={Point Transformer}, 
    author={Hengshuang Zhao and Li Jiang and Jiaya Jia and Philip Torr and Vladlen Koltun},
    year={2020},
    eprint={2012.09164},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

About

Implementation of the Point Transformer layer, in Pytorch

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Python 100.0%
0