Isaacgonzales commited on
Commit
4a94ff7
1 Parent(s): da65082

update model

Browse files
Files changed (1) hide show
  1. model.py +9 -8
model.py CHANGED
@@ -10,10 +10,12 @@ import torch
10
  import cv2
11
  from time import process_time
12
 
13
- model_recog = load_from_checkpoint("weights/parseq/last.ckpt").eval().to("cpu")
 
 
14
  img_transform = SceneTextDataModule.get_transform(model_recog.hparams.img_size)
15
 
16
- model = torch.load('weights/best_model.pth', map_location=torch.device('cpu'))
17
  model.eval()
18
  model.float()
19
 
@@ -21,18 +23,17 @@ SHAPE_X = 384
21
  SHAPE_Y = 384
22
  preprocessing_fn = smp.encoders.get_preprocessing_fn('resnet50')
23
 
24
- transform_compose = albu.Compose([
25
- albu.Lambda(image=preprocessing_fn), albu.Resize(SHAPE_X, SHAPE_Y)
26
- ])
27
 
28
- transform_tensor = transforms.ToTensor()
29
 
30
  def prediction(image_path):
31
  t_start = process_time()
32
  image = cv2.imread(image_path)
33
  image_original = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
34
-
 
 
35
  image_result = transform_compose(image=image_original)["image"]
 
36
  tensor = transform_tensor(image_result)
37
  tensor = torch.unsqueeze(tensor, 0)
38
  output = model.predict(tensor.float())
@@ -41,7 +42,7 @@ def prediction(image_path):
41
 
42
  image = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
43
  im_pil = Image.fromarray(image)
44
- image = img_transform(im_pil).unsqueeze(0).to("cpu")
45
 
46
  p = model_recog(image).softmax(-1)
47
  pred, p = model_recog.tokenizer.decode(p)
 
10
  import cv2
11
  from time import process_time
12
 
13
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
+
15
+ model_recog = load_from_checkpoint("weights/parseq/last.ckpt").eval().to(device)
16
  img_transform = SceneTextDataModule.get_transform(model_recog.hparams.img_size)
17
 
18
+ model = torch.load('weights/best_model.pth', map_location=torch.device(device))
19
  model.eval()
20
  model.float()
21
 
 
23
  SHAPE_Y = 384
24
  preprocessing_fn = smp.encoders.get_preprocessing_fn('resnet50')
25
 
 
 
 
26
 
 
27
 
28
  def prediction(image_path):
29
  t_start = process_time()
30
  image = cv2.imread(image_path)
31
  image_original = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
32
+ transform_compose = albu.Compose([
33
+ albu.Lambda(image=preprocessing_fn), albu.Resize(SHAPE_X, SHAPE_Y)
34
+ ])
35
  image_result = transform_compose(image=image_original)["image"]
36
+ transform_tensor = transforms.ToTensor()
37
  tensor = transform_tensor(image_result)
38
  tensor = torch.unsqueeze(tensor, 0)
39
  output = model.predict(tensor.float())
 
42
 
43
  image = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
44
  im_pil = Image.fromarray(image)
45
+ image = img_transform(im_pil).unsqueeze(0).to(device)
46
 
47
  p = model_recog(image).softmax(-1)
48
  pred, p = model_recog.tokenizer.decode(p)