from strhub.data.module import SceneTextDataModule from strhub.models.utils import load_from_checkpoint from post import filter_mask import segmentation_models_pytorch as smp import albumentations as albu from torchvision import transforms from PIL import Image import torch import cv2 from time import process_time import imutils device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model_recog = load_from_checkpoint("weights/parseq/last.ckpt").eval().to(device).float() img_transform = SceneTextDataModule.get_transform(model_recog.hparams.img_size) model = torch.load('weights/best_model.pth', map_location=torch.device(device)) model.eval() SHAPE_X = 384 SHAPE_Y = 384 preprocessing_fn = smp.encoders.get_preprocessing_fn('resnet50') def prediction(image_path): t_start = process_time() image = cv2.imread(image_path) image_original = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image_original = imutils.resize(image_original, width=image_original.shape[1]//5) transform_compose = albu.Compose([ albu.Lambda(image=preprocessing_fn), albu.Resize(SHAPE_X, SHAPE_Y) ]) image_result = transform_compose(image=image_original)["image"] transform_tensor = transforms.ToTensor() tensor = transform_tensor(image_result) tensor = torch.unsqueeze(tensor, 0) output = model.predict(tensor.float().to(device)) result, img_vis = filter_mask(output, image_original) image = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) im_pil = Image.fromarray(image) image = img_transform(im_pil).float().unsqueeze(0).to(device) p = model_recog(image).softmax(-1) pred, p = model_recog.tokenizer.decode(p) t_stop = process_time() t_process = (t_stop - t_start)/3 print(f'{image_path}: {pred[0]}, {t_process} seconds') return img_vis, pred[0], t_process