File size: 1,870 Bytes
d02e83e 5a7c59c 659a0c6 4a94ff7 8f3d671 d02e83e 4a94ff7 d02e83e 8958da5 d02e83e 8e12dfe d02e83e 5a7c59c d02e83e 659a0c6 4a94ff7 da65082 4a94ff7 da65082 d02e83e bd33281 d02e83e da65082 d02e83e 8f3d671 d02e83e 5a7c59c fa5693e 5a7c59c d02e83e 659a0c6 d02e83e d78a11f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
from strhub.data.module import SceneTextDataModule
from strhub.models.utils import load_from_checkpoint
from post import filter_mask
import segmentation_models_pytorch as smp
import albumentations as albu
from torchvision import transforms
from PIL import Image
import torch
import cv2
from time import process_time
import imutils
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_recog = load_from_checkpoint("weights/parseq/last.ckpt").eval().to(device).float()
img_transform = SceneTextDataModule.get_transform(model_recog.hparams.img_size)
model = torch.load('weights/best_model.pth', map_location=torch.device(device))
model.eval()
SHAPE_X = 384
SHAPE_Y = 384
preprocessing_fn = smp.encoders.get_preprocessing_fn('resnet50')
def prediction(image_path):
t_start = process_time()
image = cv2.imread(image_path)
image_original = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_original = imutils.resize(image_original, width=image_original.shape[1]//5)
transform_compose = albu.Compose([
albu.Lambda(image=preprocessing_fn), albu.Resize(SHAPE_X, SHAPE_Y)
])
image_result = transform_compose(image=image_original)["image"]
transform_tensor = transforms.ToTensor()
tensor = transform_tensor(image_result)
tensor = torch.unsqueeze(tensor, 0)
output = model.predict(tensor.float().to(device))
result, img_vis = filter_mask(output, image_original)
image = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
im_pil = Image.fromarray(image)
image = img_transform(im_pil).float().unsqueeze(0).to(device)
p = model_recog(image).softmax(-1)
pred, p = model_recog.tokenizer.decode(p)
t_stop = process_time()
t_process = (t_stop - t_start)/3
print(f'{image_path}: {pred[0]}, {t_process} seconds')
return img_vis, pred[0], t_process
|