Open
Description
Environment:
- Framework: TensorFlow
- Framework version: 2.4.0
- Horovod version: 0.21.1
- MPI version: openmpi-4.0.5
- CUDA version: 11.0
- NCCL version: nccl_2.8.3-1+cuda11.0_x86_64
- Python version: 3.8.5
- Spark / PySpark version: NA
- OS and version: 5.3.0-62-generic To run on 4 machines with 1 GPUs each using Open MPI #56~18.04.1-Ubuntu
- GCC version: 7.5.0
- CMake version: 3.18.2
I've run into errors when trying to XLA compile my Tensorflow train/test steps. In my custom model, if I use
@tf.function(jit_compile=True)
def train_step(...):
to force compilation of the training operations I can run successfully without Horovod with 1 process. Then when I try to run with Horovod, I receive errors like:
The op is created at:
File "main.py", line 368, in <module>
main()
File "main.py", line 188, in main
epoch_loop.one_train_epoch(config,trainds,net,
File "/gpfs/mira-home/parton/git/atlas_dgcnn/epoch_loop.py", line 9, in one_train_epoch
return one_epoch(config,dataset,net,train_step,loss_func,opt,epoch_num,tbwriter,batches_per_epoch,True)
File "/gpfs/mira-home/parton/git/atlas_dgcnn/epoch_loop.py", line 74, in one_epoch
loss_value,logits = step_func(net,loss_func,inputs,labels,weights,opt,first_batch,hvd)
File "/gpfs/mira-home/parton/git/atlas_dgcnn/epoch_loop.py", line 244, in train_step
if hvd and first_batch:
File "/gpfs/mira-home/parton/git/atlas_dgcnn/epoch_loop.py", line 246, in train_step
hvd.broadcast_variables(opt.variables(), root_rank=root_rank)
File "/home/parton/.local/lib/python3.8/site-packages/horovod/tensorflow/functions.py", line 56, in broadcast_variables
return broadcast_group(variables, root_rank)
File "/tmp/tmp4h_gfbt8.py", line 53, in broadcast_group
retval__2 = ag__.converted_call(ag__.ld(tf).group, tuple([ag__.converted_call(ag__.ld(var).assign, (ag__.converted_call(ag__.ld(broadcast), (ag__.ld(var), ag__.ld(root_rank)), None, fscope_2),), None, fscope_2) for var in ag__.ld(variables)]), None, fscope_2)
File "/tmp/tmp4h_gfbt8.py", line 53, in <listcomp>
retval__2 = ag__.converted_call(ag__.ld(tf).group, tuple([ag__.converted_call(ag__.ld(var).assign, (ag__.converted_call(ag__.ld(broadcast), (ag__.ld(var), ag__.ld(root_rank)), None, fscope_2),), None, fscope_2) for var in ag__.ld(variables)]), None, fscope_2)
File "/home/parton/.local/lib/python3.8/site-packages/horovod/tensorflow/mpi_ops.py", line 251, in broadcast
return MPI_LIB.horovod_broadcast(tensor, name=name, root_rank=root_rank,
File "<string>", line 423, in horovod_broadcast
HorovodBroadcast_Adam_dgcnn_conv_bn_layer_12_batch_normalization_14_beta_v_0: unsupported op: No registered 'HorovodBroadcast' OpKernel for XLA_GPU_JIT devices compatible with node {{node HorovodBroadcast_Adam_dgcnn_conv_bn_layer_12_batch_normalization_14_beta_v_0}}
You can see my code here:
https://github.com/jtchilders/atlas_dgcnn
I can reproduce the issue using your example:
examples/tensorflow2_mnist.py
by simply changing @tf.function
to @tf.function(jit_compile=True)
And if I run
mpirun -n $RANKS -npernode $PPN python tensorflow2_mnist.py
I see a similar error like this:
The op is created at:
File "tensorflow2_mnist.py", line 84, in <module>
loss_value = training_step(images, labels, batch == 0)
File "tensorflow2_mnist.py", line 75, in training_step
if first_batch:
File "tensorflow2_mnist.py", line 77, in training_step
hvd.broadcast_variables(opt.variables(), root_rank=0)
File "/home/parton/.local/lib/python3.8/site-packages/horovod/tensorflow/functions.py", line 56, in broadcast_variables
return broadcast_group(variables, root_rank)
File "/tmp/tmp_15ypkpb.py", line 53, in broadcast_group
retval__2 = ag__.converted_call(ag__.ld(tf).group, tuple([ag__.converted_call(ag__.ld(var).assign, (ag__.converted_call(ag__.ld(broadcast), (ag__.ld(var), ag__.ld(root_rank)), None, fscope_2),), None, fscope_2) for var in ag__.ld(variables)]), None, fscope_2)
File "/tmp/tmp_15ypkpb.py", line 53, in <listcomp>
retval__2 = ag__.converted_call(ag__.ld(tf).group, tuple([ag__.converted_call(ag__.ld(var).assign, (ag__.converted_call(ag__.ld(broadcast), (ag__.ld(var), ag__.ld(root_rank)), None, fscope_2),), None, fscope_2) for var in ag__.ld(variables)]), None, fscope_2)
File "/home/parton/.local/lib/python3.8/site-packages/horovod/tensorflow/mpi_ops.py", line 251, in broadcast
return MPI_LIB.horovod_broadcast(tensor, name=name, root_rank=root_rank,
File "<string>", line 423, in horovod_broadcast [Op:__inference_training_step_933]