8000 TensorFlow: scale the gradients of local variables by MrAta · Pull Request #3719 · horovod/horovod · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

TensorFlow: scale the gradients of local variables #3719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/keras.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ Horovod supports Keras and regular TensorFlow in similar ways. To use Horovod wi

The distributed optimizer delegates gradient computation to the original optimizer, averages gradients using *allreduce* or *allgather*, and then applies those averaged gradients.

**Note:** For model parallel usecases there are local variables (layers) that their gradients need not to be synced (by allreduce or allgather). You can register those variables with the returned wrapper optimizer by calling its ``register_local_var()`` API.

.. raw:: html

<p/>
Expand All @@ -56,6 +58,8 @@ Horovod supports Keras and regular TensorFlow in similar ways. To use Horovod wi

This is necessary to ensure consistent initialization of all workers when training is started with random weights or restored from a checkpoint.

**Note:** For model parallel use cases there are local variables (layers) that their weights need not to be broadcasted. You can pass those local variables to this callback by adding ``hvd.callbacks.BroadcastGlobalVariablesCallback(0, local_variables=[list of local variables])`` instead.

.. raw:: html

<p/>
Expand Down
2 changes: 2 additions & 0 deletions docs/tensorflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ To use Horovod with TensorFlow, make the following modifications to your trainin

For **TensorFlow v2**, when using a ``tf.GradientTape``, wrap the tape in ``hvd.DistributedGradientTape`` instead of wrapping the optimizer.

**Note:** For model parallel usecases there are local variables (layers) that their gradients need not to be synced (by allreduce or allgather). You can register those variables with the returned wrapper optimizer by calling its ``register_local_var()`` API. Additionally, when using ``tf.GradientTape``, wrap the tape in ``hvd.PartialDistributedGradientTape`` instead of ``DistributedGradientTape`` and pass the local layers to it in order to register their local variables.

.. raw:: html

<p/>
Expand Down
27 changes: 25 additions & 2 deletions horovod/_keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# limitations under the License.
# ==============================================================================

import os
from packaging import version

import horovod.tensorflow as hvd
import tensorflow as tf
from horovod.tensorflow.gradient_aggregation import LocalGradientAggregationHelper
from horovod.tensorflow.gradient_aggregation_eager import LocalGradientAggregationHelperEager
from horovod.tensorflow.mpi_ops import rank
from horovod.tensorflow.mpi_ops import rank, size_op


_PRE_TF_2_4_0 = version.parse(tf.__version__) < version.parse('2.4.0')
Expand All @@ -30,7 +31,8 @@ def create_distributed_optimizer(keras, optimizer, name, device_dense, device_sp
compression, sparse_as_dense, gradient_predivide_factor,
op, backward_passes_per_step=1,
average_aggregated_gradients=False,
groups=None, process_set=hvd.global_process_set):
groups=None, process_set=hvd.global_process_set,
scale_local_gradients=True):
class _DistributedOptimizer(keras.optimizers.Optimizer):
_HAS_AGGREGATE_GRAD = True

Expand All @@ -52,6 +54,8 @@ def __init__(self, **kwargs):
process_set=process_set)

self._local_vars = set()
self.process_set = process_set
self.scale_local_gradients = scale_local_gradients
self._agg_helper = None
if backward_passes_per_step > 1:
if hvd._executing_eagerly():
Expand All @@ -60,6 +64,8 @@ def __init__(self, **kwargs):
allreduce_func=self._allreduce_grads,
sparse_as_dense=sparse_as_dense,
average_aggregated_gradients=average_aggregated_gradients,
process_set=process_set,
scale_local_gradients=scale_local_gradients
)
else:
self._agg_helper = LocalGradientAggregationHelper(
Expand All @@ -69,6 +75,8 @@ def __init__(self, **kwargs):
average_aggregated_gradients=average_aggregated_gradients,
rank=rank(),
optimizer_type=LocalGradientAggregationHelper._OPTIMIZER_TYPE_KERAS,
process_set=process_set,
scale_local_gradients=scale_local_gradients
)

def register_local_var(self, var):
Expand Down Expand Up @@ -151,13 +159,28 @@ def __filtered_reduce_grads(grads, vars):
rg.append(grad)

rg = self._allreduce_grads(rg, rv)
horovod_size = size_op(process_set_id=self.process_set.process_set_id) if int(os.environ.get("HOROVOD_ELASTIC", 0)) else self.process_set.size()
if _IS_TF2:
for rv, rg in zip(rv, rg):
v2g[rv.ref()] = rg

if self.scale_local_gradients and len(self._local_vars):
# Scale local gradients by a size factor. See pull/3695 and discussions/3705 for context.
for v_ref in v2g:
if v_ref in self._local_vars and v2g[v_ref] is not None:
v2g[v_ref] /= horovod_size

return [v2g[rv.ref()] for rv in vars]
else:
for rv, rg in zip(rv, rg):
v2g[rv] = rg

if self.scale_local_gradients and len(self._local_vars):
# Scale local gradients by a size factor. See pull/3695 and discussions/3705 for context.
for v in v2g:
if v in self._local_vars and v2g[v] is not None:
v2g[v] /= horovod_size

return [v2g[rv] for rv in vars]
return __filtered_reduce_grads(grads, vars)

Expand Down
59 changes: 50 additions & 9 deletions horovod/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ def __init__(self, optimizer, name=None, use_locking=False, device_dense='',
device_sparse='', compression=Compression.none,
sparse_as_dense=False, op=Average, gradient_predivide_factor=1.0,
backward_passes_per_step=1, average_aggregated_gradients=False,
groups=None, process_set=global_process_set):
groups=None, process_set=global_process_set, scale_local_gradients=True):
if name is None:
name = "Distributed{}".format(type(optimizer).__name__)
super(_DistributedOptimizer, self).__init__(name=name, use_locking=use_locking)
Expand All @@ -607,6 +607,8 @@ def __init__(self, optimizer, name=None, use_locking=False, device_dense='',
gradient_predivide_factor, groups, process_set=process_set)

self._local_vars = set()
self.process_set = process_set
self.scale_local_gradients = scale_local_gradients
self._agg_helper = None
if backward_passes_per_step > 1:
if _executing_eagerly():
Expand All @@ -622,6 +624,8 @@ def __init__(self, optimizer, name=None, use_locking=False, device_dense='',
average_aggregated_gradients=average_aggregated_gradients,
rank=rank(),
optimizer_type=LocalGradientAggregationHelper._OPTIMIZER_TYPE_LEGACY,
process_set=process_set,
scale_local_gradients=scale_local_gradients
)

def register_local_var(self, var):
Expand Down Expand Up @@ -665,13 +669,28 @@ def _filtered_reduce_grads(grads, vars):
rg.append(grad)

rg = self._allreduce_grads(rg, rv)
horovod_size = size_op(process_set_id=self.process_set.process_set_id) if int(os.environ.get("HOROVOD_ELASTIC", 0)) else self.process_set.size()
if _IS_TF2:
for rv,rg in zip(rv, rg):
v2g[rv.ref()] = rg

if self.scale_local_gradients and len(self._local_vars):
# Scale local gradients by a size factor. See pull/3695 and discussions/3705 for context.
for v_ref in v2g:
if v_ref in self._local_vars and v2g[v_ref]:
v2g[v.ref()] /= horovod_size

return [v2g[rv.ref()] for rv in vars]
else:
for rv, rg in zip(rv, rg):
v2g[rv] = rg< 9E81 /td>

if self.scale_local_gradients and len(self._local_vars):
# Scale local gradients by a size factor. See pull/3695 and discussions/3705 for context.
for v in v2g:
if v in self._local_vars and v2g[v]:
v2g[v] /= horovod_size

return [v2g[rv] for rv in vars]

avg_grads = _filtered_reduce_grads(grads, vars)
Expand Down Expand Up @@ -805,7 +824,8 @@ def DistributedOptimizer(optimizer, name=None, use_locking=False, device_dense='
sparse_as_dense=False, backward_passes_per_step=1,
op=Average, gradient_predivide_factor=1.0,
average_aggregated_gradients=False,
num_groups=0, groups=None, process_set=global_process_set):
num_groups=0, groups=None,
process_set=global_process_set, scale_local_gradients=True):
"""Construct a new DistributedOptimizer, which uses another optimizer
under the hood for computing single-process gradient values and
applying gradient updates after the gradient values have been combined
Expand Down Expand Up @@ -865,6 +885,7 @@ def DistributedOptimizer(optimizer, name=None, use_locking=False, device_dense='
Defaults as None, which is no explicit groups.
process_set: Gradients will only be reduced over Horovod processes belonging
to this process set. Defaults to the global process set.
scale_local_gradients: Whether to scale the gradients of local variables. Default is set to True.
"""
if gradient_predivide_factor != 1.0:
if rocm_built():
Expand Down Expand Up @@ -907,6 +928,7 @@ def DistributedOptimizer(optimizer, name=None, use_locking=False, device_dense='
average_aggregated_gradients=average_aggregated_gradients,
groups=groups,
process_set=process_set,
scale_local_gradients=scale_local_gradients
)
elif isinstance(optimizer, tf.keras.optimizers.Optimizer):
if op == Adasum:
Expand All @@ -924,6 +946,7 @@ def DistributedOptimizer(optimizer, name=None, use_locking=False, device_dense='
backward_passes_per_step=backward_passes_per_step,
average_aggregated_gradients=average_aggregated_gradients,
process_set=process_set,
scale_local_gradients=scale_local_gradients
)
else:
raise ValueError('Provided optimizer doesn\'t inherit from either legacy '
Expand All @@ -934,7 +957,7 @@ def DistributedOptimizer(optimizer, name=None, use_locking=False, device_dense='
class _DistributedGradientTape(tf.GradientTape):
def __init__(self, tape, device_dense, device_sparse, compression, sparse_as_dense, op,
gradient_predivide_factor, groups, persistent=False,
watch_accessed_variables=True, process_set=global_process_set):
watch_accessed_variables=True, process_set=global_process_set, scale_local_gradients=True):
if hasattr(tape, '_watch_accessed_variables'):
super(self.__class__, self).__init__(persistent, watch_accessed_variables)
else:
Expand All @@ -945,8 +968,11 @@ def __init__(self, tape, device_dense, device_sparse, compression, sparse_as_den
'DistributedGradientTape', device_dense, device_sparse, compression,
sparse_as_dense, op, gradient_predivide_factor, groups, process_set)

self.process_set = process_set
self.scale_local_gradients = scale_local_gradients
self._local_sources = set()


def register_local_source(self, source):
"""Registers a source/variable as worker local. Horovod will not perform any global
operations on gradients corresponding to these sources and will instead return the local
Expand Down Expand Up @@ -977,23 +1003,35 @@ def gradient(self, target, sources, output_gradients=None, use_generic_names=Fal

# Reduce grads
rg = self._allreduce_grads(rg, rs, use_generic_names)

horovod_size = size_op(process_set_id=self.process_set.process_set_id) if int(os.environ.get("HOROVOD_ELASTIC", 0)) else self.process_set.size()
# Replace dict entries with reduced grads
if _IS_TF2:
for rs, rg in zip(rs, rg):
s2g[rs.ref()] = rg

if self.scale_local_gradients and len(self._local_sources):
# Scale local gradients by a size factor. See pull/3695 and discussions/3705 for context.
for s_ref in s2g:
if s_ref in self._local_sources and s2g[s_ref] is not None:
s2g[s_ref] /= horovod_size

return [s2g[s.ref()] for s in sources]
else:
for rs, rg in zip(rs, rg):
s2g[rs] = rg

if self.scale_local_gradients and len(self._local_sources):
# Scale local gradients by a size factor. See pull/3695 and discussions/3705 for context.
for s in s2g:
if s in self._local_sources and s2g[s] is not None:
s2g[s] /= horovod_size

return [s2g[s] for s in sources]

def DistributedGradientTape(gradtape, device_dense='', device_sparse='',
compression=Compression.none, sparse_as_dense=False,
op=Average, gradient_predivide_factor=1.0,
num_groups=0, groups=None, process_set=global_process_set):
num_groups=0, groups=None, process_set=global_process_set, scale_local_gradients=True):
"""A tape that wraps another tf.GradientTape, using an allreduce to
combine gradient values before applying gradients to model weights.

Expand Down Expand Up @@ -1036,6 +1074,7 @@ def DistributedGradientTape(gradtape, device_dense='', device_sparse='',
Defaults as None, which is no explicit groups.
process_set: Gradients will only be reduced over Horovod processes belonging
to this process set. Defaults to the global process set.
scale_local_gradients: Whether to scale the gradients of local variables. Default is set to True.
"""
if gradient_predivide_factor != 1.0:
if rocm_built():
Expand All @@ -1056,21 +1095,23 @@ def DistributedGradientTape(gradtape, device_dense='', device_sparse='',

cls = type(gradtape.__class__.__name__, (gradtape.__class__,),
dict(_DistributedGradientTape.__dict__))

if hasattr(gradtape, '_watch_accessed_variables'):
return cls(gradtape._tape, device_dense, device_sparse, compression,
sparse_as_dense, op, gradient_predivide_factor, groups,
gradtape._persistent, gradtape._watch_accessed_variables,
process_set=process_set)
process_set=process_set, scale_local_gradients=scale_local_gradients)
else:
return cls(gradtape._tape, device_dense, device_sparse, compression,
sparse_as_dense, op, gradient_predivide_factor, groups,
gradtape._persistent, process_set=process_set)
gradtape._persistent, process_set=process_set, scale_local_gradients=scale_local_gradients)


def PartialDistributedGradientTape(gradtape, device_dense='', device_sparse='',
compression=Compression.none, sparse_as_dense=False,
op=Average, gradient_predivide_factor=1.0,
num_groups=0, groups=None, process_set=global_process_set, local_layers=None):
num_groups=0, groups=None, process_set=global_process_set,
local_layers=None, scale_local_gradients=True):
"""A tape that wraps another tf.GradientTape, using an allreduce to
combine gradient values before applying gradients to model weights similar to
DistributedGradientTape execpt it skips allreducing gradients of the local layers
Expand Down Expand Up @@ -1099,7 +1140,7 @@ def PartialDistributedGradientTape(gradtape, device_dense='', device_sparse='',
_tape = DistributedGradientTape(gradtape, device_dense, device_sparse,
compression, sparse_as_dense,
op, gradient_predivide_factor,
num_groups, groups, process_set)
num_groups, groups, process_set, scale_local_gradients)
for var in local_vars:
_tape.register_local_source(var)
return _tape
25 changes: 24 additions & 1 deletion horovod/tensorflow/gradient_aggregation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import os
import tensorflow as tf
from packaging import version
from horovod.tensorflow.mpi_ops import size_op
from horovod.tensorflow.mpi_ops import global_process_set


_IS_TF2 = version.parse(tf.__version__) >= version.parse('2.0.0')

Expand Down Expand Up @@ -33,7 +37,9 @@ def __init__(
sparse_as_dense,
average_aggregated_gradients,
rank,
optimizer_type):
optimizer_type,
process_set=global_process_set,
scale_local_gradients=True):
self._allreduce_grads = allreduce_func

# backward_passes_per_step controls how often gradient updates are
Expand Down Expand Up @@ -66,6 +72,8 @@ def __init__(
self.not_none_indexes = {}
self.num_none_grad_updates = 0

self.process_set = process_set
self.scale_local_gradients = scale_local_gradients
self._local_vars = set()

def register_local_var(self, var):
Expand Down Expand Up @@ -178,13 +186,28 @@ def __filtered_reduce_grads(grads, vars):
rg.append(grad)

rg = self._allreduce_grads(rg, rv)
horovod_size = size_op(process_set_id=self.process_set.process_set_id) if int(os.environ.get("HOROVOD_ELASTIC", 0)) else self.process_set.size()
if _IS_TF2:
for rv, rg in zip(rv, rg):
v2g[rv.ref()] = rg

if self.scale_local_gradients and len(self._local_vars):
# Scale local gradients by a size factor. See pull/3695 and discussions/3705 for context.
for v_ref in v2g:
if v_ref in self._local_vars and v2g[v_ref] is not None:
v2g[v_ref] /= horovod_size

return [v2g[rv.ref()] for rv in vars]
else:
for rv, rg in zip(rv, rg):
v2g[rv] = rg

if self.scale_local_gradients and len(self._local_vars):
# Scale local gradients by a size factor. See pull/3695 and discussions/3705 for context.
for v in v2g:
if v in self._local_vars and v2g[v] is not None:
v2g[v] /= horovod_size

return [v2g[rv] for rv in vars]

# Read in latest variables values.
Expand Down
Loading
0