8000 example/implementation for FedBalancer, with a new sampler category by jaemin-shin · Pull Request #380 · cisco-open/flame · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

example/implementation for FedBalancer, with a new sampler category #380

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 36 additions & 25 deletions lib/python/flame/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from .typing import ModelWeights
from .constants import DeviceType

PYTORCH = 'torch'
TENSORFLOW = 'tensorflow'
PYTORCH = "torch"
TENSORFLOW = "tensorflow"


class MLFramework(Enum):
Expand All @@ -43,7 +43,8 @@ class MLFramework(Enum):

ml_framework_in_use = MLFramework.UNKNOWN
valid_frameworks = [
framework.name.lower() for framework in MLFramework
framework.name.lower()
for framework in MLFramework
if framework != MLFramework.UNKNOWN
]

Expand Down Expand Up @@ -73,13 +74,16 @@ def get_ml_framework_in_use():

return ml_framework_in_use


def get_params_detached_pytorch(model):
"""Return copy of parameters of pytorch model disconnected from graph."""
return [param.detach().clone() for param in model.parameters()]


def get_params_as_vector_pytorch(params):
"""Return the list of parameters passed in concatenated into one vector."""
import torch

vector = None
for param in params:
if not isinstance(vector, torch.Tensor):
Expand All @@ -88,37 +92,39 @@ def get_params_as_vector_pytorch(params):
vector = torch.cat((vector, param.reshape(-1)), 0)
return vector


def get_dataset_filename(link):
"""Return path for file location"""
# currently only supports https and local file
if link.startswith('https://'):
if link.startswith("https://"):
import requests

r = requests.get(link, allow_redirects=True)

try:
filename = link.split('/')[-1]
open(filename, 'wb').write(r.content)
filename = link.split("/")[-1]
open(filename, "wb").write(r.content)
except:
filename = 'data'
open(filename, 'wb').write(r.content)
filename = "data"
open(filename, "wb").write(r.content)

return filename
elif link.startswith('file://'):

elif link.startswith("file://"):
return link[7:]

raise TypeError('link format not supported; use either https:// or file://')
raise TypeError("link format not supported; use either https:// or file://")


@contextmanager
def background_thread_loop():

def run_forever(loop):
asyncio.set_event_loop(loop)
loop.run_forever()

_loop = asyncio.new_event_loop()

_thread = Thread(target=run_forever, args=(_loop, ), daemon=True)
_thread = Thread(target=run_forever, args=(_loop,), daemon=True)
_thread.start()
yield _loop

Expand All @@ -134,11 +140,11 @@ def run_async(coro, loop, timeout=None):
def install_packages(packages: List[str]) -> None:
for package in packages:
if not install_package(package):
print(f'Failed to install package: {package}')
print(f"Failed to install package: {package}")


def install_package(package: str) -> bool:
if pipmain(['install', package]) == 0:
if pipmain(["install", package]) == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This apparently has the potential to change the LOG_LEVEL of the terminal. Not necessarily relevant to this PR though.

return True

return False
Expand All @@ -151,20 +157,22 @@ def mlflow_runname(config: Config) -> str:
if val in config.realm:
groupby_value = groupby_value + val + "-"

return config.role + '-' + groupby_value + config.task_id[:8]
return config.role + "-" + groupby_value + config.task_id[:8]


def delta_weights_pytorch(a: ModelWeights,
b: ModelWeights) -> Union[ModelWeights, None]:
def delta_weights_pytorch(
a: ModelWeights, b: ModelWeights
) -> Union[ModelWeights, None]:
10000 """Return delta weights for pytorch model weights."""
if a is None or b is None:
return None

return {x: a[x] - b[y] for (x, y) in zip(a, b)}


def delta_weights_tensorflow(a: ModelWeights,
b: ModelWeights) -> Union[ModelWeights, None]:
def delta_weights_tensorflow(
a: ModelWeights, b: ModelWeights
) -> Union[ModelWeights, None]:
"""Return delta weights for tensorflow model weights."""
if a is None or b is None:
return None
Expand All @@ -174,27 +182,30 @@ def delta_weights_tensorflow(a: ModelWeights,

def get_pytorch_device(dtype: DeviceType):
import torch

if dtype == DeviceType.CPU:
device_name = "cpu"
elif dtype == DeviceType.GPU:
device_name = "cuda"
else:
raise TypeError(f"Device type {dtype} is not supported.")

return torch.device(device_name)


def weights_to_device(weights, dtype: DeviceType):
"""Send model weights to device type dtype."""

framework = get_ml_framework_in_use()
if framework == MLFramework.TENSORFLOW:
return weights
elif framework == MLFramework.PYTORCH:
torch_device = get_pytorch_device(dtype)
return {name: weights[name].to(torch_device) for name in weights}

return None


def weights_to_model_device(weights, model):
"""Send model weights to same device as model"""
framework = get_ml_framework_in_use()
Expand All @@ -204,5 +215,5 @@ def weights_to_model_device(weights, model):
# make assumption all tensors are on same device
torch_device = next(model.parameters()).device
return {name: weights[name].to(torch_device) for name in weights}

return None
17 changes: 17 additions & 0 deletions lib/python/flame/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ class SelectorType(str, Enum):
OORT = "oort"


class DataSamplerType(str, Enum):
"""Define datasampler types."""

DEFAULT = "default"
FEDBALANCER = "fedbalancer"


class Job(FlameSchema):
job_id: str = Field(alias="id")
name: str
Expand All @@ -90,6 +97,11 @@ class Selector(FlameSchema):
kwargs: dict = Field(default={})


class DataSampler(FlameSchema):
sort: DataSamplerType = Field(default=DataSamplerType.DEFAULT)
kwargs: dict = Field(default={})


class Optimizer(FlameSchema):
sort: OptimizerType = Field(default=OptimizerType.DEFAULT)
kwargs: dict = Field(default={})
Expand Down Expand Up @@ -172,6 +184,7 @@ def __init__(self, config_path: str):
job: Job
registry: t.Optional[Registry]
selector: t.Optional[Selector]
datasampler: t.Optional[DataSampler] = Field(default=DataSampler())
optimizer: t.Optional[Optimizer] = Field(default=Optimizer())
dataset: str
max_run_time: int
Expand Down Expand Up @@ -224,6 +237,10 @@ def transform_config(raw_config: dict) -> dict:
if raw_config.get("optimizer", None):
config_data = config_data | {"optimizer": raw_config.get("optimizer")}

if raw_config.get("datasampler", None):
raw_config["datasampler"]["kwargs"].update(hyperparameters)
config_data = config_data | {"datasampler": raw_config.get("datasampler")}

config_data = config_data | {
"dataset": raw_config.get("dataset", ""),
"max_run_time": raw_config.get("maxRunTime", 300),
Expand Down
78 changes: 78 additions & 0 deletions lib/python/flame/datasampler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""datasampler abstract class."""

from abc import ABC, abstractmethod
from typing import Any

from flame.channel import Channel


class AbstractDataSampler(ABC):
class AbstractTrainerDataSampler(ABC):
"""Abstract base class for trainer-side datasampler implementation."""

def __init__(self, **kwargs) -> None:
"""Initialize an instance with keyword-based arguments."""
for key, value in kwargs.items():
setattr(self, key, value)

@abstractmethod
def sample(self, dataset: Any, **kwargs) -> Any:
"""Abstract method to sample data.

Parameters
----------
dataset: Dataset of a trainer to select samples from
kwargs: other arguments specific to each datasampler algorithm

Returns
-------
dataset: Dataset that only contains selected samples
"""

@abstractmethod
def load_dataset(self, dataset: Any) -> Any:
"""Process dataset instance for datasampler."""

@abstractmethod
def get_metadata(self) -> dict[str, Any]:
"""Return metadata to send to aggregator-side datasampler."""

@abstractmethod
def handle_metadata_from_aggregator(self, metadata: dict[str, Any]) -> None:
"""Handle aggregator metadata for datasampler."""

class AbstractAggregatorDataSampler(ABC):
"""Abstract base class for aggregator-side datasampler implementation."""

def __init__(self, **kwargs) -> None:
"""Initialize an instance with keyword-based arguments."""
for key, value in kwargs.items():
setattr(self, key, value)

@abstractmethod
def get_metadata(self, end: str, round: int) -> dict[str, Any]:
"""Return metadata to send to trainer-side datasampler."""

@abstractmethod
def handle_metadata_from_trainer(
self,
metadata: dict[str, Any],
end: str,
channel: Channel,
) -> None:
"""Handle trainer metadata for datasampler."""
77 changes: 77 additions & 0 deletions lib/python/flame/datasampler/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""DefaultDataSampler class."""

import logging
from typing import Any

from flame.channel import Channel
from flame.datasampler import AbstractDataSampler

logger = logging.getLogger(__name__)


class DefaultDataSampler(AbstractDataSampler):
def __init__(self) -> None:
self.trainer_data_sampler = DefaultDataSampler.DefaultTrainerDataSampler()
self.aggregator_data_sampler = DefaultDataSampler.DefaultAggregatorDataSampler()

class DefaultTrainerDataSampler(AbstractDataSampler.AbstractTrainerDataSampler):
"""A default trainer-side datasampler class."""

def __init__(self, **kwargs):
"""Initailize instance."""
super().__init__()

def sample(self, dataset: Any, **kwargs) -> Any:
"""Return all dataset from the given dataset."""
logger.debug("calling default datasampler")

return dataset

def load_dataset(self, dataset: Any) -> None:
"""Change dataset instance to return index with each sample."""
return dataset

def get_metadata(self) -> dict[str, Any]:
"""Return metadata to send to aggregator-side datasampler."""
return {}

def handle_metadata_from_aggregator(self, metadata: dict[str, Any]) -> None:
"""Handle aggregator metadata for datasampler."""
pass

class DefaultAggregatorDataSampler(
AbstractDataSampler.AbstractAggregatorDataSampler
):
"""A default aggregator-side datasampler class."""

def __init__(self, **kwargs):
"""Initailize instance."""
super().__init__()

def get_metadata(self, end: str, round: int) -> dict[str, Any]:
"""Return metadata to send to trainer-side datasampler."""
return {}

def handle_metadata_from_trainer(
self,
metadata: dict[str, Any],
end: str,
channel: Channel,
) -> None:
"""Handle trainer metadata for datasampler."""
pass
Loading
0