Open
Description
Firstly,thank you for your code
but as i try to read your source code.i find maybe there is errors in your squash function code
Problem1:
from your readme file,i read the tensorflow source code
Squashing function corresponding to Eq. 1
Args:
vector: A tensor with shape [batch_size, 1, num_caps, vec_len, 1] or [batch_size, num_caps, vec_len, 1].
Returns:
A tensor with the same shape as vector but squashed in 'vec_len' dimension.
in the comment,we squash in the vec_len dimension.But in your code
def squash(s):
# This is equation 1 from the paper.
mag_sq = torch.sum(s**2, dim=2, keepdim=True)
mag = torch.sqrt(mag_sq)
s = (mag_sq / (1.0 + mag_sq)) * (s / mag)
return s
because you have not wrote comment.so we just see here
# Flatten to (batch, unit, output).
u = u.view(x.size(0), self.num_units, -1)
# Return squashed outputs.
return CapsuleLayer.squash(u)
it is easy to know we should do squashing in dim=1 not 2
Problem2:
# (batch, features, in_units) -> (batch, features, num_units, in_units, 1)
x = torch.stack([x] * self.num_units, dim=2).unsqueeze(4)
# (batch, features, in_units, unit_size, num_units)
W = torch.cat([self.W] * batch_size, dim=0)
# Transform inputs by weight matrix.
# (batch_size, features, num_units, unit_size, 1)
u_hat = torch.matmul(W, x)
how can x with shape(batch, features, num_units, in_units, 1) and w with shape (batch, features, in_units, unit_size, num_units) do matmul operate...
i do not run your code successfully,because of data.so i do not know is it right.
Best Wishes!!
Metadata
Metadata
Assignees
Labels
No labels