8000 Fix path resolution in training & inference by talmo · Pull Request #643 · talmolab/sleap · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Fix path resolution in training & inference #643

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sleap/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1588,6 +1588,7 @@ def _show_learning_dialog(self, mode: str):
return

if self._child_windows.get(mode, None) is None:
# Re-use existing dialog widget.
self._child_windows[mode] = LearningDialog(
mode,
self.state["filename"],
Expand All @@ -1596,6 +1597,11 @@ def _show_learning_dialog(self, mode: str):
self._child_windows[mode]._handle_learning_finished.connect(
self._handle_learning_finished
)
else:
# Update da 8000 ta in existing dialog widget.
self._child_windows[mode].labels = self.labels
self._child_windows[mode].labels_filename = self.state["filename"]
self._child_windows[mode].skeleton = self.labels.skeleton

self._child_windows[mode].update_file_lists()

Expand Down
2 changes: 1 addition & 1 deletion sleap/nn/data/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def from_filename(
Returns:
A `LabelsReader` instance that can create a dataset for pipelining.
"""
labels = sleap.Labels.load_file(filename)
labels = sleap.load_file(filename)
if user_instances:
return cls.from_user_instances(labels)
else:
Expand Down
2 changes: 1 addition & 1 deletion sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2947,7 +2947,7 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]:
)

if data_path.endswith(".slp"):
labels = sleap.Labels.load_file(data_path)
labels = sleap.load_file(data_path)

if args.only_labeled_frames:
provider = LabelsReader.from_user_labeled_frames(labels)
Expand Down
11 changes: 7 additions & 4 deletions sleap/nn/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def from_labels(
"""Create data readers from sleap.Labels datasets as data providers."""
if isinstance(training, str):
logger.info(f"Loading training labels from: {training}")
training = sleap.Labels.load_file(training, video_search=video_search_paths)
training = sleap.load_file(training, video_search=video_search_paths)

if labels_config is not None and labels_config.split_by_inds:
# First try to split by indices if specified in config.
Expand Down Expand Up @@ -192,7 +192,7 @@ def from_labels(
# If validation is still a path, load it.
logger.info(f"Loading validation labels from: {validation}")
validation = sleap.Labels.load_file(
validation, video_search=video_search_paths
validation, search_paths=video_search_paths
)
elif isinstance(validation, float):
logger.info(
Expand All @@ -217,7 +217,7 @@ def from_labels(
if isinstance(test, str):
# If test is still a path, load it.
logger.info(f"Loading test labels from: {test}")
test = sleap.Labels.load_file(test, video_search=video_search_paths)
test = sleap.load_file(test, search_paths=video_search_paths)

test_reader = None
if test is not None:
Expand Down Expand Up @@ -1568,6 +1568,9 @@ def main():
job_config.outputs.save_visualizations |= args.save_viz
if args.labels_path == "":
args.labels_path = None
args.video_paths = args.video_paths.split(",")
if len(args.video_paths) == 0:
args.video_paths = None

logger.info("Versions:")
sleap.versions()
Expand Down Expand Up @@ -1613,7 +1616,7 @@ def main():
training_labels=args.labels_path,
validation_labels=args.val_labels,
test_labels=args.test_labels,
video_search_paths=args.video_paths.split(","),
video_search_paths=args.video_paths,
)
trainer.train()

Expand Down
0