bonduelle / model.py
Isaacgonzales's picture
update parameters
fa5693e
raw
history blame
1.87 kB
from strhub.data.module import SceneTextDataModule
from strhub.models.utils import load_from_checkpoint
from post import filter_mask
import segmentation_models_pytorch as smp
import albumentations as albu
from torchvision import transforms
from PIL import Image
import torch
import cv2
from time import process_time
import imutils
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_recog = load_from_checkpoint("weights/parseq/last.ckpt").eval().to(device).float()
img_transform = SceneTextDataModule.get_transform(model_recog.hparams.img_size)
model = torch.load('weights/best_model.pth', map_location=torch.device(device))
model.eval()
SHAPE_X = 384
SHAPE_Y = 384
preprocessing_fn = smp.encoders.get_preprocessing_fn('resnet50')
def prediction(image_path):
t_start = process_time()
image = cv2.imread(image_path)
image_original = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_original = imutils.resize(image_original, width=image_original.shape[1]//5)
transform_compose = albu.Compose([
albu.Lambda(image=preprocessing_fn), albu.Resize(SHAPE_X, SHAPE_Y)
])
image_result = transform_compose(image=image_original)["image"]
transform_tensor = transforms.ToTensor()
tensor = transform_tensor(image_result)
tensor = torch.unsqueeze(tensor, 0)
output = model.predict(tensor.float().to(device))
result, img_vis = filter_mask(output, image_original)
image = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
im_pil = Image.fromarray(image)
image = img_transform(im_pil).float().unsqueeze(0).to(device)
p = model_recog(image).softmax(-1)
pred, p = model_recog.tokenizer.decode(p)
t_stop = process_time()
t_process = (t_stop - t_start)/3
print(f'{image_path}: {pred[0]}, {t_process} seconds')
return img_vis, pred[0], t_process