You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi dear torchrec developers. I found a fatal bug when using EmbeddingCollection. The full stack is
[rank0]: File "/home/admin/hippo/worker/slave/aop_418921_aop_launcher_job_temp_m_20250528093245_6524584_job.worker_0_57_12/train/test_ebd.py", line 44, in <module>
[rank0]: main()
[rank0]: File "/home/admin/hippo/worker/slave/aop_418921_aop_launcher_job_temp_m_20250528093245_6524584_job.worker_0_57_12/train/test_ebd.py", line 36, in main
[rank0]: dmp = DistributedModelParallel(module = ec,
[rank0]: File "/opt/conda/envs/python3.10/lib/python3.10/site-packages/torchrec/distributed/model_parallel.py", line 278, in __init__
[rank0]: self._dmp_wrapped_module: nn.Module = self._init_dmp(module)
[rank0]: File "/opt/conda/envs/python3.10/lib/python3.10/site-packages/torchrec/distributed/model_parallel.py", line 343, in _init_dmp
[rank0]: return self._shard_modules_impl(module)
[rank0]: File "/opt/conda/envs/python3.10/lib/python3.10/site-packages/torchrec/distributed/model_parallel.py", line 381, in _shard_modules_impl
[rank0]: module = self._sharder_map[sharder_key].shard(
[rank0]: File "/opt/conda/envs/python3.10/lib/python3.10/site-packages/torchrec/distributed/embedding.py", line 1372, in shard
[rank0]: return ShardedEmbeddingCollection(
[rank0]: File "/opt/conda/envs/python3.10/lib/python3.10/site-packages/torchrec/distributed/embedding.py", line 632, in __init__
[rank0]: self.load_state_dict(module.state_dict())
[rank0]: File "/opt/conda/envs/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2581, in load_state_dict
[rank0]: raise RuntimeError(
[rank0]: RuntimeError: Error(s) in loading state_dict for ShardedEmbeddingCollection:
[rank0]: While copying the parameter named "embeddings.t1.weight", whose dimensions in the model are torch.Size([625000, 16]) and whose dimensions in the checkpoint are torch.Size([625000, 16]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
To reproduce it, just run the following code snippet with command torchrun --standalone --nnodes 1 --node_rank 0 --nproc_per_node 8 test_ebd.py
test_ebd.py attached below:
import os
import torch
from torchrec.modules.embedding_configs import EmbeddingConfig
from torchrec.modules.embedding_modules import EmbeddingCollection
from torchrec.distributed.embedding import EmbeddingCollectionSharder
from torchrec.distributed.planner.planners import EmbeddingShardingPlanner
from torchrec.distributed.model_parallel import DistributedModelParallel
import torch.distributed as dist
def main():
rank = int(os.environ["LOCAL_RANK"])
if torch.cuda.is_available():
device: torch.device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
e1_config = EmbeddingConfig(
name="t1", embedding_dim=16, num_embeddings=5000000, feature_names=["f1"]
)
ec = EmbeddingCollection(
tables=[e1_config],
device="meta"
)
_pg = dist.GroupMember.WORLD
_sharder = [EmbeddingCollectionSharder()]
planner = EmbeddingShardingPlanner()
plan = planner.collective_plan(
module=ec,
sharders=_sharder,
pg=_pg
)
dmp = DistributedModelParallel(module = ec,
device = device,
plan = plan,
sharders = _sharder
)
if __name__ == "__main__":
dist.init_process_group("nccl")
main()
torch 2.6.0, fbgemm 1.1.0, torchrec 1.1.0
This bug is not caused by torchrec since I found many similar bugs from other repositories. The temporary workaround is to set the device to "cuda", but then you cannot train large embedding tables.
The text was updated successfully, but these errors were encountered:
It does not happen to the sharding of EmbeddingBagCollection. I compared the ShardedEmbeddingBagCollection and ShardedEmbeddingCollection. EC will load states but not EBC when the device argument of the configs are set to 'meta'. I will go ahead and dig it later.
gouchangjiang
changed the title
Cannot copy out of meta tensor; no data!
[BUG]: Cannot copy out of meta tensor; no data!
May 29, 2025
Replacing the condition of loading states of ShardedEmbeddingCollection with that of the ShardedEmbeddingBagCollection solves the problem. I don't know if it has other (bad) side effects.
Hi dear torchrec developers. I found a fatal bug when using EmbeddingCollection. The full stack is
To reproduce it, just run the following code snippet with command
torchrun --standalone --nnodes 1 --node_rank 0 --nproc_per_node 8 test_ebd.py
test_ebd.py attached below:
torch 2.6.0, fbgemm 1.1.0, torchrec 1.1.0
This bug is not caused by torchrec since I found many similar bugs from other repositories. The temporary workaround is to set the device to "cuda", but then you cannot train large embedding tables.
The text was updated successfully, but these errors were encountered: