Open
Description
What happened + What you expected to happen
The problem is here:
ray/rllib/models/tf/tf_action_dist.py
Lines 90 to 91 in 6d8d739
tf.random.categorical takes in log probabilities even though the name of the input variable is 'logits'. Here, self.inputs are the logits, so masked actions (where logits[a]=0) are considered valid samples by tf. This also impacts MultiCategorical since it ultimately calls this same method.
Versions / Dependencies
ray 1.11.0 (still present in 1.12, though)
python 3.9
tf 2.7
rhel 7.9
Reproduction script
# mask last 3 actions
z = tf.constant([[0.5,0.5,0.5,0,0,0]])
# current - will sample masked actions
tf.random.categorical(z,10)
# corrected - won't sample masked actions
tf.random.categorical(tf.math.log(z),10)
Issue Severity
High: It blocks me from completing my task.