-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Reward classifier and training #528
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reward classifier and training #528
Conversation
Co-authored-by: Remi <re.cadene@gmail.com>
Co-authored-by: Remi <re.cadene@gmail.com>
Co-authored-by: Remi <re.cadene@gmail.com>
…lassifier_and_training
Co-authored-by: Remi <re.cadene@gmail.com>
Nice work @ChorntonYoel ! Could you move the classifier directory to Since now we will only use the reward classifier for |
lerobot/common/policies/factory.py
Outdated
elif name == "classifier": | ||
from lerobot.common.policies.classifier.configuration_classifier import ClassifierConfig | ||
from lerobot.common.policies.classifier.modeling_classifier import Classifier | ||
|
||
return Classifier, ClassifierConfig |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think its not ideal to put the classifier in the factory.py of policies. I think we can remove and instead of relying on make_policy
in the training script we can directly define the classifier there. Since the training script of the classifier is not train.py
.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this better now that the classifier has the "policy" hilserl/classifier" ?
Or do you still think it's confusing and we should initialize it in a different way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
lerobot/scripts/train_classifier.py
Outdated
from lerobot.common.datasets.factory import resolve_delta_timestamps | ||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset | ||
from lerobot.common.logger import Logger | ||
from lerobot.common.policies.factory import make_policy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could remove make_policy and manually define it later.
from lerobot.common.policies.factory import make_policy | |
from lerobot.common.policies.classifier.configuration_classifier import ClassifierConfig | |
from lerobot.common.policies.classifier.modeling_classifier import Classifier |
lerobot/scripts/train_classifier.py
Outdated
model = make_policy( | ||
hydra_cfg=cfg, | ||
dataset_stats=dataset.meta.stats if not cfg.resume else None, | ||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None, | ||
).to(device) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can define the classifier here:
model = make_policy( | |
hydra_cfg=cfg, | |
dataset_stats=dataset.meta.stats if not cfg.resume else None, | |
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None, | |
).to(device) | |
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg | |
classifier_cfg = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg) | |
if not cfg.resume: | |
model = Classifier(classifier_config, dataset.meta.stats) | |
else: | |
model = Classifier(classifier_config) | |
model.load_state_dict(Classifier.from_pretrained(str(logger.last_pretrained_model_dir)).state_dict()) | |
model = model.to(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Outdated, but would you still prefer I do that? I don't mind
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
6490927
into
huggingface:user/michel-aractingi/2024-11-27-port-hil-serl
Co-authored-by: Daniel Ritchie <daniel@brainwavecollective.ai> Co-authored-by: resolver101757 <kelster101757@hotmail.com> Co-authored-by: Jannik Grothusen <56967823+J4nn1K@users.noreply.github.com> Co-authored-by: Remi <re.cadene@gmail.com> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Daniel Ritchie <daniel@brainwavecollective.ai> Co-authored-by: resolver101757 <kelster101757@hotmail.com> Co-authored-by: Jannik Grothusen <56967823+J4nn1K@users.noreply.github.com> Co-authored-by: Remi <re.cadene@gmail.com> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Daniel Ritchie <daniel@brainwavecollective.ai> Co-authored-by: resolver101757 <kelster101757@hotmail.com> Co-authored-by: Jannik Grothusen <56967823+J4nn1K@users.noreply.github.com> Co-authored-by: Remi <re.cadene@gmail.com> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Daniel Ritchie <daniel@brainwavecollective.ai> Co-authored-by: resolver101757 <kelster101757@hotmail.com> Co-authored-by: Jannik Grothusen <56967823+J4nn1K@users.noreply.github.com> Co-authored-by: Remi <re.cadene@gmail.com> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Daniel Ritchie <daniel@brainwavecollective.ai> Co-authored-by: resolver101757 <kelster101757@hotmail.com> Co-authored-by: Jannik Grothusen <56967823+J4nn1K@users.noreply.github.com> Co-authored-by: Remi <re.cadene@gmail.com> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Daniel Ritchie <daniel@brainwavecollective.ai> Co-authored-by: resolver101757 <kelster101757@hotmail.com> Co-authored-by: Jannik Grothusen <56967823+J4nn1K@users.noreply.github.com> Co-authored-by: Remi <re.cadene@gmail.com> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Daniel Ritchie <daniel@brainwavecollective.ai> Co-authored-by: resolver101757 <kelster101757@hotmail.com> Co-authored-by: Jannik Grothusen <56967823+J4nn1K@users.noreply.github.com> Co-authored-by: Remi <re.cadene@gmail.com> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Daniel Ritchie <daniel@brainwavecollective.ai> Co-authored-by: resolver101757 <kelster101757@hotmail.com> Co-authored-by: Jannik Grothusen <56967823+J4nn1K@users.noreply.github.com> Co-authored-by: Remi <re.cadene@gmail.com> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
What this does
This PR is meant to add a reward classifier (used to classify if an image of a robot performing a task should get a reward or not), a training file allowing the training of the classifier (with logging + resuming), a config.yaml file that can be used to start a training, and a few tests for the training loop
How it was tested
Using 10 episodes made with the reward system of this PR: #518
Also I added a test file for the training classifier file. Lots of things are mocked but it covers the basics I believe.
How to checkout & try? (for the reviewer)
With the wandb entity and the dataset name adapted.
I was able to reproduce 95%+ after a few epochs with
facebook/convnext-base-224
as backbone and a dataset of 10 epsiodes of ~15 sec.This branch was built on top of the branch from #518 so will need to wait for this one to be merged befre merging