8000 some Tensor opration meet erro · Issue #81 · sbrunk/storch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
some Tensor opration meet erro  #81
Open
@mullerhai

Description

@mullerhai

I wanna write transformer module ,some subclass need rewrite
like python


//class PositionalEncoding(nnModule
//):
//  def __init__(self, d_model, max_len = 28 * 28):
//  super
//  (PositionalEncoding, self).__init__()
//  self.encoding = torch.zeros(max_len, d_model)
//  position = torch.arange(0, max_len, dtype = torch.float).unsqueeze(1)
//  div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
//  self.encoding[:
//  , 0 :: 2
//  ] = torch.sin(position * div_term)
//  self.encoding[:
//  , 1 :: 2
//  ] = torch.cos(position * div_term)
//  self.encoding = self.encoding.unsqueeze(0)
//
//  def forward(self, x):
//  return x + self.encoding[:
//  ,: x.size(1)
//  ].to(x.device)

but in scala ,torch.zeros() dtype how to pass correct dtype on Model class init not in apply(),and tensor slice select

class PositionalEncoding[D <: BFloat16 | Float32 : Default](d_model: Long, max_len: Long = 28 * 28) extends HasParams[D] {

  val arr = Seq(max_len,d_model)
  var encoding = torch.zeros(size = arr.map(_.toInt), dtype = D.default)
  val position = torch.arange(0, max_len, dtype = torch.float).unsqueeze(1)
  val div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
//  encoding[::, 0 :: 2] = torch.sin(position * div_term)
//  encoding[::, 1 :: 2] = torch.cos(position * div_term)
  encoding[::,torch.indexing.::(0,2)] )= torch.sin(position * div_term)
  encoding[::, 1 :: 2] = torch.cos(position * div_term)
  encoding = encoding.unsqueeze(0)

  def apply(x: Tensor[D]): Tensor[D] =
    x + encoding[::,:: ].to(x.device)
}

thanks for your reply

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0