8000 Training job generation tweaks by talmo · Pull Request #642 · talmolab/sleap · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Training job generation tweaks #642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sleap/gui/learning/dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ def export_package(self, output_path: Optional[str] = None, gui: bool = True):
self.accept()

if gui:
msgBox = QtWidgets.QMessageBox(text=f"Created training job package:")
msgBox = QtWidgets.QMessageBox(text=f"Created training job package.")
msgBox.setDetailedText(output_path)
msgBox.setWindowTitle("Training Job Package")
okButton = msgBox.addButton(QtWidgets.QMessageBox.Ok)
Expand Down
66 changes: 54 additions & 12 deletions sleap/gui/learning/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import tempfile
import time
import shutil
import yaml
from pathlib import Path
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Text, Tuple

Expand Down Expand Up @@ -358,9 +360,10 @@ def write_pipeline_files(
# Update config.
cfg_info.config.outputs.run_name_suffix = suffix

training_jobs = []
for cfg_info in config_info_list:
if cfg_info.dont_retrain:
# Use full absolute path to already training model
# Use full absolute path to already trained model
trained_path = os.path.normpath(os.path.join(old_cwd, cfg_info.path))
new_cfg_filenames.append(trained_path)

Expand All @@ -373,8 +376,13 @@ def write_pipeline_files(
# Note that setup_new_run_folder does things relative to cwd which
# is the main reason we're setting it to the output directory rather
# than just using normpath.
cfg_info.config.outputs.runs_folder = ""
# cfg_info.config.outputs.runs_folder = ""
training.setup_new_run_folder(cfg_info.config.outputs)
# training.setup_new_run_folder(
# cfg_info.config.outputs,
# # base_run_name=f"{model_type}.n={len(labels.user_labeled_frames)}",
# base_run_name=cfg_info.head_name,
# )

# Now we set the filename for the training config file
new_cfg_filename = f"{cfg_info.head_name}.json"
Expand All @@ -390,6 +398,15 @@ def write_pipeline_files(
f"sleap-train {new_cfg_filename} {os.path.basename(labels_filename)}\n"
)

# Setup job params
training_jobs.append(
{
"cfg": new_cfg_filename,
"run_path": Path(cfg_info.config.outputs.run_path).as_posix(),
"train_labels": os.path.basename(labels_filename),
}
)

# Write the script to train the models which need to be trained
with open(os.path.join(output_dir, "train-script.sh"), "w") as f:
f.write(train_script)
Expand All @@ -404,11 +421,15 @@ def write_pipeline_files(
inference_params=inference_params,
)

inference_jobs = []
for item_for_inference in items_for_inference.items:
if type(item_for_inference) == DatasetItemForInference:
data_path = labels_filename
else:
data_path = item_for_inference.path

# We want to save predictions in output dir so use local path
prediction_output_path = (
f"{os.path.basename(item_for_inference.path)}.predictions.slp"
)
prediction_output_path = f"{os.path.basename(data_path)}.predictions.slp"

# Use absolute path to video
item_for_inference.use_absolute_path = True
Expand All @@ -421,10 +442,37 @@ def write_pipeline_files(
# And join them into a single call to inference
inference_script += " ".join(cli_args) + "\n"

# Setup job params
only_suggested_frames = False
if type(item_for_inference) == DatasetItemForInference:
only_suggested_frames = item_for_inference.frame_filter == "suggested"

# TODO: support frame ranges, user-labeled frames
tracking_args = {
k: v for k, v in inference_params.items() if k.startswith("tracking.")
}
inference_jobs.append(
{
"data_path": os.path.basename(data_path),
"models": [Path(p).as_posix() for p in new_cfg_filenames],
"output_path": prediction_output_path,
"type": "labels"
if type(item_for_inference) == DatasetItemForInference
else "video",
"only_suggested_frames": only_suggested_frames,
"tracking": tracking_args,
}
)

# And write it
with open(os.path.join(output_dir, "inference-script.sh"), "w") as f:
f.write(inference_script)

# Save jobs.yaml
jobs = {"training": training_jobs, "inference": inference_jobs}
with open(os.path.join(output_dir, "jobs.yaml"), "w") as f:
yaml.dump(jobs, f)

# Restore the working directory
os.chdir(old_cwd)

Expand Down Expand Up @@ -745,14 +793,8 @@ def train_subprocess(
if job_config.outputs.tensorboard.write_logs:
cli_args.append("--tensorboard")

# Add list of video paths so we can find video even if paths in saved
# labels dataset file are incorrect.
if video_paths:
cli_args.extend(("--video-paths", ",".join(video_paths)))

print(cli_args)

# Run training in a subprocess.
print(cli_args)
proc = subprocess.Popen(cli_args)

# Wait till training is done, calling a callback if given.
Expand Down
0