8000 Fixes issues related to using consecutive slices in MTLRS and SKMTEA by TimPaquaij · Pull Request #8 · wdika/atommic · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Fixes issues related to using consecutive slices in MTLRS and SKMTEA #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
__pycache__/
*.py[cod]
*$py.class

**.DS_Store
# C extensions
*.so

Expand Down
4 changes: 3 additions & 1 deletion atommic/collections/common/data/mri_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ def __init__( # noqa: MC0001
self.examples = [ex for ex in self.examples if ex[2]["encoding_size"][1] in num_cols]

self.indices_to_log = np.random.choice(
len(self.examples), int(log_images_rate * len(self.examples)), replace=False # type: ignore
[example[1] for example in self.examples],
int(log_images_rate * len(self.examples)), # type: ignore
replace=False,
)

def _retrieve_metadata(self, fname: Union[str, Path]) -> Tuple[Dict, int]:
Expand Down
48 changes: 29 additions & 19 deletions atommic/collections/multitask/rs/data/mrirs_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,19 +416,22 @@
kspace = kspace[:, :, 0, :] + kspace[:, :, 1, :]
elif not is_none(dataset_format) and dataset_format == "skm-tea-echo1+echo2-mc":
kspace = np.concatenate([kspace[:, :, 0, :], kspace[:, :, 1, :]], axis=-1)
elif not is_none(dataset_format) and dataset_format == "skm-tea-echo1-echo2":
kspace = kspace

Check failure

Code scanning / CodeQL

Redundant assignment Error

This assignment assigns a variable to itself.
else:
warnings.warn(
f"Dataset format {dataset_format} is either not supported or set to None. "
"Using by default only the first echo."
)
kspace = kspace[:, :, 0, :]

kspace = kspace[48:-48, 40:-40]

sensitivity_map = self.get_consecutive_slices(hf, "maps", dataslice).astype(np.complex64)
sensitivity_map = sensitivity_map[..., 0]

sensitivity_map = sensitivity_map[48:-48, 40:-40]
if self.consecutive_slices > 1:
sensitivity_map = sensitivity_map[:, 48:-48, 40:-40]
kspace = kspace[:, 48:-48, 40:-40]
else:
sensitivity_map = sensitivity_map[48:-48, 40:-40]
kspace = kspace[48:-48, 40:-40]

if masking == "custom":
mask = np.array([])
Expand Down Expand Up @@ -470,22 +473,17 @@
# combine Lateral Meniscus and Medial Meniscus
medial_meniscus = lateral_meniscus + medial_meniscus

if self.consecutive_slices > 1:
segmentation_labels_dim = 1
else:
segmentation_labels_dim = 0

# stack the labels in the last dimension
segmentation_labels = np.stack(
[patellar_cartilage, femoral_cartilage, tibial_cartilage, medial_meniscus],
axis=segmentation_labels_dim,
axis=-1,
)

# TODO: This is hardcoded on the SKM-TEA side, how to generalize this?
# We need to crop the segmentation labels in the frequency domain to reduce the FOV.
segmentation_labels = np.fft.fftshift(np.fft.fft2(segmentation_labels))
segmentation_labels = segmentation_labels[:, 48:-48, 40:-40]
segmentation_labels = np.fft.ifft2(np.fft.ifftshift(segmentation_labels)).real
segmentation_labels = np.fft.fftshift(np.fft.fft2(segmentation_labels, axes=(-3, -2)))
segmentation_labels = segmentation_labels[..., 48:-48, 40:-40, :]
segmentation_labels = np.fft.ifft2(np.fft.ifftshift(segmentation_labels), axes=(-3, -2)).real

imspace = np.empty([])

Expand All @@ -499,12 +497,24 @@
metadata["noise"] = 1.0

attrs.update(metadata)

kspace = np.transpose(kspace, (2, 0, 1))
sensitivity_map = np.transpose(sensitivity_map.squeeze(), (2, 0, 1))

if not is_none(dataset_format) and dataset_format == "skm-tea-echo1-echo2":
if self.consecutive_slices > 1:
segmentation_labels = np.transpose(segmentation_labels, (0, 3, 1, 2))
kspace = np.transpose(kspace, (3, 0, 4, 1, 2))
sensitivity_map = np.transpose(sensitivity_map, (4, 0, 3, 1, 2))
else:
segmentation_labels = np.transpose(segmentation_labels, (2, 0, 1))
kspace = np.transpose(kspace, (2, 3, 0, 1))
sensitivity_map = np.transpose(sensitivity_map, (3, 2, 0, 1))
elif self.consecutive_slices > 1 and not is_none(dataset_format) and dataset_format != "skm-tea-echo1-echo2":
segmentation_labels = np.transpose(segmentation_labels, (0, 3, 1, 2))
kspace = np.transpose(kspace, (0, 3, 1, 2))
sensitivity_map = np.transpose(sensitivity_map.squeeze(), (0, 3, 1, 2))
else:
segmentation_labels = np.transpose(segmentation_labels, (2, 0, 1))
kspace = np.transpose(kspace, (2, 0, 1))
sensitivity_map = np.transpose(sensitivity_map.squeeze(), (2, 0, 1))
attrs["log_image"] = bool(dataslice in self.indices_to_log)

return (
(
kspace,
Expand Down
142 changes: 100 additions & 42 deletions atommic/collections/multitask/rs/nn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
# Initialize the dimensionality of the data. It can be 2D or 2.5D -> meaning 2D with > 1 slices or 3D.
self.dimensionality = cfg_dict.get("dimensionality", 2)
self.consecutive_slices = cfg_dict.get("consecutive_slices", 1)

self.num_echoes = cfg_dict.get("num_echoes", 1)
# Initialize the coil combination method. It can be either "SENSE" or "RSS" (root-sum-of-squares) or
# "RSS-complex" (root-sum-of-squares of the complex-valued data).
self.coil_combination_method = cfg_dict.get("coil_combination_method", "SENSE")
Expand Down Expand Up @@ -601,9 +601,6 @@
If self.accumulate_loss is True, returns an accumulative result of all intermediate losses.
Otherwise, returns the loss of the last intermediate loss.
"""
if self.consecutive_slices > 1:
batch_size, slices = target_segmentation.shape[:2]
target_segmentation = target_segmentation.reshape(batch_size * slices, *target_segmentation.shape[2:])

segmentation_loss = self.process_segmentation_loss(target_segmentation, predictions_segmentation, attrs)

Expand Down Expand Up @@ -675,27 +672,31 @@
if isinstance(predictions_segmentation, list):
while isinstance(predictions_segmentation, list):
predictions_segmentation = predictions_segmentation[-1]

if self.consecutive_slices > 1:
# reshape the target and prediction to [batch_size, self.consecutive_slices, nr_classes, n_x, n_y]
batch_size = target_segmentation.shape[0] // self.consecutive_slices
target_segmentation = target_segmentation.reshape(
batch_size, self.consecutive_slices, *target_segmentation.shape[1:]
)
target_reconstruction = target_reconstruction.reshape(
batch_size, self.consecutive_slices, *target_reconstruction.shape[2:]
)
batch_size = int(target_segmentation.shape[0] / self.consecutive_slices)
predictions_segmentation = predictions_segmentation.reshape(
batch_size, self.consecutive_slices, *predictions_segmentation.shape[2:]
batch_size, self.consecutive_slices, *predictions_segmentation.shape[1:]
)
predictions_reconstruction = predictions_reconstruction.reshape(
batch_size, self.consecutive_slices, *predictions_reconstruction.shape[1:]
target_segmentation = target_segmentation.reshape(
batch_size, self.consecutive_slices, *target_segmentation.shape[1:]
)

target_segmentation = target_segmentation[:, self.consecutive_slices // 2]
target_reconstruction = target_reconstruction[:, self.consecutive_slices // 2]
predictions_segmentation = predictions_segmentation[:, self.consecutive_slices // 2]
predictions_reconstruction = predictions_reconstruction[:, self.consecutive_slices // 2]

if self.num_echoes > 1:
# find the batch size
batch_size = target_reconstruction.shape[0] / self.num_echoes
# reshape to [batch_size, num_echoes, n_x, n_y]
target_reconstruction = target_reconstruction.reshape(
(int(batch_size), self.num_echoes, *target_reconstruction.shape[1:])
)
predictions_reconstruction = predictions_reconstruction.reshape(
(int(batch_size), self.num_echoes, *predictions_reconstruction.shape[1:])
)
fname = attrs["fname"]
slice_idx = attrs["slice_idx"]

Expand Down Expand Up @@ -734,11 +735,6 @@
batch_idx=_batch_idx_,
)

output_predictions_reconstruction = output_predictions_reconstruction.detach().cpu()
output_target_reconstruction = output_target_reconstruction.detach().cpu()
output_target_segmentation = output_target_segmentation.detach().cpu()
output_predictions_segmentation = output_predictions_segmentation.detach().cpu()

# Normalize target and predictions to [0, 1] for logging.
if torch.is_complex(output_target_reconstruction) and output_target_reconstruction.shape[-1] != 2:
output_target_reconstruction = torch.view_as_real(output_target_reconstruction)
Expand All @@ -747,7 +743,6 @@
output_target_reconstruction = output_target_reconstruction / torch.max(
torch.abs(output_target_reconstruction)
)
output_target_reconstruction = output_target_reconstruction.detach().cpu()

if (
torch.is_complex(output_predictions_reconstruction)
Expand All @@ -759,7 +754,11 @@
output_predictions_reconstruction = output_predictions_reconstruction / torch.max(
torch.abs(output_predictions_reconstruction)
)
output_predictions_reconstruction = output_predictions_reconstruction.detach().cpu()

output_predictions_reconstruction = output_predictions_reconstruction.detach().cpu().float()
output_target_reconstruction = output_target_reconstruction.detach().cpu().float()
output_target_segmentation = output_target_segmentation.detach().cpu().float()
output_predictions_segmentation = output_predictions_segmentation.detach().cpu().float()

# Log target and predictions, if log_image is True for this slice.
if attrs["log_image"][_batch_idx_]:
Expand All @@ -772,17 +771,33 @@
)

if self.use_reconstruction_module:
self.log_image(
f"{key}/a F438 /reconstruction/target/predictions/error",
torch.cat(
[
output_target_reconstruction,
output_predictions_reconstruction,
torch.abs(output_target_reconstruction - output_predictions_reconstruction),
],
dim=-1,
),
)
if self.num_echoes > 1:
for i in range(output_target_reconstruction.shape[0]):
self.log_image(
f"{key}/a/reconstruction_abs/target echo: {i+1}/predictions echo: {i+1}/error echo: {i+1}",
torch.cat(
[
output_target_reconstruction[i],
output_predictions_reconstruction[i],
torch.abs(
output_target_reconstruction[i] - output_predictions_reconstruction[i]
),
],
dim=-1,
),
)
else:
self.log_image(
f"{key}/a/reconstruction_abs/target/predictions/error",
torch.cat(
[
output_target_reconstruction,
output_predictions_reconstruction,
torch.abs(output_target_reconstruction - output_predictions_reconstruction),
],
dim=-1,
),
)

# concatenate the segmentation classes for logging
target_segmentation_class = torch.cat(
Expand Down Expand Up @@ -1120,7 +1135,16 @@
self.coil_combination_method,
self.coil_dim,
)

if self.num_echoes > 1:
# stack the echoes along the batch dimension
kspace = kspace.view(-1, *kspace.shape[2:])
y = y.view(-1, *y.shape[2:])
mask = mask.view(-1, *mask.shape[2:])
initial_prediction_reconstruction = initial_prediction_reconstruction.view(
-1, *initial_prediction_reconstruction.shape[2:]
)
target_reconstruction = target_reconstruction.view(-1, *target_reconstruction.shape[2:])
sensitivity_maps = torch.repeat_interleave(sensitivity_maps, repeats=kspace.shape[0], dim=0).squeeze(1)
# Model forward pass
predictions_reconstruction, predictions_segmentation = self.forward(
y,
Expand All @@ -1130,6 +1154,19 @@
target_reconstruction,
attrs["noise"],
)
if self.consecutive_slices > 1:
## reshape the target and prediction segmentation to [batch_size * consecutive_slices, nr_classes, n_x, n_y]
batch_size, slices = target_segmentation.shape[:2]
target_segmentation = target_segmentation.reshape(batch_size * slices, *target_segmentation.shape[2:])
if isinstance(predictions_segmentation, list):
for i, prediction_segmentation in enumerate(predictions_segmentation):
predictions_segmentation[i] = prediction_segmentation.reshape(
batch_size * slices, *prediction_segmentation.shape[2:]
)
else:
predictions_segmentation = predictions_segmentation.reshape(
batch_size * slices, *predictions_segmentation.shape[2:]
)

if not is_none(self.segmentation_classes_thresholds):
for class_idx, thres in enumerate(self.segmentation_classes_thresholds):
Expand Down Expand Up @@ -1482,6 +1519,26 @@
while isinstance(predictions_reconstruction, list):
predictions_reconstruction = predictions_reconstruction[-1]

if self.consecutive_slices > 1:
# reshape the target and prediction to [batch_size, self.consecutive_slices, nr_classes, n_x, n_y]
batch_size = int(target_segmentation.shape[0] / self.consecutive_slices)
predictions_segmentation = predictions_segmentation.reshape(
batch_size, self.consecutive_slices, *predictions_segmentation.shape[1:]
)
predictions_segmentation = predictions_segmentation[:, self.consecutive_slices // 2]
predictions_reconstruction = predictions_reconstruction[:, self.consecutive_slices // 2]

if self.num_echoes > 1:
# find the batch size
batch_size = target_reconstruction.shape[0] / self.num_echoes
# reshape to [batch_size, num_echoes, n_x, n_y]
target_reconstruction = target_reconstruction.reshape(

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable target_reconstruction is not used.
(int(batch_size), self.num_echoes, *target_reconstruction.shape[1:])
)
predictions_reconstruction = predictions_reconstruction.reshape(
(int(batch_size), self.num_echoes, *predictions_reconstruction.shape[1:])
)

# If "16" or "16-mixed" fp is used, ensure complex type will be supported when saving the predictions.
predictions_reconstruction = (
torch.view_as_complex(torch.view_as_real(predictions_reconstruction).type(torch.float32))
Expand Down Expand Up @@ -1670,10 +1727,10 @@
for fname in segmentations:
segmentations[fname] = np.stack([out for _, out in sorted(segmentations[fname])])

if self.consecutive_slices > 1:
# iterate over the slices and always keep the middle slice
for fname in segmentations:
segmentations[fname] = segmentations[fname][:, self.consecutive_slices // 2]
# if self.consecutive_slices > 1:
# # iterate over the slices and always keep the middle slice
# for fname in segmentations:
Comment on lines +1730 to +1732

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
# segmentations[fname] = segmentations[fname][:, self.consecutive_slices // 2] #TODO remove, is already done in the test_step to minimize memory load

if self.use_reconstruction_module:
reconstructions = defaultdict(list)
Expand All @@ -1684,10 +1741,10 @@
for fname in reconstructions:
reconstructions[fname] = np.stack([out for _, out in sorted(reconstructions[fname])])

if self.consecutive_slices > 1:
# iterate over the slices and always keep the middle slice
for fname in reconstructions:
reconstructions[fname] = reconstructions[fname][:, self.consecutive_slices // 2]
# if self.consecutive_slices > 1: #TODO remove, is already done in the test_step to minimize memory load
# # iterate over the slices and always keep the middle slice
# for fname in reconstructions:
# reconstructions[fname] = reconstructions[fname][:, self.consecutive_slices // 2]
Comment on lines +1746 to +1747

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
else:
reconstructions = None

Expand Down Expand Up @@ -1752,6 +1809,7 @@
"skm-tea-echo2",
"skm-tea-echo1+echo2",
"skm-tea-echo1+echo2-mc",
"skm-tea-echo1-echo2",
):
dataloader = mrirs_loader.SKMTEARSMRIDataset
else:
Expand Down
Loading
Loading
0