8000 Post-training optimization using OpenVINO added by junwenwu · Pull Request #312 · mlcommons/GaNDLF · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Post-training optimization using OpenVINO added #312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 167 commits into from
Mar 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
167 commits
Select commit Hold shift + click to select a range
16a589e
OpenVINO example for Gandlf trained classification models
psakamoori Nov 4, 2021
722687d
Update README.md
Nov 4, 2021
3b1c25d
Update README.md
Nov 4, 2021
9c6c81d
Update README.md
Nov 4, 2021
0721949
Update README.md
Nov 4, 2021
639ab2f
Update README.md
Nov 4, 2021
a621848
Update README.md
Nov 4, 2021
a63a189
Update README.md
Nov 4, 2021
acd44e3
Update README.md
Nov 4, 2021
b5b4741
Update the README file
psakamoori Nov 4, 2021
bf67644
Update the README file
psakamoori Nov 4, 2021
5e9b107
Update the README file
psakamoori Nov 4, 2021
5d0bfc5
Update the README file
psakamoori Nov 4, 2021
4b47a1f
Update the ov conversion script
Nov 4, 2021
13c3931
Update README.md
Nov 4, 2021
3b8ac9c
Update the ov conversion script
Nov 4, 2021
b3268fc
Update the ov conversion script
Nov 5, 2021
4d81369
Update the README
Nov 5, 2021
9f219d4
Update the README
Nov 5, 2021
2bfc199
Update the README
Nov 5, 2021
5a06479
Update the README
Nov 5, 2021
c42654b
Update the README
Nov 8, 2021
cbf8599
Update the onnx conversion scripts
Nov 8, 2021
4fa3a2b
Update the nncf script
Nov 9, 2021
4c267bf
Update the onnx conversion scripts
Nov 9, 2021
ce588e9
Code refactoring
Nov 11, 2021
963f568
Minor modifications after functionality testing for 3D Resunet
karkadad Jan 6, 2022
197416a
Merge pull request #1 from junwenwu/3dresunet_tests
Jan 19, 2022
df8ec39
Add files via upload
karkadad Feb 14, 2022
5d46112
Update perf_infer_icx.sh
karkadad Feb 14, 2022
65157ac
Update perf_infer_icx.sh
karkadad Feb 14, 2022
a09cebd
OpenVINO Integration with GaNDLF
Feb 15, 2022
df2e1c6
Merge remote-tracking branch 'upstream/master'
psakamoori Feb 15, 2022
711d9f4
Sync up with upstream for PR
Feb 15, 2022
2a4c896
Reformatting code
Feb 15, 2022
ebfb364
Reformatting code
Feb 15, 2022
eac39f4
Reformatting code
Feb 15, 2022
c9273fe
Update setup.py
Feb 15, 2022
098e239
Reformatting code
Feb 15, 2022
ade199c
Reformatting code
Feb 15, 2022
c326aed
Reformatting code
Feb 15, 2022
f2b16b1
Fix inference_loop
Feb 15, 2022
9c60568
Fix issues from CODACY
Feb 15, 2022
37fbad7
Fix issues from CODACY
Feb 15, 2022
3823812
Fix issues from CODACY
Feb 15, 2022
5f6ecd8
Remove shell from subprocess call.
Feb 15, 2022
d8a5e27
Reformatting code.
Feb 15, 2022
bef0495
Fix dependency on OpenVINO env variables
psakamoori Feb 17, 2022
135abd0
Reformatting
Feb 17, 2022
fce7648
Reformatting
Feb 17, 2022
b959153
Reformatting
Feb 17, 2022
65a1de1
Fix dependency and packages
Feb 17, 2022
8d66986
Merge branch 'master' into gandlf_ov_integration
sarthakpati Feb 17, 2022
6be5791
Reformatting code.
Feb 17, 2022
bae8729
8000 Update ov model exclusion
Feb 18, 2022
20a63b1
Fixed logic
Feb 19, 2022
84569c3
Merge branch 'gandlf_ov_integration' of https://github.com/junwenwu/G…
psakamoori Feb 19, 2022
e98187f
Resolve conflict
Feb 19, 2022
57b7de3
Resolve conflict
Feb 19, 2022
163804f
Merge branch 'master' into gandlf_ov_integration
Feb 19, 2022
3980c39
Reformatting
Feb 19, 2022
53e1f21
Merge branch 'gandlf_ov_integration' of https://github.com/junwenwu/G…
psakamoori Feb 19, 2022
dbb117d
Resolve dependency
Feb 19, 2022
f7cb345
Resolve dependency
Feb 19, 2022
7a3f8e5
Merge branch 'master' into gandlf_ov_integration
sarthakpati Feb 21, 2022
f82d264
Resolve dependency
Feb 22, 2022
4852e71
Merge branch 'gandlf_ov_integration' of https://github.com/junwenwu/G…
psakamoori Feb 22, 2022
5b88de5
Catch OpenVINO import error.
Feb 22, 2022
d543ada
Fix a typo.
Feb 22, 2022
14b72a6
Fix a typo.
Feb 22, 2022
97e8bb6
Remove import StatusCode.
Feb 22, 2022
909e881
Resolve dependency
Feb 22, 2022
ecdf26f
Resolve dependency
Feb 22, 2022
2a274db
Resolve dependency
Feb 22, 2022
d6ad82b
Resolve dependency
Feb 22, 2022
f394ccb
Resolve dependency
Feb 22, 2022
86f1e76
Resolve dependency
Feb 22, 2022
a4957a5
Update ov model exclusion
Feb 22, 2022
50817bf
Update ov model exclusion
Feb 22, 2022
d78fa83
Update ov model exclusion
Feb 22, 2022
8d20916
Reformatting
Feb 22, 2022
a8d5313
Update setup doc
Feb 22, 2022
e3c6a2c
Update ov model exclusion
Feb 22, 2022
a4611c2
Reformatting
Feb 22, 2022
b892b36
Update ov model exclusion
Feb 22, 2022
e4eda6d
Reformatting
Feb 22, 2022
22df1cd
Reformatting
Feb 22, 2022
1bc63fc
Merge branch 'master' into gandlf_ov_integration
sarthakpati Mar 1, 2022
0671ae2
Merge branch 'gandlf_ov_integration' of https://github.com/junwenwu/G…
sarthakpati Mar 3, 2022
e0a7165
Merge branch 'master' into gandlf_ov_integration
sarthakpati Mar 4, 2022
dce44c3
Update the openvino model saving
Mar 9, 2022
abb00b7
Reformat
Mar 9, 2022
00a5ced
Reformat
Mar 9, 2022
78c240d
Reformat
Mar 9, 2022
8976149
Merge branch 'master' of https://github.com/CBICA/GaNDLF into openvin…
sarthakpati Mar 9, 2022
6debd1f
Merge branch 'gandlf_ov_integration' of https://github.com/junwenwu/G…
sarthakpati Mar 9, 2022
1b55902
Update GANDLF/utils/modelio.py
Mar 9, 2022
5048df3
Resolving conflicts an 10000 d fix testing
psakamoori Mar 10, 2022
dfced5e
Resolve conflicts
Mar 10, 2022
f0d6603
Resolve conflicts
Mar 10, 2022
02fbb5a
Resolve conflicts
Mar 10, 2022
183fe4c
Resolve reformatting
Mar 10, 2022
6097f60
Resolve reformatting
Mar 10, 2022
b4added
Resolve reformatting
Mar 10, 2022
c759529
Merge remote-tracking branch 'upstream/master' into gandlf_ov_integra…
psakamoori Mar 10, 2022
fff58dc
Resolve conflicts
Mar 10, 2022
43643ac
Resolve conflicts
Mar 10, 2022
42c6178
Reformat
Mar 10, 2022
51db388
Reformat
Mar 10, 2022
e7d1cc7
Resolve conflicts
Mar 10, 2022
5cbf623
Merge branch 'gandlf_ov_integration' of https://github.com/junwenwu/G…
sarthakpati Mar 10, 2022
c4d78e0
hash update
sarthakpati Mar 10, 2022
07bb8cc
added a comment
sarthakpati Mar 10, 2022
535f633
added onnx
sarthakpati Mar 10, 2022
5c3d564
removed from print commands
sarthakpati Mar 10, 2022
df0e6ff
moving model to cpu
sarthakpati Mar 10, 2022
c7706a0
t push origin gandlf_ov_integrationMerge branch 'sarthakpati-openvino…
psakamoori Mar 10, 2022
481181b
Merge branch 'master' into gandlf_ov_integration
sarthakpati Mar 10, 2022
9f180fc
bug fix
sarthakpati Mar 10, 2022
17f0566
Fixing average_pool
Mar 10, 2022
3fc1821
Update model save and test.
Mar 11, 2022
6b59629
Resolve reformatting
Mar 11, 2022
0c5f809
Resolve reformatting
Mar 11, 2022
e28a4c0
Resolve reformatting
Mar 11, 2022
a921357
Merge branch 'gandlf_ov_integration' of https://github.com/junwenwu/G…
sarthakpati Mar 11, 2022
d169508
removed print
sarthakpati Mar 11, 2022
704555e
moving parameter parsing to parseConfig
sarthakpati Mar 11, 2022
28676fa
moving the final optimization after epoch completion, since the previ…
sarthakpati Mar 11, 2022
8dfe73d
API updated to use params dict so that future extensions are easier w…
sarthakpati Mar 11, 2022
512ed22
using new logic, and changed optimization call to end of epoch
sarthakpati Mar 11, 2022
abc294b
revert to resunet
sarthakpati Mar 11, 2022
c8c013f
added for 2 tests
sarthakpati Mar 11, 2022
6d5c809
lower patch size for histo test
sarthakpati Mar 11, 2022
547c164
added some test logic
sarthakpati Mar 11, 2022
7d716f9
black .
sarthakpati Mar 11, 2022
a50c772
removed warnings usage in favor of simple print for consistency
sarthakpati Mar 11, 2022
ca667cb
Merge branch 'master' of https://github.com/CBICA/GaNDLF into openvin…
sarthakpati Mar 11, 2022
7866d60
Merge pull request #3 from sarthakpati/openvino_integration
Mar 11, 2022
91d0da5
disable openvino for 3d classification
sarthakpati Mar 11, 2022
e764655
tests are functioning
sarthakpati Mar 11, 2022
6ee8df9
Merge pull request #4 from sarthakpati/openvino_integration
Mar 11, 2022
5f0d618
Merge branch 'master' into gandlf_ov_integration
sarthakpati Mar 11, 2022
bd83f2c
Merge branch 'master' of https://github.com/CBICA/GaNDLF into openvin…
sarthakpati Mar 14, 2022
227554f
Update GANDLF/compute/forward_pass.py
Mar 14, 2022
d896af8
Merge branch 'gandlf_ov_integration' of https://github.com/junwenwu/G…
sarthakpati Mar 14, 2022
7c2ac30
black .
sarthakpati Mar 14, 2022
b2ee049
Reformatting
Mar 14, 2022
68af85c
Merge pull request #5 from sarthakpati/openvino_integration
Mar 14, 2022
37f2d82
fixed test
sarthakpati Mar 14, 2022
0826b33
Merge pull request #6 from sarthakpati/openvino_integration
Mar 14, 2022
8021fa1
this is needed to ensure that output doesn't stay in cpu in case cuda…
sarthakpati Mar 14, 2022
31f4933
do model optimization if onnx export is true
sarthakpati Mar 14, 2022
42a2083
updated logic for optimizer
sarthakpati Mar 14, 2022
05c5d7b
added some print statements and specific optimization tests for 3d cl…
sarthakpati Mar 14, 2022
c0b39a0
syntax fix
sarthakpati Mar 14, 2022
36eca2e
black .
sarthakpati Mar 14, 2022
627e71f
Merge branch 'master' into gandlf_ov_integration
Geeks-Sid Mar 15, 2022
cf39321
Merge branch 'master' of https://github.com/CBICA/GaNDLF into openvin…
sarthakpati Mar 15, 2022
a52a0fd
message updated
sarthakpati Mar 15, 2022
ea9f440
Merge pull request #7 from sarthakpati/openvino_integration
Mar 16, 2022
13c417d
Merge branch 'master' into gandlf_ov_integration
sarthakpati Mar 18, 2022
e6fbf5e
OV installation added in docker images
sarthakpati Mar 18, 2022
eac887d
Merge branch 'master' of https://github.com/sarthakpati/GaNDLF into o…
sarthakpati Mar 18, 2022
edce4ed
Merge branch 'master' of https://github.com/CBICA/GaNDLF into openvin…
sarthakpati Mar 18, 2022
169fe17
Merge branch 'master' of https://github.com/CBICA/GaNDLF into openvin…
sarthakpati Mar 19, 2022
37389b3
Merge branch 'master' into gandlf_ov_integration
Geeks-Sid Mar 20, 2022
737835d
Merge pull request #8 from sarthakpati/openvino_integration
Mar 21, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ jobs:
- name: Install dependencies and package
run: |
python -m pip install --upgrade pip
python -m pip install openvino-dev
$CONDA/bin/conda install -c conda-forge libvips -y
pip3 install torch==1.8.2+cpu torchvision==0.9.2+cpu torchaudio==0.8.2 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
pip install -e .
Expand Down
1 change: 1 addition & 0 deletions Dockerfile-CPU
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ RUN python3.7 -m pip install --upgrade pip
RUN python3.7 -m pip install torch==1.10.0+cpu torchvision==0.11.0+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
COPY . /GaNDLF
WORKDIR /GaNDLF
RUN python3.7 -m pip install --upgrade pip && python3.7 -m pip install openvino-dev
RUN python3.7 -m pip install -e .
# Entrypoint forces all commands given via "docker run" to go through python, CMD forces the default entrypoint script argument to be gandlf_run
# If a user calls "docker run gandlf:[tag] gandlf_anonymize", it will resolve to running "python gandlf_anonymize" instead.
Expand Down
1 change: 1 addition & 0 deletions Dockerfile-CUDA10.2
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ LABEL version=1.0
RUN python3 -m pip install --upgrade pip
COPY . /GaNDLF
WORKDIR /GaNDLF
RUN python3 -m pip install --upgrade pip && python3 -m pip install openvino-dev
RUN python3 -m pip install -e .
# Entrypoint forces all commands given via "docker run" to go through python, CMD forces the default entrypoint script argument to be gandlf_run
# If a user calls "docker run gandlf:[tag] gandlf_anonymize", it will resolve to running "python gandlf_anonymize" instead.
Expand Down
1 change: 1 addition & 0 deletions Dockerfile-CUDA11.3
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ RUN python3.7 -m pip install --upgrade pip
RUN python3.7 -m pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
COPY . /GaNDLF
WORKDIR /GaNDLF
RUN python3.7 -m pip install --upgrade pip && python3.7 -m pip install openvino-dev
RUN python3.7 -m pip install -e .
# Entrypoint forces all commands given via "docker run" to go through python, CMD forces the default entrypoint script argument to be gandlf_run
# If a user calls "docker run gandlf:[tag] gandlf_anonymize", it will resolve to running "python gandlf_anonymize" instead.
Expand Down
1 change: 1 addition & 0 deletions Dockerfile-ROCm
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ LABEL version=1.0
RUN python3 -m pip install --upgrade pip
COPY . /GaNDLF
WORKDIR /GaNDLF
RUN python3 -m pip install --upgrade pip && python3 -m pip install openvino-dev
RUN python3 -m pip install -e .
# Entrypoint forces all commands given via "docker run" to go through python, CMD forces the default entrypoint script argument to be gandlf_run
# If a user calls "docker run gandlf:[tag] gandlf_anonymize", it will resolve to running "python gandlf_anonymize" instead.
Expand Down
32 changes: 24 additions & 8 deletions GANDLF/compute/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def validate_network(

Parameters
----------
model : torch.model
model : if parameters["model"]["type"] == torch, this is a torch.model, otherwise this is OV exec_net
The model to process the input image with, it should support appropriate dimensions.
valid_dataloader : torch.DataLoader
The dataloader for the validation epoch
Expand Down Expand Up @@ -76,11 +76,13 @@ def validate_network(
pathlib.Path(current_output_dir).mkdir(parents=True, exist_ok=True)

# Set the model to valid
model.eval()
if params["model"]["type"] == "torch":
model.eval()

# # putting stuff in individual arrays for correlation analysis
# all_targets = []
# all_predics = []
if params["medcam_enabled"]:
if params["medcam_enabled"] and params["model"]["type"] == "torch":
model.enable_medcam()
params["medcam_enabled"] = True

Expand Down Expand Up @@ -158,12 +160,23 @@ def validate_network(
## special case for 2D
if image.shape[-1] == 1:
image = torch.squeeze(image, -1)
pred_output += model(image)
if params["model"]["type"] == "torch":
pred_output += model(image)
elif params["model"]["type"] == "openvino":
pred_output += torch.from_numpy(
model.infer(
inputs={params["model"]["IO"][0]: image.cpu().numpy()}
)[params["model"]["IO"][1]]
)
else:
raise Exception(
"Model type not supported. Please only use 'torch' or 'openvino'."
)

pred_output = pred_output.cpu() / params["q_samples_per_volume"]
pred_output /= params["scaling_factor"]
# all_predics.append(pred_output.double())
# all_targets.append(valuesToPredict.double())
print(f"pred_output.shape: {pred_output.shape}")

if is_inference and is_classification:
logits_list.append(pred_output)
Expand Down Expand Up @@ -242,7 +255,10 @@ def validate_network(
flush=True,
)

result = step(model, image, label, params)
if is_inference:
result = step(model, image, label, params, train=False)
else:
result = step(model, image, label, params, train=True)

# get the current attention map and add it to its aggregator
if params["medcam_enabled"]:
Expand Down Expand Up @@ -326,7 +342,7 @@ def validate_network(
)

# get the final attention map and save it
if params["medcam_enabled"]:
if params["medcam_enabled"] and params["model"]["type"] == "torch":
attention_map = attention_map_aggregator.get_output_tensor()
for i, n in enumerate(attention_map):
model.save_attention_map(
Expand Down Expand Up @@ -383,7 +399,7 @@ def validate_network(
to_print,
)

if params["medcam_enabled"]:
if params["medcam_enabled"] and params["model"]["type"] == "torch":
model.disable_medcam()
params["medcam_enabled"] = False

Expand Down
82 changes: 64 additions & 18 deletions GANDLF/compute/inference_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
import tiffslide as openslide

from GANDLF.data.ImagesFromDataFrame import ImagesFromDataFrame
from GANDLF.utils import populate_channel_keys_in_params, send_model_to_device
from GANDLF.models import global_models_dict
from GANDLF.utils import (
populate_channel_keys_in_params,
send_model_to_device,
load_ov_model,
)
from GANDLF.data.inference_dataloader_histopath import InferTumorSegDataset


Expand All @@ -31,6 +35,7 @@ def inference_loop(inferenceDataFromPickle, device, parameters, outputDir):
outputDir (str): The output directory.
"""
# Defining our model here according to parameters mentioned in the configuration file
print("Current model type : ", parameters["model"]["type"])
print("Number of dims : ", parameters["model"]["dimension"])
if "num_channels" in parameters["model"]:
print("Number of channels : ", parameters["model"]["num_channels"])
Expand All @@ -47,17 +52,46 @@ def inference_loop(inferenceDataFromPickle, device, parameters, outputDir):
)
inference_loader = DataLoader(inferenceDataForTorch, batch_size=1)

# Loading the weights into the model
main_dict = outputDir
if os.path.isdir(outputDir):
file_to_check = os.path.join(
outputDir, str(parameters["model"]["architecture"]) + "_best.pth.tar"
if parameters["model"]["type"] == "torch":
# Loading the weights into the model
main_dict = outputDir
if os.path.isdir(outputDir):
file_to_check = os.path.join(
outputDir, str(parameters["model"]["architecture"]) + "_best.pth.tar"
)
if not os.path.isfile(file_to_check):
raise ValueError(
"The specified model was not found: {0}.".format(file_to_check)
)

main_dict = torch.load(file_to_check, map_location=torch.device(device))
model.load_state_dict(main_dict["model_state_dict"])
elif parameters["model"]["type"].lower() == "openvino":
# Loading the executable OpenVINO model
main_dict = outputDir
if os.path.isdir(outputDir):
xml_to_check = os.path.join(
outputDir, str(parameters["model"]["architecture"]) + "_best.xml"
)
bin_to_check = os.path.join(
outputDir, str(parameters["model"]["architecture"]) + "_best.bin"
)
if not os.path.isfile(xml_to_check):
raise ValueError(
"The specified model IR was not found: {0}.".format(xml_to_check)
)
if not os.path.isfile(bin_to_check):
raise ValueError(
"The model specified model weights was not found: {0}.".format(
bin_to_check
)
)
model, input_blob, output_blob = load_ov_model(xml_to_check, device.upper())
parameters["model"]["IO"] = [input_blob, output_blob]
else:
raise ValueError(
"The model type is not recognized: ", parameters["model"]["type"]
)
if not os.path.isfile(file_to_check):
raise ValueError("The model specified model was not found:", file_to_check)

main_dict = torch.load(file_to_check, map_location=torch.device(device))
model.load_state_dict(main_dict["model_state_dict"])

if not (os.environ.get("HOSTNAME") is None):
print("\nHostname :" + str(os.environ.get("HOSTNAME")), flush=True)
Expand All @@ -67,9 +101,10 @@ def inference_loop(inferenceDataFromPickle, device, parameters, outputDir):
parameters["save_output"] = True

print("Data Samples: ", len(inference_loader.dataset), flush=True)
model, parameters["model"]["amp"], parameters["device"] = send_model_to_device(
model, parameters["model"]["amp"], device, optimizer=None
)
if parameters["model"]["type"] == "torch":
model, parameters["model"]["amp"], parameters["device"] = send_model_to_device(
model, parameters["model"]["amp"], device, optimizer=None
)

print("Using device:", parameters["device"], flush=True)

Expand Down Expand Up @@ -129,12 +164,23 @@ def inference_loop(inferenceDataFromPickle, device, parameters, outputDir):
)
for image_patches, (x_coords, y_coords) in tqdm(dataloader):
x_coords, y_coords = y_coords.numpy(), x_coords.numpy()
if parameters["model"]["amp"]:
with autocast():
if parameters["model"]["type"] == "torch":
if parameters["model"]["amp"]:
with autocast():
output = model(
image_patches.float().to(parameters["device"])
)
else:
output = model(image_patches.float().to(parameters["device"]))
output = output.detach().cpu().numpy()
else:
output = model(image_patches.float().to(parameters["device"]))
output = output.detach().cpu().numpy()
output = model.infer(
inputs={
parameters["model"]["IO"][0]: image_patches.float()
.cpu()
.numpy()
}
)[parameters["model"]["IO"][1]]
for i in range(int(output.shape[0])):
count_map[
x_coords[i] : x_coords[i] + patch_size[0],
Expand Down
18 changes: 13 additions & 5 deletions GANDLF/compute/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .loss_and_metric import get_loss_and_metrics


def step(model, image, label, params):
def step(model, image, label, params, train=True):
"""
Function that steps the model for a single batch

Expand Down Expand Up @@ -60,11 +60,19 @@ def step(model, image, label, params):
if len(label.shape) > 1:
label = torch.squeeze(label, -1)

if params["model"]["amp"]:
with torch.cuda.amp.autocast():
output = model(image)
if train == False and params["model"]["type"].lower() == "openvino":
output = torch.from_numpy(
model.infer(inputs={params["model"]["IO"][0]: image.cpu().numpy()})[
params["model"]["IO"][1]
]
)
output = output.to(params["device"])
else:
output = model(image)
if params["model"]["amp"]:
with torch.cuda.amp.autocast():
output = model(image)
else:
output = model(image)

if "medcam_enabled" in params and params["medcam_enabled"]:
output, attention_map = output
Expand Down
49 changes: 45 additions & 4 deletions GANDLF/compute/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,17 +358,16 @@ def training_loop(

# if previous model file is present, load it up
if os.path.exists(best_model_path):
print("Previous model found. Loading it up.")
try:
main_dict = load_model(best_model_path)
version_check(params["version"], version_to_check=main_dict["version"])
model.load_state_dict(main_dict["model_state_dict"])
start_epoch = main_dict["epoch"]
optimizer.load_state_dict(main_dict["optimizer_state_dict"])
best_loss = main_dict["loss"]
print("Previous model loaded successfully.")
except IOError:
raise IOError("Previous model could not be loaded, error: ")
print("Previous model successfully loaded.")
except RuntimeWarning:
RuntimeWarning("Previous model could not be loaded, initializing model")

print("Using device:", device, flush=True)

Expand Down Expand Up @@ -460,15 +459,21 @@ def training_loop(
best_loss = epoch_valid_loss
best_train_idx = epoch
patience = 0

model.eval()
save_model(
{
"epoch": best_train_idx,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": best_loss,
},
model,
params,
best_model_path,
onnx_export=False,
)
model.train()
first_model_saved = True

print("Current Best epoch: ", best_train_idx)
Expand All @@ -491,6 +496,42 @@ def training_loop(
flush=True,
)

# once the training is done, optimize the best model
if os.path.exists(best_model_path):

onnx_export = True
if params["model"]["architecture"] in ["sdnet", "brain_age"]:
onnx_export = False
elif (
"onnx_export" in params["model"] and params["model"]["onnx_export"] == False
):
onnx_export = False

if onnx_export:
print("Optimizing best model.")

try:
10000 main_dict = load_model(best_model_path)
version_check(params["version"], version_to_check=main_dict["version"])
model.load_state_dict(main_dict["model_state_dict"])
best_epoch = main_dict["epoch"]
optimizer.load_state_dict(main_dict["optimizer_state_dict"])
best_loss = main_dict["loss"]
save_model(
{
"epoch": best_epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": best_loss,
},
model,
params,
best_model_path,
onnx_export,
)
except Exception as e:
print("Best model could not be loaded, error: ", e)


if __name__ == "__main__":

Expand Down
4 changes: 4 additions & 0 deletions GANDLF/inference_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def InferenceManager(dataframe, outputDir, parameters, device):
class_list = None
is_classification = parameters["problem_type"] == "classification"

# initialize model type for processing: if not defined, default to torch
if not ("type" in parameters["model"]):
parameters["model"]["type"] = "torch"

for fold_dir in fold_dirs:
parameters["current_fold_dir"] = fold_dir
inference_loop(
Expand Down
2 changes: 1 addition & 1 deletion GANDLF/models/seg_modules/average_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,6 @@ def forward(self, x):
if isinstance(B, int):
return F.avg_pool3d(x, (W, H, D)).view(B, C)
else:
return F.avg_pool2d(x, (W.item(), H.item(), D.item())).view(
return F.avg_pool3d(x, (W.item(), H.item(), D.item())).view(
B.item(), C.item()
)
Loading
0