Open
Description
I am trying to train Go using this wonderful project and GymGo.
However, training.py prints an error at line 103.
https://github.com/intenseG/MuZero/blob/5cab8b6c89b652ad33e277318bfd4389de1315a9/muzero/training/training.py#L103
Is there a solution to this error?
Thank you.
[Additional debugging information]
value_batch.shape
: (2048, 6, 9, 24)
targets.shape
: (2048, 24)
[Stack trace]
Traceback (most recent call last):
File "muzero.py", line 36, in <module>
muzero(config)
File "muzero.py", line 24, in muzero
train_network(config, storage, replay_buffer, config.nb_epochs)
File "C:\Users\inten\Desktop\ML\MuZero\muzero\training\training.py", line 19, in train_network
update_weights(optimizer, network, batch)
File "C:\Users\inten\Desktop\ML\MuZero\muzero\training\training.py", line 84, in update_weights
optimizer.minimize(loss=loss, var_list=network.cb_get_variables())
File "C:\Users\inten\Anaconda3\envs\muzero\lib\site-packages\tensorflow_core\python\keras\optimizer_v2\optimizer_v2.py", line 317, in minimize
loss, var_list=var_list, grad_loss=grad_loss)
File "C:\Users\inten\Anaconda3\envs\muzero\lib\site-packages\tensorflow_core\python\keras\optimizer_v2\optimizer_v2.py", line 351, in _compute_gradients
loss_value = loss()
File "C:\Users\inten\Desktop\ML\MuZero\muzero\training\training.py", line 42, in loss
loss += tf.math.reduce_mean(loss_value(target_value_batch, value_batch, network.value_support_size))
File "C:\Users\inten\Desktop\ML\MuZero\muzero\training\training.py", line 103, in loss_value
return tf.nn.softmax_cross_entropy_with_logits(logits=value_batch, labels=targets)
File "C:\Users\inten\Anaconda3\envs\muzero\lib\site-packages\tensorflow_core\python\ops\nn_ops.py", line 3105, in softmax_cross_entropy_with_logits_v2
labels=labels, logits=logits, axis=axis, name=name)
File "C:\Users\inten\Anaconda3\envs\muzero\lib\site-packages\tensorflow_core\python\util\deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "C:\Users\inten\Anaconda3\envs\muzero\lib\site-packages\tensorflow_core\python\ops\nn_ops.py", line 3206, in softmax_cross_entropy_with_logits_v2_helper
precise_logits, labels, name=name)
File "C:\Users\inten\Anaconda3\envs\muzero\lib\site-packages\tensorflow_core\python\ops\gen_nn_ops.py", line 11458, in softmax_cross_entropy_with_logits
_six.raise_from(_core._status_to_exception(e.code, message), None)
File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: logits and labels must be broadcastable: logits_size=[110592,24] labels_size=[2048,24] [Op:SoftmaxCrossEntropyWithLogits] name: softmax_cross_entropy_with_logits/
Metadata
Metadata
Assignees
Labels
No labels