File size: 1,870 Bytes
d02e83e
 
 
 
 
 
 
 
 
 
5a7c59c
659a0c6
4a94ff7
 
8f3d671
d02e83e
 
4a94ff7
d02e83e
8958da5
d02e83e
 
 
8e12dfe
 
d02e83e
 
 
5a7c59c
d02e83e
 
659a0c6
4a94ff7
 
 
da65082
4a94ff7
da65082
d02e83e
bd33281
d02e83e
da65082
d02e83e
 
 
8f3d671
d02e83e
 
 
5a7c59c
fa5693e
5a7c59c
d02e83e
659a0c6
d02e83e
 
d78a11f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

from strhub.data.module import SceneTextDataModule
from strhub.models.utils import load_from_checkpoint
from post import filter_mask
import segmentation_models_pytorch as smp
import albumentations as albu
from torchvision import transforms
from PIL import Image
import torch
import cv2
from time import process_time
import imutils
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_recog = load_from_checkpoint("weights/parseq/last.ckpt").eval().to(device).float()
img_transform = SceneTextDataModule.get_transform(model_recog.hparams.img_size)

model = torch.load('weights/best_model.pth', map_location=torch.device(device))
model.eval()


SHAPE_X = 384
SHAPE_Y = 384
preprocessing_fn = smp.encoders.get_preprocessing_fn('resnet50')



def prediction(image_path):
    t_start = process_time()
    image = cv2.imread(image_path)
    image_original = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image_original = imutils.resize(image_original, width=image_original.shape[1]//5) 
    transform_compose = albu.Compose([
            albu.Lambda(image=preprocessing_fn), albu.Resize(SHAPE_X, SHAPE_Y)
          ])
    image_result = transform_compose(image=image_original)["image"]
    transform_tensor = transforms.ToTensor()
    tensor = transform_tensor(image_result)
    tensor = torch.unsqueeze(tensor, 0)
    output = model.predict(tensor.float().to(device))

    result, img_vis = filter_mask(output, image_original)
    
    image = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
    im_pil = Image.fromarray(image)
    image = img_transform(im_pil).float().unsqueeze(0).to(device)
    
    p = model_recog(image).softmax(-1)
    pred, p = model_recog.tokenizer.decode(p)
    t_stop = process_time()
    t_process = (t_stop - t_start)/3
    print(f'{image_path}: {pred[0]}, {t_process} seconds')
    
    return img_vis, pred[0], t_process