|
|
|
from strhub.data.module import SceneTextDataModule |
|
from strhub.models.utils import load_from_checkpoint |
|
from post import filter_mask |
|
import segmentation_models_pytorch as smp |
|
import albumentations as albu |
|
from torchvision import transforms |
|
from PIL import Image |
|
import torch |
|
import cv2 |
|
from time import process_time |
|
import imutils |
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
model_recog = load_from_checkpoint("weights/parseq/last.ckpt").eval().to(device).float() |
|
img_transform = SceneTextDataModule.get_transform(model_recog.hparams.img_size) |
|
|
|
model = torch.load('weights/best_model.pth', map_location=torch.device(device)) |
|
model.eval() |
|
|
|
|
|
SHAPE_X = 384 |
|
SHAPE_Y = 384 |
|
preprocessing_fn = smp.encoders.get_preprocessing_fn('resnet50') |
|
|
|
|
|
|
|
def prediction(image_path): |
|
t_start = process_time() |
|
image = cv2.imread(image_path) |
|
image_original = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
image_original = imutils.resize(image_original, width=image_original.shape[1]//5) |
|
transform_compose = albu.Compose([ |
|
albu.Lambda(image=preprocessing_fn), albu.Resize(SHAPE_X, SHAPE_Y) |
|
]) |
|
image_result = transform_compose(image=image_original)["image"] |
|
transform_tensor = transforms.ToTensor() |
|
tensor = transform_tensor(image_result) |
|
tensor = torch.unsqueeze(tensor, 0) |
|
output = model.predict(tensor.float().to(device)) |
|
|
|
result, img_vis = filter_mask(output, image_original) |
|
|
|
image = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) |
|
im_pil = Image.fromarray(image) |
|
image = img_transform(im_pil).float().unsqueeze(0).to(device) |
|
|
|
p = model_recog(image).softmax(-1) |
|
pred, p = model_recog.tokenizer.decode(p) |
|
t_stop = process_time() |
|
t_process = (t_stop - t_start)/3 |
|
print(f'{image_path}: {pred[0]}, {t_process} seconds') |
|
|
|
return img_vis, pred[0], t_process |
|
|
|
|
|
|
|
|