8000 Re-enabling new keras optimizers by nvcastet · Pull Request #3860 · horovod/horovod · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Re-enabling new keras optimizers #3860

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

- Improved reducescatter performance by allocating output tensors before enqueuing the operation. ([#3824](https://github.com/horovod/horovod/pull/3824))
- Force tf.logical_and in hvd allreduce condition running on CPU. ([#3885](https://github.com/horovod/horovod/pull/3885))
- Support TF Keras 2.11+ optimizers. ([#3860](https://github.com/horovod/horovod/pull/3860))

### Deprecated

Expand Down
7 changes: 1 addition & 6 deletions examples/tensorflow2/tensorflow2_keras_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@
import horovod
import horovod.tensorflow.keras as hvd

from packaging import version
if version.parse(tf.keras.__version__.replace("-tf", "+tf")) < version.parse("2.11"):
from tensorflow.keras import optimizers
else:
from tensorflow.keras.optimizers import legacy as optimizers

def main():
# Horovod: initialize Horovod.
Expand Down Expand Up @@ -59,7 +54,7 @@ def main():

# Horovod: adjust learning rate based on number of GPUs.
scaled_lr = 0.001 * hvd.size()
opt = optimizers.Adam(scaled_lr)
opt = tf.optimizers.Adam(scaled_lr)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this PR, but a question about tf.keras and keras compatibility in horovod:

  • After tf 2.6, tf.keras == keras.
  • Before tf 2.6, this might be an issue?
  • Or our CI already dropped support for TF2 versions < 2.6?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know exactly about the test matrix of our CI.
Technically it should work for tf < 2.6 version when using keras package. I maintained the code path for pure keras (older version of TF and Keras).


# Horovod: add Horovod DistributedOptimizer.
opt = hvd.DistributedOptimizer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@
import horovod.tensorflow.keras as hvd
from horovod.tensorflow.data.compute_service import TfDataServiceConfig

from packaging import version
if version.parse(tf.keras.__version__.replace("-tf", "+tf")) < version.parse("2.11"):
from tensorflow.keras import optimizers
else:
from tensorflow.keras.optimizers import legacy as optimizers

# arguments reuse_dataset and round_robin only used when single dispatcher is present
def train_fn(compute_config: TfDataServiceConfig, reuse_dataset: bool = False, round_robin: bool = False):
Expand Down Expand Up @@ -69,7 +64,7 @@ def train_fn(compute_config: TfDataServiceConfig, reuse_dataset: bool = False, r

# Horovod: adjust learning rate based on number of GPUs.
scaled_lr = 0.001 * hvd.size()
opt = optimizers.Adam(scaled_lr)
opt = tf.optimizers.Adam(scaled_lr)

# Horovod: add Horovod DistributedOptimizer.
opt = hvd.DistributedOptimizer(
Expand Down
64 changes: 35 additions & 29 deletions horovod/_keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,19 @@
from horovod.tensorflow.gradient_aggregation import LocalGradientAggregationHelper
from horovod.tensorflow.gradient_aggregation_eager import LocalGradientAggregationHelperEager
from horovod.tensorflow.mpi_ops import rank, size_op
from horovod.common.util import support_non_legacy_keras_optimizers


_PRE_TF_2_4_0 = version.parse(tf.__version__) < version.parse('2.4.0')
_IS_TF2 = version.parse(tf.__version__) >= version.parse('2.0.0')


def get_keras_optimizer_base_type(k):
if support_non_legacy_keras_optimizers(k):
return k.optimizers.Optimizer
else:
return tf.keras.optimizers.legacy.Optimizer


def check_keras_optimizer_type(k, optimizer):
if not support_non_legacy_keras_optimizers(k):
if not isinstance(optimizer, tf.keras.optimizers.legacy.Optimizer):
raise ValueError(f"Optimizer has to be an instance of tensorflow.keras.optimizers.legacy.Optimizer starting from Keras 2.11: {type(optimizer).__name__}")


def create_distributed_optimizer(keras, optimizer, name, device_dense, device_sparse,
compression, sparse_as_dense, gradient_predivide_factor,
op, backward_passes_per_step=1,
average_aggregated_gradients=False,
groups=None, process_set=hvd.global_process_set,
scale_local_gradients=True):
check_keras_optimizer_type(keras, optimizer)

class _DistributedOptimizer(get_keras_optimizer_base_type(keras)):
class _DistributedOptimizer(*optimizer.__class__.__bases__):
_HAS_AGGREGATE_GRAD = True

def __init__(self, **kwargs):
Expand Down Expand Up @@ -94,6 +79,11 @@ def __init__(self, **kwargs):
scale_local_gradients=scale_local_gradients
)

def variables(self):
if _IS_TF2:
return super(self.__class__, self).variables()
return self.get_weights()

def register_local_var(self, var):
"""Registers a source/variable as worker local. Horovod will not perform any global
operations on gradients corresponding to these sources and will instead return the local
Expand All @@ -105,6 +95,9 @@ def register_local_var(self, var):
else:
self._local_vars.add(var)

def compute_gradients(self, loss, var_list, tape=None):
return self._compute_gradients(loss, var_list, None, tape)

def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
"""
Compute gradients of all trainable variables.
Expand All @@ -114,17 +107,25 @@ def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
In DistributedOptimizer, get_gradients() is overriden to also
allreduce the gradients before returning them.
"""
base_class = super(self.__class__, self)
if _PRE_TF_2_4_0:
return super(self.__class__, self)._compute_gradients(
return base_class._compute_gradients(
loss, var_list, grad_loss, tape)

tape = tf.GradientTape() if tape is None else tape
grads_and_vars = super(self.__class__, self)._compute_gradients(
# pylint: disable=protected-access
loss,
var_list,
grad_loss,
tape=tape)
if hasattr(base_class, '_compute_gradients'):
grads_and_vars = base_class._compute_gradients(
# pylint: disable=protected-access
loss,
var_list,
grad_loss,
tape=tape)
else:
grads_and_vars = base_class.compute_gradients(
# pylint: disable=protected-access
loss,
var_list,
tape=tape)
grads, weights = list(zip(*grads_and_vars))

allreduced_grads = self._allreduce(grads, weights)
Expand All @@ -143,13 +144,15 @@ def get_gradients(self, loss, params):
return self._allreduce(gradients, params)

def _aggregate_gradients(self, grads_and_vars):
base_class = super(self.__class__, self)
if _PRE_TF_2_4_0:
grads, vars = list(zip(*grads_and_vars))
aggregated_grads = self._allreduce(grads, vars)
return aggregated_grads
elif hasattr(base_class, '_aggregate_gradients'):
return base_class._aggregate_gradients(grads_and_vars)
else:
return super(self.__class__, self)._aggregate_gradients(
grads_and_vars)
return base_class.aggregate_gradients(grads_and_vars)

def _allreduce(self, grads, vars):
self._aggregated_gradients = True
Expand Down Expand Up @@ -278,12 +281,15 @@ def reducescatter(backend, value, name, op):
return _eval(backend, hvd.reducescatter(tf.constant(value, name=name), op=op))


def load_model(keras, wrap_optimizer, optimizer_modules, filepath, custom_optimizers, custom_objects):
keras_subclasses = get_keras_optimizer_base_type(keras).__subclasses__()
def load_model(keras, wrap_optimizer, filepath, custom_optimizers, custom_objects, legacy_opts=False):
if legacy_opts:
keras_subclasses = keras.optimizers.legacy.Optimizer.__subclasses__()
else:
keras_subclasses = keras.optimizers.Optimizer.__subclasses__()

horovod_objects = {
subclass.__name__.lower(): wrap_optimizer(subclass)
for subclass in keras_subclasses
if subclass.__module__ in optimizer_modules
}

if custom_optimizers is not None:
Expand Down
5 changes: 0 additions & 5 deletions horovod/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import warnings

from contextlib import contextmanager
from packaging import version

from horovod.common.exceptions import get_version_mismatch_message, HorovodVersionMismatchError

Expand Down Expand Up @@ -287,7 +286,3 @@ def is_version_greater_equal_than(ver, target):
"of: major.minor.patch. Received: {}".format(target))

return version.parse(ver) >= version.parse(target)


def support_non_legacy_keras_optimizers(k):
return version.parse(k.__version__.replace("-tf", "+tf")) < version.parse("2.11")
6 changes: 3 additions & 3 deletions horovod/keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def reducescatter(value, name=None, op=Average):
return _impl.reducescatter(K, value, name, op)


def load_model(filepath, custom_optimizers=None, custom_objects=None, compression=Compression.none):
def load_model(filepath, custom_optimizers=None, custom_objects=None, compression=Compression.none, legacy_opts=False):
"""
Loads a saved Keras model with a Horovod DistributedOptimizer.

Expand All @@ -272,6 +272,7 @@ def load_model(filepath, custom_optimizers=None, custom_objects=None, compressio
compression: Compression algorithm used to reduce the amount of data
sent and received by each worker node. Defaults to not
using compression.
legacy_opts: If True, model uses tf.keras.optimizers.legacy.* optimizers

Returns:
A Keras model instance.
Expand All @@ -282,5 +283,4 @@ def load_model(filepath, custom_optimizers=None, custom_objects=None, compressio
"""
def wrap_optimizer(cls):
return lambda **kwargs: DistributedOptimizer(cls(**kwargs), compression=compression)
optimizer_modules = {_impl.get_keras_optimizer_base_type(keras).__module__}
return _impl.load_model(keras, wrap_optimizer, optimizer_modules, filepath, custom_optimizers, custom_objects)
return _impl.load_model(keras, wrap_optimizer, filepath, custom_optimizers, custom_objects, legacy_opts)
2 changes: 1 addition & 1 deletion horovod/spark/keras/bare.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def get_json_type(obj):
},
}, default=get_json_type).encode('utf8')

symbolic_weights = getattr(optimizer, 'weights')
symbolic_weights = optimizer.variables()
if symbolic_weights:
optimizer_weights_group = h5py_file['optimizer_weights']
weight_values = K.batch_get_value(symbolic_weights)
Expand Down
26 changes: 13 additions & 13 deletions horovod/spark/keras/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from horovod.spark.keras.util import TFKerasUtil
from horovod.spark.keras.datamodule import PetastormDataModule

from horovod._keras import check_keras_optimizer_type

class KerasEstimatorParamsWriter(HorovodParamsWriter):
def saveImpl(self, path):
Expand All @@ -52,7 +51,7 @@ def write(self):

class KerasEstimatorParamsReader(HorovodParamsReader):
def _deserialize_dict(self, dict):
def _param_deserializer_fn(name, param_val, keras_utils, custom_objects):
def _param_deserializer_fn(name, param_val, keras_utils, custom_objects, model=None):
if param_val is None:
return param_val

Expand All @@ -65,7 +64,7 @@ def load_model_fn(x):
load_model_fn=load_model_fn)
elif name == KerasEstimator.optimizer.name:
opt_base64_encoded = codec.loads_base64(param_val)
return keras_utils.deserialize_optimizer(opt_base64_encoded)
return keras_utils.deserialize_optimizer(opt_base64_encoded, model=model)
else:
return codec.loads_base64(param_val)

Expand All @@ -77,8 +76,15 @@ def load_model_fn(x):
dict[KerasEstimator.custom_objects.name],
None, None)

model = None
model_name = EstimatorParams.model.name
if model_name in dict:
model = _param_deserializer_fn(model_name, dict[model_name], TFKerasUtil, custom_objects)

for key, val in dict.items():
dict[key] = _param_deserializer_fn(key, val, TFKerasUtil, custom_objects)
if key == model_name:
dict[model_name] = model
dict[key] = _param_deserializer_fn(key, val, TFKerasUtil, custom_objects, model)
return dict


Expand Down Expand Up @@ -225,14 +231,6 @@ def _get_keras_utils(self):
if not isinstance(model, tf.keras.Model):
raise ValueError(
"model has to be an instance of tensorflow.keras.Model")

optimizer = self.getOptimizer()
if optimizer:
if isinstance(optimizer, str):
pass
else:
check_keras_optimizer_type(tf.keras, optimizer)

return TFKerasUtil

def setCustomObjects(self, value):
Expand Down Expand Up @@ -328,7 +326,7 @@ def _compile_model(self, keras_utils):

metrics = self.getMetrics()
gradient_compression = self.getGradientCompression()
optimizer_weight_values = optimizer.get_weights()
optimizer_weight_values = optimizer.variables()

dist_optimizer_args = dict(optimizer=optimizer)
if gradient_compression:
Expand All @@ -342,6 +340,8 @@ def _compile_model(self, keras_utils):
metrics=metrics)

if optimizer_weight_values:
if hasattr(model.optimizer, 'build'):
model.optimizer.build(model.trainable_weights)
model.optimizer.set_weights(optimizer_weight_values)

return keras_utils.serialize_model(model)
Expand Down
26 changes: 6 additions & 20 deletions horovod/spark/keras/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,10 @@
from packaging import version
from horovod.runner.common.util import codec

from horovod._keras import get_keras_optimizer_base_type

def serialize_bare_keras_optimizer(x):
import keras
from horovod.spark.keras.bare import save_bare_keras_optimizer

optimizer_class = get_keras_optimizer_base_type(keras)

return _serialize_keras_optimizer(x,
optimizer_class=optimizer_class,
save_optimizer_fn=save_bare_keras_optimizer)


Expand All @@ -40,43 +34,35 @@ def deserialize_bare_keras_optimizer(x):


def serialize_tf_keras_optimizer(x):
import tensorflow as tf
from horovod.spark.keras.tensorflow import save_tf_keras_optimizer

optimizer_class = get_keras_optimizer_base_type(tf.keras)

return _serialize_keras_optimizer(x,
optimizer_class=optimizer_class,
save_optimizer_fn=save_tf_keras_optimizer)


def deserialize_tf_keras_optimizer(x):
def deserialize_tf_keras_optimizer(x, model=None):
from horovod.spark.keras.tensorflow import load_tf_keras_optimizer

return _deserialize_keras_optimizer(x,
return _deserialize_keras_optimizer(x, model,
load_keras_optimizer_fn=load_tf_keras_optimizer)


def _serialize_keras_optimizer(opt, optimizer_class, save_optimizer_fn):
def _serialize_keras_optimizer(opt, save_optimizer_fn):
if isinstance(opt, str):
return opt
elif isinstance(opt, optimizer_class):
else:
bio = io.BytesIO()
with h5py.File(bio, 'w') as f:
save_optimizer_fn(opt, f)
return codec.dumps_base64(bio.getvalue())
else:
raise \
ValueError(f'Keras optimizer has to be an instance of str or {optimizer_class}')


def is_string(obj):
return isinstance(obj, str)


def _deserialize_keras_optimizer(serialized_opt, load_keras_optimizer_fn):
def _deserialize_keras_optimizer(serialized_opt, model, load_keras_optimizer_fn):
if is_string(serialized_opt):
return serialized_opt
bio = io.BytesIO(serialized_opt)
with h5py.File(bio, 'r') as f:
return load_keras_optimizer_fn(f)
return load_keras_optimizer_fn(f, model=model)
6 changes: 4 additions & 2 deletions horovod/spark/keras/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def save_tf_keras_optimizer(optimizer, h5py_file):
default=serialization.get_json_type).encode('utf8')

# Save optimizer weights.
symbolic_weights = getattr(optimizer, 'weights')
symbolic_weights = optimizer.variables()
if symbolic_weights:
optimizer_weights_group = h5py_file.create_group('optimizer_weights')
weight_values = K.batch_get_value(symbolic_weights)
Expand All @@ -79,7 +79,7 @@ def save_tf_keras_optimizer(optimizer, h5py_file):
h5py_file.flush()


def load_tf_keras_optimizer(h5py_file, custom_objects=None):
def load_tf_keras_optimizer(h5py_file, custom_objects=None, model=None):
if not custom_objects:
custom_objects = {}

Expand Down Expand Up @@ -125,5 +125,7 @@ def convert_custom_objects(obj):
optimizer_weight_values = [optimizer_weights_group[n].value for n in
optimizer_weight_names]
if optimizer_weight_values:
if hasattr(optimizer, 'build'):
optimizer.build(model.trainable_weights)
optimizer.set_weights(optimizer_weight_values)
return optimizer
Loading
0