-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[RLlib] Examples folder do-over (vol 52): Custom action distribution example (new script, replaces existing Catalogs-based one). #53262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. APproved with a kind request for including temperature decay.
@@ -670,7 +665,7 @@ def from_logits( | |||
child_distribution_cls_struct, child_distribution_list | |||
) | |||
|
|||
return TorchMultiDistribution( | |||
return cls( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why changing it here the other way around?
@@ -0,0 +1,118 @@ | |||
"""Example on how to define and run an experiment with a custom action distribution. | |||
|
|||
The example uses an additional `temperature` parameter on top of the built-in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great example of how to introduce temperature into the action sampling. Could we also show how to decay this temperature. Temperature decay over the course of training is a common practice in RL.
# to None, its default value. | ||
self.action_dist_cls = _make_categorical_with_temperature( | ||
self.model_config.get("action_dist_temperature", 1.0), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sven1977 I think, for the purpose of this PR, using this API still makes sense.
But I'd like to propose a (backward-compatible) change to RL Modules:
RLModule.get_inference_action_dist_cls
should be a getter method RLModule.inference_action_dist_cls
to make it look like an attribute but does the same thing as today. If user now wants to override, they set that attribute in the setup method. Because today, we have a mixture of attributes and these getter methods to modify RLModules.
That way, the default way to change all action distributions would be the setup method, while the old path of overriding RLModule.get_inference_action_dist_cls
would still be available through overriding the RLModule.inference_action_dist_cls
getter method. So we get to a state where user does not have to mix inheritance-based definition of components with setup().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also CC @simonsays1980
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we are almost there, already. The default implementation of get_inference_action_dist_cls
today is:
def get_inference_action_dist_cls(self) -> Type[TorchDistribution]:
if self.action_dist_cls is not None:
return self.action_dist_cls
elif isinstance(self.action_space, gym.spaces.Discrete):
return TorchCategorical
elif isinstance(self.action_space, gym.spaces.Box):
return TorchDiagGaussian
else:
raise ValueError(...)
Are you suggesting to just make the attributes more granular, like introduce self.inference_action_dist_cls
, self.exploration_action_dist_cls
, and self.train_action_dist_cls
?
I'm not sure. Maybe this would complicate things and give users too many options.
Counter suggestion:
- We deprecate the option to set any dist-cls attribute. Everything has to be done through overriding methods.
- Analogous to overriding
_forward
vs_forward_[inference|exploration|train]
, we should introduce the methods:_get_action_dist_cls()
<- for all cases,_get_action_dist_cls_inference
, etc.. <- for the specific cases. By default, all the specific cases simply call the generic_get_action_dist_cls()
. Again, completely analogous to behavior of the_forward
methods. This way, if users just need one class, they override_get_action_dist_cls
, if they need more granularity for some phases, they override the phase-specific methods.
# your custom class(es) from these. In this case, leave self.action_dist_cls set | ||
# to None, its default value. | ||
self.action_dist_cls = _make_categorical_with_temperature( | ||
self.model_config.get("action_dist_temperature", 1.0), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we please not default to 1.0
just for the purpose of making this a bit safer?
With how it is now, this example would not fail if user sets model_config["action_dist_temp"] or some other wrong index of the model dict making user believe that die temperature has negligible impact because the failure is silent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This my my only nit, the rest are just "thoughts for future PRs"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! I generally agree that hidden defaults should be avoided. Will fix ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
this RLModule is subject to. Note that the observation space might not be the | ||
exact space from your env, but that it might have already gone through | ||
preprocessing through a connector pipeline (for example, flattening, | ||
frame-stacking, mean/std-filtering, etc..). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: I think we should, at some point, disambiguate the word observation_space
by changing it to input_space
or something similar.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have been thinking about this for some time as well.
I think a contra-argument could be:
- In 99% of the cases, a sub-module within a MultiRLModule is some form of policy, mapping agent-observations to agent-actions.
- Yes, there are sometimes sub-modules in a MultiRLModule that are NOT policies, like a world model or a shared encoder. But even in these cases, they normally take observations as inputs, or - and that would still require
observation_space
information to be present - a combination of observations and (last n) rewards and (last n) actions. - Yes, you could also have a sub-module that's some sort of head, getting its input from an intermediary embedding layer, but then in that case, I would think that the size of that embedding layer (probably some 1D tensor) would be given in
self.model_config
.
@override(Distribution) | ||
def from_logits(cls, logits: TensorType, **kwargs) -> "TorchDistribution": | ||
return cls(logits=logits, **kwargs) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just one nit. Thanks!
…nup_examples_folder_52_custom_action_distribution
…example (new script, replaces existing Catalogs-based one). (ray-project#53262) Signed-off-by: Chris Zhang <chris@anyscale.com>
…example (new script, replaces existing Catalogs-based one). (ray-project#53262) Signed-off-by: Vicky Tsang <vtsang@amd.com>
…example (new script, replaces existing Catalogs-based one). (ray-project#53262)
…example (new script, replaces existing Catalogs-based one). (ray-project#53262) Signed-off-by: Scott Lee <scott.lee@rebellions.ai>
Examples folder do-over (vol 52): Custom action distribution example
Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.