8000 Monitor node losses across training by gitttt-1234 · Pull Request #199 · talmolab/sleap-nn · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Monitor node losses across training #199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: divya/slumbr_v1_center_padding
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8000
13 changes: 12 additions & 1 deletion docs/config_slumbr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ model_config:
pretrained_backbone_weights:
pretrained_head_weights:
backbone_type: unet
dataset_loss_weights:
0: 1.0
1: 1.0
2: 1.0
keypoint_mining:
online_mining: False
hard_to_easy_ratio: 2.0
min_hard_keypoints: 2
max_hard_keypoints:
loss_scale:
backbone_config:
in_channels: 1
kernel_size: 3
Expand All @@ -75,7 +85,8 @@ model_config:
single_instance:
centroid:
bottomup:
centered_instance:
centered_instance:
fitbbox: False # False: centroid crop, True: crop with fit bbox
confmaps:
0:
part_names:
Expand Down
8 changes: 6 additions & 2 deletions sleap_nn/inference/topdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def _generate_crops(self, inputs):
inputs["eff_scale"],
inputs["pad_shifts"],
):
image = image[0] # source imgs: nested tensor

# size matcher
max_h = self.preprocess_config.centered_max_height
Expand Down Expand Up @@ -200,6 +201,8 @@ def _generate_fitbbox_crops(self, inputs):
inputs["instances"],
):

image = image[0] # source imgs: nested tensor

# adjust for initial size matching in preprocessing

instances = (
Expand All @@ -220,8 +223,6 @@ def _generate_fitbbox_crops(self, inputs):
image = resize_image(image, self.precrop_resize)
instances = instances * self.precrop_resize

n = centroid.shape[0]

# get max bbox size for this batch
max_crop_size = self.preprocess_config.max_crop_size

Expand Down Expand Up @@ -256,6 +257,7 @@ def _generate_fitbbox_crops(self, inputs):
]
),
)

bbox_shifts.append(bbox[:2].unsqueeze(dim=0))

cropped_image_match_hw, eff_scale, pad_wh = apply_sizematcher(
Expand All @@ -265,6 +267,8 @@ def _generate_fitbbox_crops(self, inputs):
eff_scale_crops.append(eff_scale)
padding_shifts_crops.append(torch.Tensor(pad_wh).unsqueeze(dim=0))

n = len(instance_images)

ex = {}
ex["image"] = torch.cat([image] * n)
ex["centroid_val"] = centroid_val
Expand Down
233 changes: 218 additions & 15 deletions sleap_nn/training/lightning_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,11 @@ def __init__(
model_type=self.model_type,
)

self.dataset_loss_weights = self.config.get(
"model_config.dataset_loss_weights",
{k: 1.0 for k in self.config.dataset_mapper},
)

if (
len(self.model_config.head_configs[self.model_type]) > 1
): # TODO: online mining for each dataset
Expand Down Expand Up @@ -826,8 +831,14 @@ def __init__(
torch_model=self.forward,
peak_threshold=0.2,
return_confmaps=True,
centered_fitbbox=True,
centered_fitbbox=False,
)
self.part_names = {}
for (
d_num,
cfg,
) in self.config.model_config.head_configs.centered_instance.confmaps.items():
self.part_names[d_num] = cfg.part_names

def on_train_epoch_start(self):
"""Configure the train timer at the beginning of each epoch."""
Expand All @@ -846,10 +857,6 @@ def on_train_epoch_start(self):
sample["pad_shifts"] = torch.zeros(
(sample["video_idx"].shape[0], 2)
)
sample["eff_scale_crops"] = torch.ones(sample["video_idx"].shape)
sample["padding_shifts_crops"] = torch.zeros(
(sample["video_idx"].shape[0], 2)
)
for k, v in sample.items():
sample[k] = v.to(device=self.device)
self.inf_layer.output_stride = self.config.model_config.head_configs.centered_instance.confmaps[
Expand Down Expand Up @@ -937,7 +944,53 @@ def training_step(self, batch, batch_idx):
output[h_num] = output[h_num].detach()

y_preds = output[d_num]
curr_loss = 1.0 * self.loss_func(y_preds, y)

for c in range(y.shape[-3]):
l = self.loss_func(y_preds[..., c, :, :], y[..., c, :, :])
self.log(
f"node_inv_losses_dataset:{d_num}_node:`{self.part_names[d_num][c]}`",
1 / l,
prog_bar=True,
on_step=False,
on_epoch=True,
logger=True,
)

if OmegaConf.select(
self.config, "model_config.keypoint_mining.online_mining", default=False
):
mse_loss = (y_preds - y) ** 2
batch_shape = mse_loss.shape
l = torch.sum(mse_loss, dim=(0, 2, 3))
best_loss = torch.min(l)
is_hard_keypoint = (
l / best_loss
) >= self.config.model_config.keypoint_mining.hard_to_easy_ratio
n_hard_keypoints = torch.sum(is_hard_keypoint.to(torch.int32))
if self.config.model_config.keypoint_mining.max_hard_keypoints < 0:
max_hard_keypoints = l.shape[0]
else:
max_hard_keypoints = min(
self.config.model_config.keypoint_mining.max_hard_keypoints,
l.shape[0],
)

k = min(
max(
n_hard_keypoints,
self.config.model_config.keypoint_mining.min_hard_keypoints,
),
max_hard_keypoints,
)
k_vals, k_inds = torch.topk(l, k=k, largest=True, sorted=False)
k_loss = k_vals * self.config.model_config.keypoint_mining.loss_scale
n_elements = batch_shape[0] * batch_shape[2] * batch_shape[3] * k
k_loss = torch.sum(k_loss) / n_elements
mse_loss = k_loss
else:
mse_loss = self.loss_func(y_preds, y)

curr_loss = self.dataset_loss_weights[d_num] * mse_loss
loss += curr_loss

self.manual_backward(curr_loss, retain_graph=True)
Expand Down Expand Up @@ -973,7 +1026,7 @@ def validation_step(self, batch, batch_idx):
), torch.squeeze(batch[d_num]["confidence_maps"], dim=1)

y_preds = self.model(X)["CenteredInstanceConfmapsHead"][d_num]
curr_loss = 1.0 * nn.MSELoss()(y_preds, y)
curr_loss = self.dataset_loss_weights[d_num] * nn.MSELoss()(y_preds, y)
total_loss += curr_loss

self.log(
Expand Down Expand Up @@ -1005,6 +1058,120 @@ def validation_step(self, batch, batch_idx):
)


class TopDownCenteredInstanceFitbboxMultiHeadLightningModule(
TopDownCenteredInstanceMultiHeadLightningModule
):
def __init__(
self,
config: OmegaConf,
model_type: str,
backbone_type: str,
):
"""Initialise the configs and the model."""
super().__init__(
config=config,
backbone_type=backbone_type,
model_type=model_type,
)
self.inf_layer = FindInstancePeaks(
torch_model=self.forward,
peak_threshold=0.2,
return_confmaps=True,
centered_fitbbox=True,
)
self.part_names = {}
for (
d_num,
cfg,
) in self.config.model_config.head_configs.centered_instance.confmaps.items():
self.part_names[d_num] = cfg.part_names

def on_train_epoch_start(self):
"""Configure the train timer at the beginning of each epoch."""
# add eval
if self.config.trainer_config.log_inf_epochs is not None:
if (
self.current_epoch > 0
and self.global_rank == 0
and (self.current_epoch % self.config.trainer_config.log_inf_epochs)
== 0
):
img_array = []
for d_num in self.config.dataset_mapper:
sample = next(iter(self.trainer.val_dataloaders[d_num]))
sample["eff_scale"] = torch.ones(sample["video_idx"].shape)
sample["pad_shifts"] = torch.zeros(
(sample["video_idx"].shape[0], 2)
)
# for fit bbox cropping
sample["eff_scale_crops"] = torch.ones(sample["video_idx"].shape)
sample["padding_shifts_crops"] = torch.zeros(
(sample["video_idx"].shape[0], 2)
)
for k, v in sample.items():
sample[k] = v.to(device=self.device)
self.inf_layer.output_stride = self.config.model_config.head_configs.centered_instance.confmaps[
d_num
][
"output_stride"
]
output = self.inf_layer(sample, output_head_skeleton_num=d_num)
batch_idx = 0

# plot predictions on sample image
if self.use_wandb or self.save_ckpt:
peaks = output["pred_instance_peaks"][batch_idx].cpu().numpy()
gt_instances = sample["instance"][batch_idx, 0].cpu().numpy()
img = output["instance_image"][batch_idx, 0].cpu().numpy()
confmaps = output["pred_confmaps"][batch_idx].cpu().numpy()
fig = plot_pred_confmaps_peaks(
img=img,
confmaps=confmaps,
peaks=np.expand_dims(peaks, axis=0),
gt_instances=np.expand_dims(gt_instances, axis=0),
plot_title=f"{self.config.dataset_mapper[d_num]}",
)

if self.save_ckpt:
curr_results_path = (
Path(self.config.trainer_config.save_ckpt_path)
/ "visualizations"
/ f"epoch_{self.current_epoch}"
)
if not Path(curr_results_path).exists():
Path(curr_results_path).mkdir(parents=True, exist_ok=True)
fig.savefig(
(Path(curr_results_path) / f"pred_on_{d_num}").as_posix(),
bbox_inches="tight",
)

if self.use_wandb:
fig.canvas.draw()
img = Image.frombytes(
"RGB",
fig.canvas.get_width_height(),
fig.canvas.tostring_rgb(),
)

img_array.append(wandb.Image(img))

plt.close(fig)

if self.use_wandb and img_array:
# wandb logging metrics in table

wandb_table = wandb.Table(
columns=[
"epoch",
"Predictions on test set",
],
data=[[self.current_epoch, img_array]],
)
wandb.log({"Performance": wandb_table})

self.train_start_time = time.time()


class SingleInstanceMultiHeadLightningModule(MultiHeadLightningModule):
"""Lightning Module for SingleInstanceMultiHeadLightningModule Model.

Expand Down Expand Up @@ -1143,7 +1310,7 @@ def training_step(self, batch, batch_idx):
output[h_num] = output[h_num].detach()

y_preds = output[d_num]
curr_loss = 1.0 * self.loss_func(y_preds, y)
curr_loss = self.dataset_loss_weights[d_num] * self.loss_func(y_preds, y)
loss += curr_loss

self.manual_backward(curr_loss, retain_graph=True)
Expand Down Expand Up @@ -1179,7 +1346,7 @@ def validation_step(self, batch, batch_idx):
), torch.squeeze(batch[d_num]["confidence_maps"], dim=1)

y_preds = self.model(X)["SingleInstanceConfmapsHead"][d_num]
curr_loss = 1.0 * nn.MSELoss()(y_preds, y)
curr_loss = self.dataset_loss_weights[d_num] * nn.MSELoss()(y_preds, y)
total_loss += curr_loss

self.log(
Expand Down Expand Up @@ -1337,6 +1504,8 @@ def training_step(self, batch, batch_idx):
loss = 0
opt = self.optimizers()
opt.zero_grad()

dataset_losses = {}
for d_num in batch.keys():
batch_data = batch[d_num]
X, y = torch.squeeze(batch_data["image"], dim=1).to(
Expand All @@ -1348,10 +1517,44 @@ def training_step(self, batch, batch_idx):
output = self.model(X)["CentroidConfmapsHead"]

y_preds = output[0]
curr_loss = 1.0 * self.loss_func(y_preds, y)
loss += curr_loss
curr_loss = self.dataset_loss_weights[d_num] * self.loss_func(y_preds, y)
dataset_losses[d_num] = curr_loss

self.manual_backward(curr_loss, retain_graph=True)
self.log(
f"train_loss_on_head_{d_num}",
curr_loss,
prog_bar=True,
on_step=False,
on_epoch=True,
logger=True,
)

# compute dynamic loss weights for each dataset
with torch.no_grad():
total_loss = sum(l.detach() for l in dataset_losses.values())
dynamic_weights = {
d_num: (dataset_losses[d_num].detach() / total_loss).clamp(
min=1e-4
) # avoid zero
for d_num in dataset_losses
}

# apply weights and compute total loss
for d_num, loss in dataset_losses.items():
weighted_loss = dynamic_weights[d_num] * loss

self.manual_backward(weighted_loss, retain_graph=True)

loss += weighted_loss

self.log(
f"dynamic_weights_for_head_{d_num}",
dynamic_weights[d_num],
prog_bar=True,
on_step=False,
on_epoch=True,
logger=True,
)

self.log(
f"train_loss",
Expand All @@ -1377,7 +1580,7 @@ def validation_step(self, batch, batch_idx):
)

y_preds = self.model(X)["CentroidConfmapsHead"][0]
curr_loss = 1.0 * nn.MSELoss()(y_preds, y)
curr_loss = self.dataset_loss_weights[d_num] * nn.MSELoss()(y_preds, y)
total_loss += curr_loss

self.log(
Expand Down Expand Up @@ -1615,7 +1818,7 @@ def training_step(self, batch, batch_idx):
),
"PartAffinityFieldsHead": nn.MSELoss()(output_pafs[d_num], y_paf),
}
curr_loss = 1.0 * sum(
curr_loss = self.dataset_loss_weights[d_num] * sum(
[s * losses[t] for s, t in zip(self.loss_weights, losses)]
)

Expand Down Expand Up @@ -1667,7 +1870,7 @@ def validation_step(self, batch, batch_idx):
"PartAffinityFieldsHead": nn.MSELoss()(output_pafs[d_num], y_paf),
}

curr_loss = 1.0 * sum(
curr_loss = self.dataset_loss_weights[d_num] * sum(
[s * losses[t] for s, t in zip(self.loss_weights, losses)]
)
total_loss += curr_loss
Expand Down
Loading
Loading
0