Open
Description
I wanna write transformer module ,some subclass need rewrite
like python
//class PositionalEncoding(nnModule
//):
// def __init__(self, d_model, max_len = 28 * 28):
// super
// (PositionalEncoding, self).__init__()
// self.encoding = torch.zeros(max_len, d_model)
// position = torch.arange(0, max_len, dtype = torch.float).unsqueeze(1)
// div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
// self.encoding[:
// , 0 :: 2
// ] = torch.sin(position * div_term)
// self.encoding[:
// , 1 :: 2
// ] = torch.cos(position * div_term)
// self.encoding = self.encoding.unsqueeze(0)
//
// def forward(self, x):
// return x + self.encoding[:
// ,: x.size(1)
// ].to(x.device)
but in scala ,torch.zeros() dtype how to pass correct dtype on Model class init not in apply(),and tensor slice select
class PositionalEncoding[D <: BFloat16 | Float32 : Default](d_model: Long, max_len: Long = 28 * 28) extends HasParams[D] {
val arr = Seq(max_len,d_model)
var encoding = torch.zeros(size = arr.map(_.toInt), dtype = D.default)
val position = torch.arange(0, max_len, dtype = torch.float).unsqueeze(1)
val div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
// encoding[::, 0 :: 2] = torch.sin(position * div_term)
// encoding[::, 1 :: 2] = torch.cos(position * div_term)
encoding[::,torch.indexing.::(0,2)] )= torch.sin(position * div_term)
encoding[::, 1 :: 2] = torch.cos(position * div_term)
encoding = encoding.unsqueeze(0)
def apply(x: Tensor[D]): Tensor[D] =
x + encoding[::,:: ].to(x.device)
}
thanks for your reply
Metadata
Metadata
Assignees
Labels
No labels