8000 Dynamic Sharding API + Test for EBC, TW, ShardedTensor by aporialiao · Pull Request #2852 · pytorch/torchrec · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Dynamic Sharding API + Test for EBC, TW, ShardedTensor #2852

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

aporialiao
Copy link
Member

Summary:
Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

Motivation for Dynamic Sharding: Doc [Work in Progress]
Design: [WIP]

What's added here:

  1. A reshard API which implements the update_shards APIs for ShardedEmbeddingBagCollection

  2. Util functions for dynamic sharding - these are used by the update_shards API:

    1. extend_shard_name: for extending table_i to embedding_bags.table_i.weight
    2. shards_all_to_all: containing the all to all collective call to redistribute shards in a distributed environment, based on the changed_sharding_params
    3. update_state_dict_post_resharding: for updating a given state_dict with new shard placements and local_shards.
  3. A multi-process unit test test_dynamic_sharding_ebc_tw testing TW sharded EBCs calling the reshard API, sampling from various: world_sizes, num_tables, data_types.

    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call hereD71703434

Future work items (features not yet supported in this diff):

  • CW, RW, and many other sharding types
  • Optimizer saving
  • DTensor implementation

Differential Revision: D69095169

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 27, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69095169

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69095169

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Mar 27, 2025
Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be t
8000
o merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69095169

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Mar 28, 2025
Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69095169

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Mar 28, 2025
Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69095169

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Mar 31, 2025
)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
8000
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69095169

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Mar 31, 2025
)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69095169

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 1, 2025
)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69095169

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 1, 2025
)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 1, 2025
)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 1, 2025
)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 1, 2025
)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 1, 2025
)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 2, 2025
)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 2, 2025
)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69095169

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69095169

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 2, 2025
)

Summary:
Pull Request resolved: pytorch#2852

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here:
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`.

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`.
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 2, 2025
)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69095169

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 2, 2025
)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69095169

)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69095169

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 2, 2025
)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
0