Description
In the function below, I am not sure about the shape of the matrice:
def compute_joint_tracking_error(
joint_pos: torch.Tensor, joint_pos_gt: torch.Tensor, frame_weights: torch.Tensor, num_envs: int
) -> float:
"""Compute weighted mean absolute joint position error across environments.
For each environment:
1. Take absolute difference between predicted and ground truth joint positions
2. Weight the differences by frame_weights to normalize across varying trajectory lengths
3. Take mean across joints
Finally, sum across environments and divide by num_envs for mean error.
"""
return torch.sum(torch.mean(torch.abs(joint_pos - joint_pos_gt), dim=1) * frame_weights).item() / num_envs
I understand that joint_pos has shape (num_envs * num_frames, num_links, 3) (e.g., for body positions), and frame_weights has shape (num_envs * num_frames). However, after taking the mean over joints (dim=1), the resulting tensor has shape (num_envs * num_frames, 3). Multiplying this by frame_weights (which is 1D) should cause a broadcasting error unless frame_weights is explicitly reshaped.
Also, the final return value should be a scalar float, but since the mean is taken over joints only, the tensor still has a dimension of size 3, suggesting the output would have shape (3,) before summation.
This leads me to think that joint_pos should instead be reshaped to (num_envs * num_frames, num_links * 3) before computing the mean, so that the mean is taken over all joint coordinates at once, resulting in a 1D tensor compatible with frame_weights.
My question is: Where in the code does this reshape happen for body_pos?
PS: I am not able to run the code, that why I can not test it.