8000 GitHub - srijith1996/FSL-SAGE: Federated Split Learning via Smashed Activation Gradient Estimation
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

srijith1996/FSL-SAGE

Repository files navigation

FSL-SAGE: Accelerating Federated Split Learning via Smashed Activation Gradient Estimation

Static Badge OpenReview

Introduction

Our Federated Split Learning (FSL) algorithm cuts down on communication overheads in traditional Split Learning methods by directly estimating server-returned gradients at each client using auxiliary models. The auxiliary models are much smaller versions of the server model which are explicitly trained to estimate the gradients that the server model would return for the client's local input.

The algorithm is summarized in the following schematic:

FSL-SAGE schematic

Please refer to our paper for further details, and consider citing the same if you find this work useful.

Requirements

The project requirements can be simply installed using the environment config file conda_env.yaml as follows:

conda env create -f conda_env.yaml

which will create a conda environment by the name sage. You can activate the conda environment using:

conda activate sage

and all dependency requirements should be met.

Configuration

This project is powered by Hydra, which allows hierarchical configurations and easy running of multiple ML experiments. The config files for hydra are located in the folder hydra_config.

There is a high degree of customizability here; datasets, models and FL algorithms can be plugged in using configs. Please check out our contributing readme for more details.

Datasets

Currently datasets are read and imported from the datas folder in the root of the repository. You can simply create a folder for the repository and download the dataset there. After performing the necessary preprocessing, simply use/extend the get_dataset() function in datasets/__init__.py

Training

To train FSL-SAGE with defaults from config.yaml, you can simply run

python main.py

Training results are saved in a saves folder in the root, where they will be saved in folders segregated based on the FL algorithm, model, dataset and distribution used.

The default number of clients (num_clients) is set to 10 and the default number of rounds is rounds=200. Each method will train upto the fixed rounds or until the number of MBs specified in comm_threshold_mb is reached.

To choose a specific model or algorithm, the Hydra command-line override functionality can be used as follows

python main.py model=resnet18 algorithm=cse_fsl dataset=cifar100 dataset.distribution=iid

The following options are currently supported, click them to reveal the details:

Algorithm

Syntax : algorithm=<key>. The FL algorithm to use for training. List of algorithms currently supported:

Key Algorithm
fed_avg FedAvg
sl_multi_server SplitFedv1
sl_single_server SplitFedv2
cse_fsl CSE-FSL
fsl_sage FSL-SAGE
Dataset

Syntax : dataset=<key>. The dataset used in training. List of datasets currently supported:

Key Dataset
cifar10 cifar10
cifar100 cifar100
Model

Syntax : model=<key>. The ML model to use for training. List of models currently supported:

Key Model
resnet18 ResNet-18
resnet50 ResNet-50
resnet56 ResNet-56
resnet110 ResNet-110

Note that currently the above resnet models apart from resnet18 haven't been tuned yet, so the results may not optimally represent FSL-SAGE's communication benefits.

Data distribution

Syntax : dataset.distribution=<key>. Determines the distrbution of the dataset across clients List of distributions currently supported:

Key Distribution
iid homogeneous
noniid_dirichlet heterogeneous

For noniid_dirichlet you can specify the value of alpha using the key dataset.alpha, e.g., dataset.alpha=1.


We also support multiruns in parallel using the hydra-joblib-launcher Thus, it is possible to run multiple experiments for different combinations of hyperparams, models, datasets or algorithms given sufficient GPU memory.

python main.py -m model=resnet18,simple_conv algorithm=fed_avg,sl_single_server,sl_multi_server,cse_fsl,fsl_sage

The above would create parallel jobs that would run main.py on all combinations of specified options. The number of jobs can be controlled by modifying the hydra.launcher.n_jobs option in config.yaml or by specifying hydra.launcher.n_jobs=<jobs> as an option to the script.

Inference Plots

Please check out the readme, the functions used in the plot_results.py and the configs in exp_configs.yaml on how to generate the plots for accuracy, communication load, etc.

Note on source code for LLMs

This code has been extended to allow for LLMs and a basic setup on natural language generation (NLG) using the WebNLG E2E dataset has been tested. We use LoRA fine-tuning on the GPT-2 medium model to learn the E2E task. The code is available on the llm branch of this repository and is built upon the LoRA codebase. The code currently does not use the PEFT or Transformers libraries by HuggingFace, since building model splitting on top of those is challenging. However, contributions to this end on the LLM branch would be welcome. We would like to encourage research on developing automatic model splitters for PyTorch and other popular deep learning frameworks, since these could become increasingly relevant as split learning or federated split learning methods become more popular.

About

Federated Split Learning via Smashed Activation Gradient Estimation

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published
0