Description
`import torch
import ttach as tta
import timm
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import cv2
model = torch.load('E:/PhD_Projects/egmentation models/new model weights/UNet_mitb2_thresh0.3.pth')
tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode="mean")
image_dir = 'E:/PhD_Projects/segmentation models/patches'
image_filename_2 = 'image__02_02.tif'
image_path = os.path.join(image_dir, image_filename_2)
image = tiff.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
preprocessing_fn_inference = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
preprocessing_inference=get_preprocessing(preprocessing_fn_inference)
sample = preprocessing_inference(image=image)
image = sample['image']
x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
pr_mask = tta_model.predict(x_tensor)
pr_mask = (pr_mask.squeeze().cpu().numpy().round())
pr_mask = (pr_mask.astype('float') * 255.0/16)
#pr_mask = (pr_mask.astype('float') * 255.0/16).astype('uint8')
=============================================================================
plt.imshow(pr_mask)
plt.show()`
Can anyone help me with this prediction problem? Thank you. @qubvel