8000 GitHub · Where software is built
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
AttributeError: 'SegmentationTTAWrapper' object has no attribute 'predict' #23
Open
@chefkrym

Description

@chefkrym

`import torch
import ttach as tta
import timm
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import cv2

model = torch.load('E:/PhD_Projects/egmentation models/new model weights/UNet_mitb2_thresh0.3.pth')

tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode="mean")

image_dir = 'E:/PhD_Projects/segmentation models/patches'
image_filename_2 = 'image__02_02.tif'
image_path = os.path.join(image_dir, image_filename_2)
image = tiff.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

preprocessing_fn_inference = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
preprocessing_inference=get_preprocessing(preprocessing_fn_inference)
sample = preprocessing_inference(image=image)
image = sample['image']

x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
pr_mask = tta_model.predict(x_tensor)
pr_mask = (pr_mask.squeeze().cpu().numpy().round())
pr_mask = (pr_mask.astype('float') * 255.0/16)
#pr_mask = (pr_mask.astype('float') * 255.0/16).astype('uint8')

=============================================================================

plt.imshow(pr_mask)
plt.show()`

Can anyone help me with this prediction problem? Thank you. @qubvel

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0