diff --git a/.gitignore b/.gitignore index 6500cd24..d13c4505 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ __pycache__/ *.py[cod] *$py.class - +**.DS_Store # C extensions *.so diff --git a/atommic/collections/common/data/mri_loader.py b/atommic/collections/common/data/mri_loader.py index fac98ee1..b6b80ab0 100644 --- a/atommic/collections/common/data/mri_loader.py +++ b/atommic/collections/common/data/mri_loader.py @@ -226,7 +226,9 @@ def __init__( # noqa: MC0001 self.examples = [ex for ex in self.examples if ex[2]["encoding_size"][1] in num_cols] self.indices_to_log = np.random.choice( - len(self.examples), int(log_images_rate * len(self.examples)), replace=False # type: ignore + [example[1] for example in self.examples], + int(log_images_rate * len(self.examples)), # type: ignore + replace=False, ) def _retrieve_metadata(self, fname: Union[str, Path]) -> Tuple[Dict, int]: diff --git a/atommic/collections/multitask/rs/data/mrirs_loader.py b/atommic/collections/multitask/rs/data/mrirs_loader.py index 5a8c80e5..be388fbc 100644 --- a/atommic/collections/multitask/rs/data/mrirs_loader.py +++ b/atommic/collections/multitask/rs/data/mrirs_loader.py @@ -416,6 +416,8 @@ def __getitem__(self, i: int): # noqa: MC0001 kspace = kspace[:, :, 0, :] + kspace[:, :, 1, :] elif not is_none(dataset_format) and dataset_format == "skm-tea-echo1+echo2-mc": kspace = np.concatenate([kspace[:, :, 0, :], kspace[:, :, 1, :]], axis=-1) + elif not is_none(dataset_format) and dataset_format == "skm-tea-echo1-echo2": + kspace = kspace else: warnings.warn( f"Dataset format {dataset_format} is either not supported or set to None. " @@ -423,12 +425,13 @@ def __getitem__(self, i: int): # noqa: MC0001 ) kspace = kspace[:, :, 0, :] - kspace = kspace[48:-48, 40:-40] - sensitivity_map = self.get_consecutive_slices(hf, "maps", dataslice).astype(np.complex64) - sensitivity_map = sensitivity_map[..., 0] - - sensitivity_map = sensitivity_map[48:-48, 40:-40] + if self.consecutive_slices > 1: + sensitivity_map = sensitivity_map[:, 48:-48, 40:-40] + kspace = kspace[:, 48:-48, 40:-40] + else: + sensitivity_map = sensitivity_map[48:-48, 40:-40] + kspace = kspace[48:-48, 40:-40] if masking == "custom": mask = np.array([]) @@ -470,22 +473,17 @@ def __getitem__(self, i: int): # noqa: MC0001 # combine Lateral Meniscus and Medial Meniscus medial_meniscus = lateral_meniscus + medial_meniscus - if self.consecutive_slices > 1: - segmentation_labels_dim = 1 - else: - segmentation_labels_dim = 0 - # stack the labels in the last dimension segmentation_labels = np.stack( [patellar_cartilage, femoral_cartilage, tibial_cartilage, medial_meniscus], - axis=segmentation_labels_dim, + axis=-1, ) # TODO: This is hardcoded on the SKM-TEA side, how to generalize this? # We need to crop the segmentation labels in the frequency domain to reduce the FOV. - segmentation_labels = np.fft.fftshift(np.fft.fft2(segmentation_labels)) - segmentation_labels = segmentation_labels[:, 48:-48, 40:-40] - segmentation_labels = np.fft.ifft2(np.fft.ifftshift(segmentation_labels)).real + segmentation_labels = np.fft.fftshift(np.fft.fft2(segmentation_labels, axes=(-3, -2))) + segmentation_labels = segmentation_labels[..., 48:-48, 40:-40, :] + segmentation_labels = np.fft.ifft2(np.fft.ifftshift(segmentation_labels), axes=(-3, -2)).real imspace = np.empty([]) @@ -499,12 +497,24 @@ def __getitem__(self, i: int): # noqa: MC0001 metadata["noise"] = 1.0 attrs.update(metadata) - - kspace = np.transpose(kspace, (2, 0, 1)) - sensitivity_map = np.transpose(sensitivity_map.squeeze(), (2, 0, 1)) - + if not is_none(dataset_format) and dataset_format == "skm-tea-echo1-echo2": + if self.consecutive_slices > 1: + segmentation_labels = np.transpose(segmentation_labels, (0, 3, 1, 2)) + kspace = np.transpose(kspace, (3, 0, 4, 1, 2)) + sensitivity_map = np.transpose(sensitivity_map, (4, 0, 3, 1, 2)) + else: + segmentation_labels = np.transpose(segmentation_labels, (2, 0, 1)) + kspace = np.transpose(kspace, (2, 3, 0, 1)) + sensitivity_map = np.transpose(sensitivity_map, (3, 2, 0, 1)) + elif self.consecutive_slices > 1 and not is_none(dataset_format) and dataset_format != "skm-tea-echo1-echo2": + segmentation_labels = np.transpose(segmentation_labels, (0, 3, 1, 2)) + kspace = np.transpose(kspace, (0, 3, 1, 2)) + sensitivity_map = np.transpose(sensitivity_map.squeeze(), (0, 3, 1, 2)) + else: + segmentation_labels = np.transpose(segmentation_labels, (2, 0, 1)) + kspace = np.transpose(kspace, (2, 0, 1)) + sensitivity_map = np.transpose(sensitivity_map.squeeze(), (2, 0, 1)) attrs["log_image"] = bool(dataslice in self.indices_to_log) - return ( ( kspace, diff --git a/atommic/collections/multitask/rs/nn/base.py b/atommic/collections/multitask/rs/nn/base.py index 5a0b506a..771e3213 100644 --- a/atommic/collections/multitask/rs/nn/base.py +++ b/atommic/collections/multitask/rs/nn/base.py @@ -67,7 +67,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # noqa: MC0001 # Initialize the dimensionality of the data. It can be 2D or 2.5D -> meaning 2D with > 1 slices or 3D. self.dimensionality = cfg_dict.get("dimensionality", 2) self.consecutive_slices = cfg_dict.get("consecutive_slices", 1) - + self.num_echoes = cfg_dict.get("num_echoes", 1) # Initialize the coil combination method. It can be either "SENSE" or "RSS" (root-sum-of-squares) or # "RSS-complex" (root-sum-of-squares of the complex-valued data). self.coil_combination_method = cfg_dict.get("coil_combination_method", "SENSE") @@ -601,9 +601,6 @@ def __compute_loss__( If self.accumulate_loss is True, returns an accumulative result of all intermediate losses. Otherwise, returns the loss of the last intermediate loss. """ - if self.consecutive_slices > 1: - batch_size, slices = target_segmentation.shape[:2] - target_segmentation = target_segmentation.reshape(batch_size * slices, *target_segmentation.shape[2:]) segmentation_loss = self.process_segmentation_loss(target_segmentation, predictions_segmentation, attrs) @@ -675,27 +672,31 @@ def __compute_and_log_metrics_and_outputs__( # noqa: MC0001 if isinstance(predictions_segmentation, list): while isinstance(predictions_segmentation, list): predictions_segmentation = predictions_segmentation[-1] - if self.consecutive_slices > 1: # reshape the target and prediction to [batch_size, self.consecutive_slices, nr_classes, n_x, n_y] - batch_size = target_segmentation.shape[0] // self.consecutive_slices - target_segmentation = target_segmentation.reshape( - batch_size, self.consecutive_slices, *target_segmentation.shape[1:] - ) - target_reconstruction = target_reconstruction.reshape( - batch_size, self.consecutive_slices, *target_reconstruction.shape[2:] - ) + batch_size = int(target_segmentation.shape[0] / self.consecutive_slices) predictions_segmentation = predictions_segmentation.reshape( - batch_size, self.consecutive_slices, *predictions_segmentation.shape[2:] + batch_size, self.consecutive_slices, *predictions_segmentation.shape[1:] ) - predictions_reconstruction = predictions_reconstruction.reshape( - batch_size, self.consecutive_slices, *predictions_reconstruction.shape[1:] + target_segmentation = target_segmentation.reshape( + batch_size, self.consecutive_slices, *target_segmentation.shape[1:] ) + target_segmentation = target_segmentation[:, self.consecutive_slices // 2] target_reconstruction = target_reconstruction[:, self.consecutive_slices // 2] predictions_segmentation = predictions_segmentation[:, self.consecutive_slices // 2] predictions_reconstruction = predictions_reconstruction[:, self.consecutive_slices // 2] + if self.num_echoes > 1: + # find the batch size + batch_size = target_reconstruction.shape[0] / self.num_echoes + # reshape to [batch_size, num_echoes, n_x, n_y] + target_reconstruction = target_reconstruction.reshape( + (int(batch_size), self.num_echoes, *target_reconstruction.shape[1:]) + ) + predictions_reconstruction = predictions_reconstruction.reshape( + (int(batch_size), self.num_echoes, *predictions_reconstruction.shape[1:]) + ) fname = attrs["fname"] slice_idx = attrs["slice_idx"] @@ -734,11 +735,6 @@ def __compute_and_log_metrics_and_outputs__( # noqa: MC0001 batch_idx=_batch_idx_, ) - output_predictions_reconstruction = output_predictions_reconstruction.detach().cpu() - output_target_reconstruction = output_target_reconstruction.detach().cpu() - output_target_segmentation = output_target_segmentation.detach().cpu() - output_predictions_segmentation = output_predictions_segmentation.detach().cpu() - # Normalize target and predictions to [0, 1] for logging. if torch.is_complex(output_target_reconstruction) and output_target_reconstruction.shape[-1] != 2: output_target_reconstruction = torch.view_as_real(output_target_reconstruction) @@ -747,7 +743,6 @@ def __compute_and_log_metrics_and_outputs__( # noqa: MC0001 output_target_reconstruction = output_target_reconstruction / torch.max( torch.abs(output_target_reconstruction) ) - output_target_reconstruction = output_target_reconstruction.detach().cpu() if ( torch.is_complex(output_predictions_reconstruction) @@ -759,7 +754,11 @@ def __compute_and_log_metrics_and_outputs__( # noqa: MC0001 output_predictions_reconstruction = output_predictions_reconstruction / torch.max( torch.abs(output_predictions_reconstruction) ) - output_predictions_reconstruction = output_predictions_reconstruction.detach().cpu() + + output_predictions_reconstruction = output_predictions_reconstruction.detach().cpu().float() + output_target_reconstruction = output_target_reconstruction.detach().cpu().float() + output_target_segmentation = output_target_segmentation.detach().cpu().float() + output_predictions_segmentation = output_predictions_segmentation.detach().cpu().float() # Log target and predictions, if log_image is True for this slice. if attrs["log_image"][_batch_idx_]: @@ -772,17 +771,33 @@ def __compute_and_log_metrics_and_outputs__( # noqa: MC0001 ) if self.use_reconstruction_module: - self.log_image( - f"{key}/a/reconstruction/target/predictions/error", - torch.cat( - [ - output_target_reconstruction, - output_predictions_reconstruction, - torch.abs(output_target_reconstruction - output_predictions_reconstruction), - ], - dim=-1, - ), - ) + if self.num_echoes > 1: + for i in range(output_target_reconstruction.shape[0]): + self.log_image( + f"{key}/a/reconstruction_abs/target echo: {i+1}/predictions echo: {i+1}/error echo: {i+1}", + torch.cat( + [ + output_target_reconstruction[i], + output_predictions_reconstruction[i], + torch.abs( + output_target_reconstruction[i] - output_predictions_reconstruction[i] + ), + ], + dim=-1, + ), + ) + else: + self.log_image( + f"{key}/a/reconstruction_abs/target/predictions/error", + torch.cat( + [ + output_target_reconstruction, + output_predictions_reconstruction, + torch.abs(output_target_reconstruction - output_predictions_reconstruction), + ], + dim=-1, + ), + ) # concatenate the segmentation classes for logging target_segmentation_class = torch.cat( @@ -1120,7 +1135,16 @@ def inference_step( # noqa: MC0001 self.coil_combination_method, self.coil_dim, ) - + if self.num_echoes > 1: + # stack the echoes along the batch dimension + kspace = kspace.view(-1, *kspace.shape[2:]) + y = y.view(-1, *y.shape[2:]) + mask = mask.view(-1, *mask.shape[2:]) + initial_prediction_reconstruction = initial_prediction_reconstruction.view( + -1, *initial_prediction_reconstruction.shape[2:] + ) + target_reconstruction = target_reconstruction.view(-1, *target_reconstruction.shape[2:]) + sensitivity_maps = torch.repeat_interleave(sensitivity_maps, repeats=kspace.shape[0], dim=0).squeeze(1) # Model forward pass predictions_reconstruction, predictions_segmentation = self.forward( y, @@ -1130,6 +1154,19 @@ def inference_step( # noqa: MC0001 target_reconstruction, attrs["noise"], ) + if self.consecutive_slices > 1: + ## reshape the target and prediction segmentation to [batch_size * consecutive_slices, nr_classes, n_x, n_y] + batch_size, slices = target_segmentation.shape[:2] + target_segmentation = target_segmentation.reshape(batch_size * slices, *target_segmentation.shape[2:]) + if isinstance(predictions_segmentation, list): + for i, prediction_segmentation in enumerate(predictions_segmentation): + predictions_segmentation[i] = prediction_segmentation.reshape( + batch_size * slices, *prediction_segmentation.shape[2:] + ) + else: + predictions_segmentation = predictions_segmentation.reshape( + batch_size * slices, *predictions_segmentation.shape[2:] + ) if not is_none(self.segmentation_classes_thresholds): for class_idx, thres in enumerate(self.segmentation_classes_thresholds): @@ -1482,6 +1519,26 @@ def test_step(self, batch: Dict[float, torch.Tensor], batch_idx: int): while isinstance(predictions_reconstruction, list): predictions_reconstruction = predictions_reconstruction[-1] + if self.consecutive_slices > 1: + # reshape the target and prediction to [batch_size, self.consecutive_slices, nr_classes, n_x, n_y] + batch_size = int(target_segmentation.shape[0] / self.consecutive_slices) + predictions_segmentation = predictions_segmentation.reshape( + batch_size, self.consecutive_slices, *predictions_segmentation.shape[1:] + ) + predictions_segmentation = predictions_segmentation[:, self.consecutive_slices // 2] + predictions_reconstruction = predictions_reconstruction[:, self.consecutive_slices // 2] + + if self.num_echoes > 1: + # find the batch size + batch_size = target_reconstruction.shape[0] / self.num_echoes + # reshape to [batch_size, num_echoes, n_x, n_y] + target_reconstruction = target_reconstruction.reshape( + (int(batch_size), self.num_echoes, *target_reconstruction.shape[1:]) + ) + predictions_reconstruction = predictions_reconstruction.reshape( + (int(batch_size), self.num_echoes, *predictions_reconstruction.shape[1:]) + ) + # If "16" or "16-mixed" fp is used, ensure complex type will be supported when saving the predictions. predictions_reconstruction = ( torch.view_as_complex(torch.view_as_real(predictions_reconstruction).type(torch.float32)) @@ -1670,10 +1727,10 @@ def on_test_epoch_end(self): # noqa: MC0001 for fname in segmentations: segmentations[fname] = np.stack([out for _, out in sorted(segmentations[fname])]) - if self.consecutive_slices > 1: - # iterate over the slices and always keep the middle slice - for fname in segmentations: - segmentations[fname] = segmentations[fname][:, self.consecutive_slices // 2] + # if self.consecutive_slices > 1: + # # iterate over the slices and always keep the middle slice + # for fname in segmentations: + # segmentations[fname] = segmentations[fname][:, self.consecutive_slices // 2] #TODO remove, is already done in the test_step to minimize memory load if self.use_reconstruction_module: reconstructions = defaultdict(list) @@ -1684,10 +1741,10 @@ def on_test_epoch_end(self): # noqa: MC0001 for fname in reconstructions: reconstructions[fname] = np.stack([out for _, out in sorted(reconstructions[fname])]) - if self.consecutive_slices > 1: - # iterate over the slices and always keep the middle slice - for fname in reconstructions: - reconstructions[fname] = reconstructions[fname][:, self.consecutive_slices // 2] + # if self.consecutive_slices > 1: #TODO remove, is already done in the test_step to minimize memory load + # # iterate over the slices and always keep the middle slice + # for fname in reconstructions: + # reconstructions[fname] = reconstructions[fname][:, self.consecutive_slices // 2] else: reconstructions = None @@ -1752,6 +1809,7 @@ def _setup_dataloader_from_config(cfg: DictConfig) -> DataLoader: "skm-tea-echo2", "skm-tea-echo1+echo2", "skm-tea-echo1+echo2-mc", + "skm-tea-echo1-echo2", ): dataloader = mrirs_loader.SKMTEARSMRIDataset else: diff --git a/atommic/collections/multitask/rs/nn/mtlrs.py b/atommic/collections/multitask/rs/nn/mtlrs.py index 81daa2d4..ab93297e 100644 --- a/atommic/collections/multitask/rs/nn/mtlrs.py +++ b/atommic/collections/multitask/rs/nn/mtlrs.py @@ -85,6 +85,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.coil_dim = cfg_dict.get("coil_dim", 1) self.consecutive_slices = cfg_dict.get("consecutive_slices", 1) + self.num_echoes self.rs_cascades = cfg_dict.get("joint_reconstruction_segmentation_module_cascades", 1) self.rs_module = torch.nn.ModuleList( @@ -102,6 +103,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): consecutive_slices=self.consecutive_slices, coil_combination_method=cfg_dict.get("coil_combination_method", "SENSE"), normalize_segmentation_output=cfg_dict.get("normalize_segmentation_output", True), + num_echoes=self.num_echoes, ) for _ in range(self.rs_cascades) ] @@ -170,9 +172,6 @@ def forward( if f != 0 ] - if self.consecutive_slices > 1: - hx = [x.unsqueeze(1) for x in hx] - # Check if the concatenated hidden states are the same size as the hidden state of the RNN if hidden_states[0].shape[self.coil_dim] != hx[0].shape[self.coil_dim]: prev_hidden_states = hidden_states @@ -189,7 +188,6 @@ def forward( hx = [hx[i] + hidden_states[i] for i in range(len(hx))] init_reconstruction_pred = torch.view_as_real(init_reconstruction_pred) - return pred_reconstructions, pred_segmentation def process_reconstruction_loss( # noqa: MC0001 @@ -284,19 +282,31 @@ def compute_reconstruction_loss(t, p, s): return loss_func(t, p) - if self.accumulate_predictions: + if self.reconstruction_module_accumulate_predictions: rs_cascades_weights = torch.logspace(-1, 0, steps=len(prediction)).to(target.device) rs_cascades_loss = [] for rs_cascade_pred in prediction: cascades_weights = torch.logspace(-1, 0, steps=len(rs_cascade_pred)).to(target.device) cascades_loss = [] for cascade_pred in rs_cascade_pred: - time_steps_weights = torch.logspace(-1, 0, steps=self.time_steps).to(target.device) - time_steps_loss = [ - compute_reconstruction_loss(target, time_step_pred, sensitivity_maps) - for time_step_pred in cascade_pred - ] - cascade_loss = sum(x * w for x, w in zip(time_steps_loss, time_steps_weights)) / self.time_steps + time_steps_weights = torch.logspace(-1, 0, steps=len(cascade_pred)).to(target.device) + if self.consecutive_slices > 1: + time_steps_loss = [ + compute_reconstruction_loss( + target.reshape(target.shape[0] * target.shape[1], *target.shape[2:]), + time_step_pred.reshape( + time_step_pred.shape[0] * time_step_pred.shape[1], *time_step_pred.shape[2:] + ), + sensitivity_maps, + ) + for time_step_pred in cascade_pred + ] + else: + time_steps_loss = [ + compute_reconstruction_loss(target, time_step_pred, sensitivity_maps) + for time_step_pred in cascade_pred + ] + cascade_loss = sum(x * w for x, w in zip(time_steps_loss, time_steps_weights)) / len(cascade_pred) cascades_loss.append(cascade_loss) rs_cascade_loss = sum(x * w for x, w in zip(cascades_loss, cascades_weights)) / len(rs_cascade_pred) rs_cascades_loss.append(rs_cascade_loss) @@ -304,5 +314,15 @@ def compute_reconstruction_loss(t, p, s): else: # keep the last prediction of the last cascade of the last rs cascade prediction = prediction[-1][-1][-1] - loss = compute_reconstruction_loss(target, prediction, sensitivity_maps) + if self.consecutive_slices > 1: + loss = compute_reconstruction_loss( + target.reshape(target.shape[0] * target.shape[1], *target.shape[2:]), + prediction.reshape(prediction.shape[0] * prediction.shape[1], *prediction.shape[2:]), + sensitivity_maps, + ) + loss = compute_reconstruction_loss( + target, + prediction, + sensitivity_maps, + ) return loss diff --git a/atommic/collections/multitask/rs/nn/mtlrs_base/mtlrs_block.py b/atommic/collections/multitask/rs/nn/mtlrs_base/mtlrs_block.py index 800999c2..aa34442a 100644 --- a/atommic/collections/multitask/rs/nn/mtlrs_base/mtlrs_block.py +++ b/atommic/collections/multitask/rs/nn/mtlrs_base/mtlrs_block.py @@ -42,6 +42,7 @@ def __init__( consecutive_slices: int = 1, coil_combination_method: str = "SENSE", normalize_segmentation_output: bool = True, + num_echoes: int = 1, ): """Inits :class:`MTLRSBlock`. @@ -184,7 +185,7 @@ def __init__( else: raise ValueError(f"Segmentation module {segmentation_module} not implemented.") self.segmentation_module = segmentation_module - + self.num_echoes = num_echoes self.normalize_segmentation_output = normalize_segmentation_output def forward( # noqa: MC0001 @@ -224,34 +225,47 @@ def forward( # noqa: MC0001 if self.consecutive_slices > 1 and self.reconstruction_module_dimensionality == 2: # Do per slice reconstruction pred_reconstruction_slices = [] + temp_hx = [] for slice_idx in range(self.consecutive_slices): y_slice = y[:, slice_idx, ...] prediction_slice = y_slice.clone() sensitivity_maps_slice = sensitivity_maps[:, slice_idx, ...] - mask_slice = mask[:, 0, ...] + if mask.dim() == 1: + mask_slice = mask + else: + mask_slice = mask[:, 0, ...] init_reconstruction_pred_slice = init_reconstruction_pred[:, slice_idx, ...] _pred_reconstruction_slice = ( None if init_reconstruction_pred_slice is None or init_reconstruction_pred_slice.dim() < 4 else init_reconstruction_pred_slice ) + if isinstance(hx, list): + if not hx[0].shape: + hx_slice = hx + else: + hx_slice = [x[:, slice_idx, ...] for x in hx] + else: + hx_slice = hx cascades_predictions = [] for i, cascade in enumerate(self.reconstruction_module): # Forward pass through the cascades - prediction_slice, hx = cascade( + prediction_slice, hx_slice = cascade( prediction_slice, y_slice, sensitivity_maps_slice, mask_slice, - _pred_reconstruction_slice, - hx, + _pred_reconstruction_slice if i == 0 else prediction_slice[-1], + hx_slice, sigma, keep_prediction=False if i == 0 else self.keep_prediction, ) time_steps_predictions = [torch.view_as_complex(pred) for pred in prediction_slice] cascades_predictions.append(torch.stack(time_steps_predictions, dim=0)) pred_reconstruction_slices.append(torch.stack(cascades_predictions, dim=0)) + temp_hx.append(torch.stack(hx_slice, dim=0)) preds = torch.stack(pred_reconstruction_slices, dim=3) + hx = torch.stack(temp_hx, dim=2) cascades_predictions = [ [ @@ -292,6 +306,8 @@ def forward( # noqa: MC0001 _pred_reconstruction = _pred_reconstruction[-1] if _pred_reconstruction.shape[-1] != 2: _pred_reconstruction = torch.view_as_real(_pred_reconstruction) + if self.num_echoes > 1: + _pred_reconstruction = torch.sum(_pred_reconstruction, dim=0, keepdim=True) if self.consecutive_slices > 1 and _pred_reconstruction.dim() == 5: _pred_reconstruction = _pred_reconstruction.reshape( _pred_reconstruction.shape[0] * _pred_reconstruction.shape[1], @@ -319,10 +335,10 @@ def forward( # noqa: MC0001 ) pred_segmentation = torch.abs(pred_segmentation) - if self.consecutive_slices > 1: # get batch size and number of slices from y, because if the reconstruction module is used they will # not be saved before - pred_segmentation = pred_segmentation.view([y.shape[0], y.shape[1], *pred_segmentation.shape[1:]]) - + pred_segmentation = pred_segmentation.view( + [int(y.shape[0] / self.num_echoes), y.shape[1], *pred_segmentation.shape[1:]] + ) return pred_reconstruction, pred_segmentation, hx # type: ignore diff --git a/atommic/collections/reconstruction/data/mri_reconstruction_loader.py b/atommic/collections/reconstruction/data/mri_reconstruction_loader.py index ef602ee6..69c81b64 100644 --- a/atommic/collections/reconstruction/data/mri_reconstruction_loader.py +++ b/atommic/collections/reconstruction/data/mri_reconstruction_loader.py @@ -692,6 +692,8 @@ def __getitem__(self, i: int): # noqa: MC0001 kspace = kspace[:, :, 0, :] + kspace[:, :, 1, :] elif not is_none(dataset_format) and dataset_format == "skm-tea-echo1+echo2-mc": kspace = np.concatenate([kspace[:, :, 0, :], kspace[:, :, 1, :]], axis=-1) + elif not is_none(dataset_format) and dataset_format == "skm-tea-echo1-echo2": + kspace = kspace else: warnings.warn( f"Dataset format {dataset_format} is either not supported or set to None. " diff --git a/atommic/collections/reconstruction/losses/ssim.py b/atommic/collections/reconstruction/losses/ssim.py index 67d6ed9c..964dd960 100644 --- a/atommic/collections/reconstruction/losses/ssim.py +++ b/atommic/collections/reconstruction/losses/ssim.py @@ -56,6 +56,11 @@ def forward(self, X: torch.Tensor, Y: torch.Tensor, data_range: torch.Tensor = N if not isinstance(self.w, torch.Tensor): # type: ignore # pylint: disable=access-member-before-definition raise AssertionError + if X.dim() == 3: + X = X.unsqueeze(1) + if Y.dim() == 3: + Y = Y.unsqueeze(1) + # This is necessary to first assign self.w to CUDA and then in case of fp32 to avoid RuntimeError: Inference # tensors cannot be saved for backward. self.w = self.w.to(Y).clone() # type: ignore diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 95865e4d..ba9f35a4 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,7 +1,7 @@ defusedxml>=0.7.1 einops>=0.5.0 h5py==3.9.0 -huggingface_hub +huggingface_hub<=0.20.3 hydra-core>1.3,<=1.3.2 nibabel==5.1.0 numba diff --git a/tests/collections/multitask/rs/models/test_mtlrs.py b/tests/collections/multitask/rs/models/test_mtlrs.py index aa88845d..675dd980 100644 --- a/tests/collections/multitask/rs/models/test_mtlrs.py +++ b/tests/collections/multitask/rs/models/test_mtlrs.py @@ -222,6 +222,76 @@ "max_steps": -1, }, ), + ( + [2, 3, 32, 16, 2], + { + "use_reconstruction_module": True, + "task_adaption_type": "multi_task_learning", + "joint_reconstruction_segmentation_module_cascades": 5, + "reconstruction_module_recurrent_layer": "IndRNN", + "reconstruction_module_conv_filters": [64, 64, 2], + "reconstruction_module_conv_kernels": [5, 3, 3], + "reconstruction_module_conv_dilations": [1, 2, 1], + "reconstruction_module_conv_bias": [True, True, False], + "reconstruction_module_recurrent_filters": [64, 64, 0], + "reconstruction_module_recurrent_kernels": [1, 1, 0], + "reconstruction_module_recurrent_dilations": [1, 1, 0], + "reconstruction_module_recurrent_bias": [True, True, False], + "reconstruction_module_depth": 2, + "reconstruction_module_conv_dim": 2, + "reconstruction_module_time_steps": 8, + "reconstruction_module_num_cascades": 5, + "reconstruction_module_dimensionality": 2, + "reconstruction_module_accumulate_predictions": True, + "reconstruction_module_no_dc": True, + "reconstruction_module_keep_prediction": True, + "reconstruction_loss": {"l1": 1.0, "ssim": 1.0}, + "segmentation_module": "UNet", + "segmentation_module_input_channels": 2, + "segmentation_module_output_channels": 4, + "segmentation_module_channels": 64, + "segmentation_module_pooling_layers": 4, + "segmentation_module_dropout": 0.0, + "segmentation_loss": {"dice": 1.0}, + "dice_loss_include_background": False, + "dice_loss_to_onehot_y": False, + "dice_loss_sigmoid": True, + "dice_loss_softmax": False, + "dice_loss_other_act": None, + "dice_loss_squared_pred": False, + "dice_loss_jaccard": False, + "dice_loss_reduction": "mean", + "dice_loss_smooth_nr": 1, + "dice_loss_smooth_dr": 1, + "dice_loss_batch": True, + "consecutive_slices": 3, + "coil_combination_method": "SENSE", + "magnitude_input": False, + "use_sens_net": False, + "fft_centered": False, + "fft_normalization": "backward", + "spatial_dims": [-2, -1], + "coil_dim": 2, + "dimensionality": 2, + "num_echoes": 2, + }, + [0.08], + [4], + 2, + 4, + { + "strategy": "ddp", + "accelerator": "cpu", + "num_nodes": 1, + "max_epochs": 20, + "precision": 32, + "enable_checkpointing": False, + "logger": False, + "log_every_n_steps": 50, + "check_val_every_n_epoch": -1, + "max_steps": -1, + }, + ), ], ) def test_mtlmrirs(shape, cfg, center_fractions, accelerations, dimensionality, segmentation_classes, trainer): @@ -256,13 +326,13 @@ def test_mtlmrirs(shape, cfg, center_fractions, accelerations, dimensionality, s output = torch.cat(outputs) mask = torch.cat(masks) - coil_dim = cfg.get("coil_dim") consecutive_slices = cfg.get("consecutive_slices") + num_echoes = cfg.get("num_echoes") if consecutive_slices > 1: x = torch.stack([x for _ in range(consecutive_slices)], 1) + mask = torch.stack([mask for _ in range(consecutive_slices)], 1) output = torch.stack([output for _ in range(consecutive_slices)], 1) - cfg = OmegaConf.create(cfg) cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True)) @@ -320,6 +390,8 @@ def test_mtlmrirs(shape, cfg, center_fractions, accelerations, dimensionality, s pred_segmentation = pred_segmentation.reshape( pred_segmentation.shape[0] * pred_segmentation.shape[1], *pred_segmentation.shape[2:] ) + if num_echoes > 1: # Model only makes one segmentation of multiple echoes + pred_segmentation = pred_segmentation.repeat(num_echoes, 1, 1, 1) if pred_segmentation.shape != output.shape: raise AssertionError else: