Description
What happened + What you expected to happen
Hi, when I ran my env driven by DQN with an action mask, I met this issue after random rounds of iteration. How could I repair it?
File "/home/mhx/anaconda3/envs/RLlib/lib/python3.8/site-packages/ray/tune/trainable/trainable.py", line 347, in train
result = self.step()
File "/home/mhx/anaconda3/envs/RLlib/lib/python3.8/site-packages/ray/rllib/algorithms/algorithm.py", line 661, in step
results, train_iter_ctx = self._run_one_training_iteration()
File "/home/mhx/anaconda3/envs/RLlib/lib/python3.8/site-packages/ray/rllib/algorithms/algorithm.py", line 2378, in _run_one_training_iteration
num_recreated += self.try_recover_from_step_attempt(
File "/home/mhx/anaconda3/envs/RLlib/lib/python3.8/site-packages/ray/rllib/algorithms/algorithm.py", line 2185, in try_recover_from_step_attempt
raise error
File "/home/mhx/anaconda3/envs/RLlib/lib/python3.8/site-packages/ray/rllib/algorithms/algorithm.py", line 2373, in _run_one_training_iteration
results = self.training_step()
File "/home/mhx/anaconda3/envs/RLlib/lib/python3.8/site-packages/ray/rllib/algorithms/dqn/dqn.py", line 358, in training_step
new_sample_batch = synchronous_parallel_sample(
File "/home/mhx/anaconda3/envs/RLlib/lib/python3.8/site-packages/ray/rllib/execution/rollout_ops.py", line 100, in synchronous_parallel_sample
sample_batches = ray.get(
ray.exceptions.RayTaskError(RuntimeError): �[36mray::RolloutWorker.sample()�[39m (pid=3996308, ip=10.112.114.52, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x7fe84081b5b0>)
File "/home/mhx/anaconda3/envs/RLlib/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 806, in sample
batches = [self.input_reader.next()]
File "/home/mhx/anaconda3/envs/RLlib/lib/python3.8/site-packages/ray/rllib/evaluation/sampler.py", line 92, in next
batches = [self.get_data()]
File "/home/mhx/anaconda3/envs/RLlib/lib/python3.8/site-packages/ray/rllib/evaluation/sampler.py", line 282, in get_data
item = next(self._env_runner)
File "/home/mhx/anaconda3/envs/RLlib/lib/python3.8/site-packages/ray/rllib/evaluation/sampler.py", line 707, in _env_runner
eval_results = _do_policy_eval(
File "/home/mhx/anaconda3/envs/RLlib/lib/python3.8/site-packages/ray/rllib/evaluation/sampler.py", line 1207, in _do_policy_eval
eval_results[policy_id] = policy.compute_actions_from_input_dict(
File "/home/mhx/anaconda3/envs/RLlib/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 319, in compute_actions_from_input_dict
return self._compute_action_helper(
File "/home/mhx/anaconda3/envs/RLlib/lib/python3.8/site-packages/ray/rllib/utils/threading.py", line 24, in wrapper
return func(self, *a, **k)
File "/home/mhx/anaconda3/envs/RLlib/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 1000, in _compute_action_helper
actions, logp = self.exploration.get_exploration_action(
File "/home/mhx/anaconda3/envs/RLlib/lib/python3.8/site-packages/ray/rllib/utils/exploration/epsilon_greedy.py", line 100, in get_exploration_action
return self._get_torch_exploration_action(
File "/home/mhx/anaconda3/envs/RLlib/lib/python3.8/site-packages/ray/rllib/utils/exploration/epsilon_greedy.py", line 213, in _get_torch_exploration_action
torch.multinomial(random_valid_action_logits, 1), axis=1
RuntimeError: invalid multinomial distribution (sum of probabilities <= 0)
Versions / Dependencies
Package Version
gym 0.23.1
ray 2.0.0
torch 1.10.0
torchaudio 0.10.0
torchvision 0.11.1
python 3.8
Reproduction script
This problem has only been found in the environment I made, and currently the executable code cannot be provided.
Here is the code to add the action mask in DQN. Is it possible that the error is because of the action mask masks all action?
action_mask = input_data['action_mask']
inf_mask = torch.clamp(torch.log(action_mask), FLOAT_MIN, FLOAT_MAX)
if not self.logits_layer is None:
logits, values = self.logits_layer(middle_data), self.value_layer(middle_data)
# print(logits.shape)
self._value_out = torch.reshape(values, [-1])
return logits + inf_mask, [] # add action_mask here
# return logits, []
else:
return middle_data, []
Issue Severity
High: It blocks me from completing my task.