8000 About training of GoGame using GymGo · Issue #1 · johan-gras/MuZero · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
About training of GoGame using GymGo #1
Open
@intenseG

Description

@intenseG

I am trying to train Go using this wonderful project and GymGo.
However, training.py prints an error at line 103.
https://github.com/intenseG/MuZero/blob/5cab8b6c89b652ad33e277318bfd4389de1315a9/muzero/training/training.py#L103

Is there a solution to this error?
Thank you.

[Additional debugging information]

value_batch.shape : (2048, 6, 9, 24)
targets.shape : (2048, 24)

[Stack trace]

Traceback (most recent call last):
  File "muzero.py", line 36, in <module>
    muzero(config)
  File "muzero.py", line 24, in muzero
    train_network(config, storage, replay_buffer, config.nb_epochs)
  File "C:\Users\inten\Desktop\ML\MuZero\muzero\training\training.py", line 19, in train_network
    update_weights(optimizer, network, batch)
  File "C:\Users\inten\Desktop\ML\MuZero\muzero\training\training.py", line 84, in update_weights
    optimizer.minimize(loss=loss, var_list=network.cb_get_variables())
  File "C:\Users\inten\Anaconda3\envs\muzero\lib\site-packages\tensorflow_core\python\keras\optimizer_v2\optimizer_v2.py", line 317, in minimize
    loss, var_list=var_list, grad_loss=grad_loss)
  File "C:\Users\inten\Anaconda3\envs\muzero\lib\site-packages\tensorflow_core\python\keras\optimizer_v2\optimizer_v2.py", line 351, in _compute_gradients
    loss_value = loss()
  File "C:\Users\inten\Desktop\ML\MuZero\muzero\training\training.py", line 42, in loss
    loss += tf.math.reduce_mean(loss_value(target_value_batch, value_batch, network.value_support_size))
  File "C:\Users\inten\Desktop\ML\MuZero\muzero\training\training.py", line 103, in loss_value
    return tf.nn.softmax_cross_entropy_with_logits(logits=value_batch, labels=targets)
  File "C:\Users\inten\Anaconda3\envs\muzero\lib\site-packages\tensorflow_core\python\ops\nn_ops.py", line 3105, in softmax_cross_entropy_with_logits_v2
    labels=labels, logits=logits, axis=axis, name=name)
  File "C:\Users\inten\Anaconda3\envs\muzero\lib\site-packages\tensorflow_core\python\util\deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "C:\Users\inten\Anaconda3\envs\muzero\lib\site-packages\tensorflow_core\python\ops\nn_ops.py", line 3206, in softmax_cross_entropy_with_logits_v2_helper
    precise_logits, labels, name=name)
  File "C:\Users\inten\Anaconda3\envs\muzero\lib\site-packages\tensorflow_core\python\ops\gen_nn_ops.py", line 11458, in softmax_cross_entropy_with_logits
    _six.raise_from(_core._status_to_exception(e.code, message), None)
  File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: logits and labels must be broadcastable: logits_size=[110592,24] labels_size=[2048,24] [Op:SoftmaxCrossEntropyWithLogits] name: softmax_cross_entropy_with_logits/

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0