-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[RLlib; Offline RL] Implement Offline Policy Evaluation (OPE) via Importance Sampling. #53702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
ea74ebd
30c9a8f
41665ed
2f43583
768397e
34f399c
9fb729c
d7f8eee
ee7f010
f2dd8cc
b0290d2
0e327f3
750c2ba
7e43fee
da66048
c54c296
3ae36e3
4d3168a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -533,6 +533,8 @@ def __init__(self, algo_class: Optional[type] = None): | |
# Offline evaluation. | ||
self.offline_evaluation_interval = None | ||
self.num_offline_eval_runners = 0 | ||
self.offline_evaluation_type: str = None | ||
self.offline_eval_runner_class = None | ||
# TODO (simon): Only `_offline_evaluate_with_fixed_duration` works. Also, | ||
# decide, if we use `offline_evaluation_duration` or | ||
# `dataset_num_iters_per_offline_eval_runner`. Should the user decide here? | ||
|
@@ -2705,6 +2707,8 @@ def evaluation( | |
# Offline evaluation. | ||
offline_evaluation_interval: Optional[int] = NotProvided, | ||
num_offline_eval_runners: Optional[int] = NotProvided, | ||
offline_evaluation_type: Optional[Callable] = NotProvided, | ||
offline_eval_runner_class: Optional[Callable] = NotProvided, | ||
offline_loss_for_module_fn: Optional[Callable] = NotProvided, | ||
offline_eval_batch_size_per_runner: Optional[int] = NotProvided, | ||
dataset_num_iters_per_offline_eval_runner: Optional[int] = NotProvided, | ||
|
@@ -2829,6 +2833,13 @@ def evaluation( | |
for parallel evaluation. Setting this to 0 forces sampling to be done in the | ||
local OfflineEvaluationRunner (main process or the Algorithm's actor when | ||
using Tune). | ||
offline_evaluation_type: Type of offline evaluation to run. Either `"eval_loss"` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Question: So, if a user provides There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a good one. Let me think about this. Both solutions have their advantages. |
||
for evaluating the validation loss of the policy, `"is"` for importance | ||
sampling, or `"pdis"` for per-decision importance sampling. If you want to | ||
implement your own offline evaluation method write an `OfflineEvaluationRunner` | ||
and use the `AlgorithmConfig.offline_eval_runner_class`. | ||
offline_eval_runner_class: An `OfflineEvaluationRunner` class that implements | ||
custom offline evaluation logic. | ||
offline_loss_for_module_fn: A callable to compute the loss per `RLModule` in | ||
offline evaluation. If not provided the training loss function ( | ||
`Learner.compute_loss_for_module`) is used. The signature must be ( | ||
|
@@ -2975,6 +2986,10 @@ def evaluation( | |
self.offline_evaluation_interval = offline_evaluation_interval | ||
if num_offline_eval_runners is not NotProvided: | ||
self.n AE96 um_offline_eval_runners = num_offline_eval_runners | ||
if offline_evaluation_type is not NotProvided: | ||
self.offline_evaluation_type = offline_evaluation_type | ||
if offline_eval_runner_class is not NotProvided: | ||
self.offline_eval_runner_cls = offline_eval_runner_class | ||
if offline_loss_for_module_fn is not NotProvided: | ||
self.offline_loss_for_module_fn = offline_loss_for_module_fn | ||
if offline_eval_batch_size_per_runner is not NotProvided: | ||
|
@@ -5282,6 +5297,33 @@ def _validate_offline_settings(self): | |
"recorded episodes cannot be read in for training." | ||
) | ||
|
||
# Offline evaluation. | ||
from ray.rllib.offline.offline_policy_evaluation_runner import ( | ||
OfflinePolicyEvaluationTypes, | ||
) | ||
|
||
offline_eval_types = list(OfflinePolicyEvaluationTypes) | ||
if ( | ||
self.offline_evaluation_type | ||
and self.offline_evaluation_type != "eval_loss" | ||
and self.offline_evaluation_type not in OfflinePolicyEvaluationTypes | ||
): | ||
self._value_error( | ||
f"Unknown offline evaluation type: {self.offline_evaluation_type}." | ||
"Available types of offline evaluation are either `'eval_loss' to evaluate " | ||
f"the training loss on a validation dataset or {offline_eval_types}." | ||
) | ||
|
||
from ray.rllib.offline.offline_evaluation_runner import OfflineEvaluationRunner | ||
|
||
if self.prelearner_class and not issubclass( | ||
self.prelearner_class, OfflineEvaluationRunner | ||
): | ||
self._value_error( | ||
"Unknown `offline_eval_runner_class`. OfflineEvaluationRunner class needs to inherit " | ||
"from `OfflineEvaluationRunner` class." | ||
) | ||
|
||
@property | ||
def is_online(self) -> bool: | ||
"""Defines if this config is for online RL. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah! How did this even get in there?