Open
Description
Bug description
In predict_th, there's an assert statement that says the following:
assert rew_th.shape == state.shape[:1]
This will fail, even if you've modified state_th using self.preprocess to be a valid tensor.
Steps to reproduce
Attempt to pre-process a dictionary style state with the preprocess function of the reward network. Even if you return a valid state_th, it checks against the original state, which is incorrect.
Instead I believe it should be:
assert rew_th.shape == state_th.shape[:1]
I believe this was a typo in the implementation.
Environment
- Linux 24.04
- Python version: 3.10.17