From 5acbe6e2892b1fcf0683807b5d761971b498fea9 Mon Sep 17 00:00:00 2001 From: Yang Hong Date: Wed, 30 Mar 2022 17:10:46 +0800 Subject: [PATCH 1/5] apply transformation to flow grid and flow value --- model/augment_consist_loss.py | 195 ++++++++++++++++++++++++---------- 1 file changed, 140 insertions(+), 55 deletions(-) diff --git a/model/augment_consist_loss.py b/model/augment_consist_loss.py index 0464d06..65f4933 100644 --- a/model/augment_consist_loss.py +++ b/model/augment_consist_loss.py @@ -1,11 +1,9 @@ -from cv2 import ROTATE_90_CLOCKWISE -from numpy import imag import torch +import cv2 import random import numpy as np import torch.nn.functional as F from torchvision.transforms.functional import hflip, vflip, rotate -import cv2 def visualize_flow(flow, save_name): @@ -24,7 +22,7 @@ def visualize_flow(flow, save_name): # img0, img1, gt are 4D tensors of (B, 3, 256, 448). gt are the middle frames. -def random_shift(img0, img1, gt, shift_sigmas=(16, 10)): +def random_shift(img0, img1, gt, flow, flow_teacher, shift_sigmas=(16, 10)): B, C, H, W = img0.shape u_shift_sigma, v_shift_sigma = shift_sigmas # 90% of dx and dy are within [-2*u_shift_sigma, 2*u_shift_sigma] @@ -123,10 +121,17 @@ def random_shift(img0, img1, gt, shift_sigmas=(16, 10)): # mask for the middle frame. Both directions have the same mask. mask = torch.zeros(mask_shape, device=img0.device, dtype=bool) mask[:, :, TM:BM, LM:RM] = True - return img0a, img1a, gta, mask, dxy + + # merge flow_handler with aug_handler + # s enumerates all scales. + flow_a = [] + for s in np.arange(len(flow)): + flow_a.append(flow[s] + dxy) + flow_teacher_a = flow_teacher + dxy + return img0a, img1a, gta, flow_a, flow_teacher_a, mask, dxy -def _hflip(img0, img1, gt): +def _hflip(img0, img1, gt, flow, flow_teacher): # B, C, H, W img0a = hflip(img0.clone()) img1a = hflip(img1.clone()) @@ -136,14 +141,18 @@ def _hflip(img0, img1, gt): mask = torch.ones(mask_shape, device=img0.device, dtype=bool) sxy = torch.tensor([ -1, 1, -1, 1], dtype=float, device=img0.device) sxy = sxy.view(1, 4, 1, 1) - # temp0 = img0a[0].permute(1, 2, 0).cpu().numpy() - # cv2.imwrite('flip0.png', temp0*255) - # temp1 = img0[0].permute(1, 2, 0).cpu().numpy() - # cv2.imwrite('0.png', temp1*255) - return img0a, img1a, gta, mask, sxy + # merge flow_handler with aug_handler + # s enumerates all scales. + flow_a = [] + for s in np.arange(len(flow)): + temp = hflip(flow[s]) + flow_a.append(temp * sxy) + temp = hflip(flow_teacher) + flow_teacher_a = temp * sxy + return img0a, img1a, gta, flow_a, flow_teacher_a, mask -def _vflip(img0, img1, gt): +def _vflip(img0, img1, gt, flow, flow_teacher): # B, C, H, W img0a = vflip(img0.clone()) img1a = vflip(img1.clone()) @@ -153,17 +162,42 @@ def _vflip(img0, img1, gt): mask = torch.ones(mask_shape, device=img0.device, dtype=bool) sxy = torch.tensor([ 1, -1, 1, -1], dtype=float, device=img0.device) sxy = sxy.view(1, 4, 1, 1) - return img0a, img1a, gta, mask, sxy + # merge flow_handler with aug_handler + # s enumerates all scales. + flow_a = [] + for s in np.arange(len(flow)): + temp = vflip(flow[s]) + flow_a.append(temp * sxy) + temp = vflip(flow_teacher) + flow_teacher_a = temp * sxy + return img0a, img1a, gta, flow_a, flow_teacher_a, mask -def random_flip(img0, img1, gt, shift_sigmas=None): +def random_flip(img0, img1, gt, flow, flow_teacher, shift_sigmas=None): if random.random() > 0.5: - img0a, img1a, gta, smask, sxy = _hflip(img0, img1, gt) + img0a, img1a, gta, flow_a, flow_teacher_a, smask = _hflip(img0, img1, gt, flow, flow_teacher) else: - img0a, img1a, gta, smask, sxy = _vflip(img0, img1, gt) - return img0a, img1a, gta, smask, sxy + img0a, img1a, gta, flow_a, flow_teacher_a, smask = _vflip(img0, img1, gt, flow, flow_teacher) + return img0a, img1a, gta, flow_a, flow_teacher_a, smask, 0 + + +def rotater(flow, R): + """ + flow: B, C, H, W (16, 4, 224, 224) tensor + R: (2, 2) rotation matrix, tensor + """ + flow_fst, flow_sec = torch.split(flow, 2, dim=1) + # flow map left multiply by rotation matrix R + flow_fst_rot = torch.einsum('jc, bjhw -> bchw', R, flow_fst) + flow_sec_rot = torch.einsum('jc, bjhw -> bchw', R, flow_sec) + flow_rot = torch.cat((flow_fst_rot, flow_sec_rot), dim=1) + # visualize_flow(flow_fst[0].permute(1, 2, 0), 'flow.png') + # visualize_flow(flow_fst_rot[0].permute(1, 2, 0), 'flow_rotate.png') + # print(flow_fst[0, :, 112, 112]) + # print(flow_fst_rot[0, :, 112, 112]) + return flow_rot -def random_rotate(img0, img1, gt, shift_sigmas=None): +def random_rotate(img0, img1, gt, flow, flow_teacher, shift_sigmas=None): if random.random() < 1/3.: angle = 90 elif random.random() < 2/3.: @@ -183,48 +217,99 @@ def random_rotate(img0, img1, gt, shift_sigmas=None): mask_shape = list(img0.shape) mask_shape[1] = 4 # For 4 flow channels of two directions (2 for each direction). mask = torch.ones(mask_shape, device=img0.device, dtype=bool) - return img0a, img1a, gta, mask, torch.from_numpy(R).to(img0.device) - + R = torch.from_numpy(R).to(img0.device) + # merge flow_handler with aug_handler + # s enumerates all scales. + flow_a = [] + for s in np.arange(len(flow)): + temp = rotate(flow[s], angle=angle) + temp = rotater(temp, R) + flow_a.append(temp) + flow_teacher_a = rotate(flow_teacher, angle=angle) + flow_teacher_a = rotater(flow_teacher_a, R) + return img0a, img1a, gta, flow_a, flow_teacher_a, mask, 0 -def adder(a, b): - return a + b -def multiplier(a, b): - return a * b +def polygons_to_mask(polys, height, width): + """ + Convert polygons to binary masks. + Args: + polys: a list of nx2 float array. Each array contains many (x, y) coordinates. + Returns: + a binary matrix of (height, width) + """ + polys = [p.flatten().tolist() for p in polys] + assert len(polys) > 0, "Polygons are empty!" -def rotater(flow, R): - # flow: B, C, H, W (16, 4, 224, 224) tensor - # R: (2, 2) rotation matrix, tensor - flow_fst, flow_sec = torch.split(flow, 2, dim=1) - # flow map left multiply by rotation matrix R - flow_fst_rot = torch.einsum('jc, bjhw -> bchw', R, flow_fst) - flow_sec_rot = torch.einsum('jc, bjhw -> bchw', R, flow_sec) - flow_rot = torch.cat((flow_fst_rot, flow_sec_rot), dim=1) - # visualize_flow(flow_fst[0].permute(1, 2, 0), 'flow.png') - # visualize_flow(flow_fst_rot[0].permute(1, 2, 0), 'flow_rotate_90.png') - # print(flow_fst[0, :, 0, 0]) - # print(flow_fst_rot[0, :, 0, 0]) - return flow_rot + import pycocotools.mask as cocomask + rles = cocomask.frPyObjects(polys, height, width) + rle = cocomask.merge(rles) + return cocomask.decode(rle).astype(bool) -def calculate_consist_loss(img0, img1, gt, flow, flow_teacher, model, shift_sigmas, aug_handler, flow_handler): - img0a, img1a, gta, smask, dxy = aug_handler(img0, img1, gt, shift_sigmas) +def random_affine(img0, img1, gt, flow, flow_teacher, shift_sigmas=None): + B, C, H, W = img0.shape + # OpenCV uses 3 points to generate an affine matrix. + # https://docs.opencv.org/3.4/d4/d61/tutorial_warp_affine.html + W_shift_ratio_low1 = np.random.uniform(0, 0.3) + W_shift_ratio_low2 = np.random.uniform(0, 0.3) + W_shift_ratio_high = np.random.uniform(0.7, 1) + H_shift_ratio_low1 = np.random.uniform(0, 0.3) + H_shift_ratio_low2 = np.random.uniform(0, 0.3) + H_shift_ratio_high = np.random.uniform(0.7, 1) + srcTri = np.array([[0, 0], [W - 1, 0], [0, H - 1]]).astype(np.float32) # 4th pt: [W-1, H-1] + dstTri = np.array([[W*W_shift_ratio_low1, H*H_shift_ratio_low1], \ + [W*W_shift_ratio_high, H*H_shift_ratio_low2], \ + [W*W_shift_ratio_low2, H*H_shift_ratio_high]]).astype(np.float32) + warp_mat = cv2.getAffineTransform(srcTri, dstTri).astype(np.float32) # (2, 3) affine transformation matrix + # transform 4 vertices to get new vertices of affine transformed image + # new vertices are used to generate mask + fourth_pt = np.matmul(warp_mat, np.array([W-1, H-1, 1])) + fourth_pt[0] = np.clip(fourth_pt[0], 0, W) + fourth_pt[1] = np.clip(fourth_pt[1], 0, H) + polygon = np.ndarray((4, 2), dtype=np.float32) + polygon[:2, :] = dstTri[:2, :] + polygon[2, :] = fourth_pt + polygon[3, :] = dstTri[-1, :] + mask = polygons_to_mask([np.array(polygon, np.float32)], H, W) + # cv2.imwrite('mask.png', mask.astype(float)*255) + mask = torch.from_numpy(mask).to(img0.device).view(1, 1, H, W) + mask = mask.repeat(B, 4, 1, 1) - if dxy is not None: - imgsa = torch.cat((img0a, img1a), 1) - flow2, mask2, merged_img_list2, flow_teacher2, merged_teacher2, loss_distill2 = model(torch.cat((imgsa, gta), 1), scale_list=[4, 2, 1]) - loss_consist_stu = 0 - # s enumerates all scales. - loss_on_scales = np.arange(len(flow)) - for s in loss_on_scales: - loss_consist_stu += torch.abs(flow_handler(flow[s].clone(), dxy) - flow2[s])[smask].mean() + aff0 = np.empty((B, H, W, C), dtype=np.float32) + aff1 = np.empty((B, H, W, C), dtype=np.float32) + aff2 = np.empty((B, H, W, C), dtype=np.float32) + img0_copy = img0.permute(0, 2, 3, 1).cpu().numpy() # B, H, W, C + img1_copy = img1.permute(0, 2, 3, 1).cpu().numpy() + gt_copy = gt.permute(0, 2, 3, 1).cpu().numpy() + for i in range(B): + aff0[i] = cv2.warpAffine(img0_copy[i], warp_mat, (H, W)) + aff1[i] = cv2.warpAffine(img1_copy[i], warp_mat, (H, W)) + aff2[i] = cv2.warpAffine(gt_copy[i], warp_mat, (H, W)) + img0a = torch.from_numpy(aff0).permute(0, 3, 1, 2).to(img0.device) + img1a = torch.from_numpy(aff1).permute(0, 3, 1, 2).to(img0.device) + gta = torch.from_numpy(aff2).permute(0, 3, 1, 2).to(img0.device) + # cv2.imwrite('affine0.png', img0a[0].permute(1, 2, 0).cpu().numpy()*255) + R = torch.from_numpy(warp_mat[:, :2]).to(img0.device) + flow_all_a = [] + flow_all = flow + [flow_teacher] + for s in np.arange(len(flow_all)): + aff4 = np.empty((B, H, W, 4), dtype=np.float32) + flow_copy = flow_all[s].clone().detach().permute(0, 2, 3, 1).cpu().numpy() # B, H, W, 4 + for i in range(B): + aff4[i] = cv2.warpAffine(flow_copy[i], warp_mat, (H, W)) + aff4 = torch.from_numpy(aff4).permute(0, 3, 1, 2).to(img0.device) + # affine transform to flow value + temp = rotater(aff4, R) + flow_all_a.append(temp) + flow_a = flow_all_a[:-1] + flow_teacher_a = flow_all_a[-1] + # visualize_flow(flow[0][0].permute(1, 2, 0), 'flow.png') + # visualize_flow(flow_a[0][0].permute(1, 2, 0), 'flow_affine.png') + return img0a, img1a, gta, flow_a, flow_teacher_a, mask, 0 - loss_consist_tea = torch.abs(flow_handler(flow_teacher.clone(), dxy) - flow_teacher2)[smask].mean() - loss_consist = (loss_consist_stu / len(loss_on_scales) + loss_consist_tea) / 2 - mean_shift = dxy.abs().mean().item() - else: - loss_consist = 0 - mean_shift = 0 - loss_distill2 = 0 - return loss_consist, loss_distill2, mean_shift \ No newline at end of file +def calculate_consist_loss(img0, img1, gt, flow, flow_teacher, shift_sigmas, aug_handler): + assert(aug_handler is not None) + img0a, img1a, gta, flow_a, flow_teacher_a, smask, dxy = aug_handler(img0, img1, gt, flow, flow_teacher, shift_sigmas) + return img0a, img1a, gta, flow_a, flow_teacher_a, smask, dxy \ No newline at end of file From 06b8e65b1a718b2e406e809100cc693b6145f09a Mon Sep 17 00:00:00 2001 From: Yang Hong Date: Wed, 30 Mar 2022 17:17:33 +0800 Subject: [PATCH 2/5] merge flow_handler with aug_handler --- model/RIFE.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/model/RIFE.py b/model/RIFE.py index 99b8a4b..0daa423 100644 --- a/model/RIFE.py +++ b/model/RIFE.py @@ -27,6 +27,7 @@ def __init__(self, local_rank=-1, use_old_model=False, grad_clip=-1, shift_sigmas=(16,10), cons_flip_prob=0, cons_rot_prob=0, + cons_affine_prob=0, consist_loss_weight=0.05, debug=False): #if arbitrary == True: @@ -61,6 +62,7 @@ def __init__(self, local_rank=-1, use_old_model=False, grad_clip=-1, self.shift_sigmas = shift_sigmas self.cons_flip_prob = cons_flip_prob self.cons_rot_prob = cons_rot_prob + self.cons_affine_prob = cons_affine_prob self.consist_loss_weight = consist_loss_weight def train(self): @@ -113,24 +115,38 @@ def update(self, imgs, gt, learning_rate=0, weight_decay=0, mul=1, training=True self.eval() flow, mask, merged_img_list, flow_teacher, merged_teacher, loss_distill = self.flownet(torch.cat((imgs, gt), 1), scale_list=[4, 2, 1]) - args = dict(img0=img0, img1=img1, gt=gt, flow=flow, flow_teacher=flow_teacher, - model=self.flownet, shift_sigmas=self.shift_sigmas) + args = dict(img0=img0, img1=img1, gt=gt, flow=flow, flow_teacher=flow_teacher, shift_sigmas=self.shift_sigmas) if self.cons_shift_prob > 0 and random.random() < self.cons_shift_prob: args["aug_handler"] = random_shift - args["flow_handler"] = adder - loss_consist, loss_distill2, mean_shift = calculate_consist_loss(**args) elif self.cons_flip_prob > 0 and random.random() < self.cons_flip_prob: args["aug_handler"] = random_flip - args["flow_handler"] = multiplier - loss_consist, loss_distill2, mean_shift = calculate_consist_loss(**args) elif self.cons_rot_prob > 0 and random.random() < self.cons_rot_prob: args["aug_handler"] = random_rotate - args["flow_handler"] = rotater - loss_consist, loss_distill2, mean_shift = calculate_consist_loss(**args) + elif self.cons_affine_prob > 0 and random.random() < self.cons_affine_prob: + args["aug_handler"] = random_affine + else: + args["aug_handler"] = None + + if args["aug_handler"] is not None: + img0a, img1a, gta, flow_a, flow_teacher_a, smask, dxy = calculate_consist_loss(**args) + imgsa = torch.cat((img0a, img1a), 1) + flow2, mask2, merged_img_list2, flow_teacher2, merged_teacher2, loss_distill2 = self.flownet(torch.cat((imgsa, gta), 1), scale_list=[4, 2, 1]) + loss_consist_stu = 0 + # s enumerates all scales. + loss_on_scales = np.arange(len(flow)) + for s in loss_on_scales: + loss_consist_stu += torch.abs(flow_a[s] - flow2[s])[smask].mean() + loss_consist_tea = torch.abs(flow_teacher_a - flow_teacher2)[smask].mean() + loss_consist = (loss_consist_stu / len(loss_on_scales) + loss_consist_tea) / 2 + if isinstance(dxy, int): + mean_shift = dxy + else: + mean_shift = dxy.abs().mean().item() else: loss_consist = 0 mean_shift = 0 loss_distill2 = 0 + only_calc_final_loss = True if only_calc_final_loss: stu_pred = merged_img_list[2] From a0515f047547b354c7e50756bf2fab87f294cb5f Mon Sep 17 00:00:00 2001 From: Yang Hong Date: Wed, 30 Mar 2022 17:18:05 +0800 Subject: [PATCH 3/5] add affine prob argument --- train.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/train.py b/train.py index 9f751bb..c0dffbf 100644 --- a/train.py +++ b/train.py @@ -200,6 +200,8 @@ def evaluate(model, val_data, epoch, nr_eval, local_rank, writer_val): help='Probability of flipping consistency loss') parser.add_argument('--consrotprob', dest='cons_rot_prob', default=0.1, type=float, help='Probability of rotating consistency loss') + parser.add_argument('--consaffineprob', dest='cons_affine_prob', default=0.1, type=float, + help='Probability of affine transform for consistency loss') parser.add_argument('--shiftsigmas', dest='shift_sigmas', default="16,10", type=str, help='Stds of shifts for shifting consistency loss') parser.add_argument('--consweight', dest='consist_loss_weight', default=0.02, type=float, @@ -232,6 +234,7 @@ def evaluate(model, val_data, epoch, nr_eval, local_rank, writer_val): shift_sigmas=args.shift_sigmas, cons_flip_prob=args.cons_flip_prob, cons_rot_prob=args.cons_rot_prob, + cons_affine_prob=args.cons_affine_prob, consist_loss_weight=args.consist_loss_weight, debug=args.debug) if args.cp is not None: From 54c706db5f3f303756a8eaad22f30b140b4c3656 Mon Sep 17 00:00:00 2001 From: Yang Hong Date: Wed, 30 Mar 2022 17:18:42 +0800 Subject: [PATCH 4/5] add 'pycocotools' to requirements --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 29f876e..163eb24 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ moviepy>=1.0.3 torchvision>=0.12.0 imgaug>=0.4.0 sk-video -tensorboard \ No newline at end of file +tensorboard +pycocotools \ No newline at end of file From 13d27463c0db4c4a443105d1e7cbd15922d7c709 Mon Sep 17 00:00:00 2001 From: Yang Hong Date: Wed, 30 Mar 2022 17:18:57 +0800 Subject: [PATCH 5/5] update gitignore --- .gitignore | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 1fb7f16..02cc65d 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ *.py# output/* +*.png test/ .idea/ @@ -10,4 +11,6 @@ test/ *.zip -HD_dataset \ No newline at end of file +data +train/ +validate/ \ No newline at end of file