File size: 4,741 Bytes
b0be318
 
 
7fea409
b0be318
 
 
 
 
 
 
 
 
 
 
cb2271b
b0be318
 
 
 
 
 
 
 
 
 
7fea409
 
 
b0be318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7fea409
 
 
 
 
 
 
 
 
 
 
 
 
 
b0be318
 
7fea409
b0be318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7fea409
6aad328
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
import torch
import numpy as np
import yaml

from yolov5.models.common import DetectMultiBackend
from yolov5.utils.datasets import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
from yolov5.utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr,
                           increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
from yolov5.utils.plots import Annotator, colors, save_one_box
from yolov5.utils.torch_utils import select_device, time_sync
from yolov5.utils.augmentations import letterbox

device = 'cpu'
half = False
weights = 'yolov5/joint_all_multi.pt'
model = DetectMultiBackend(weights, device=device, dnn=False, data=None)
stride, names, pt, jit, onnx, engine = model.stride, model.names, model.pt, model.jit, model.onnx, model.engine
bs = 1
imgsz = (640, 640)
conf_thres = 0.1
iou_thres = 0.1
hide_labels = False
hide_conf = True
line_thickness = 1

with open('yolov5/joint_all_multi.yaml', 'r') as f:
    LABELS = yaml.safe_load(f)['names']

def joint_detection(img0):
    global imgsz
    img = letterbox(img0, 640, stride=stride, auto=pt)[0]
    img = img.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
    img = np.ascontiguousarray(img)

    im = torch.from_numpy(img).to(device)
    im = im.half() if half else im.float()  # uint8 to fp16/32
    im /= 255  # 0 - 255 to 0.0 - 1.0
    if len(im.shape) == 3:
        im = im[None]  # expand for batch dim
    # Padded resize

    # Convert
    imgsz = check_img_size(imgsz, s=stride)  # check image size

    # Inference
    model.warmup(imgsz=(1 if pt else bs, 3, *imgsz), half=half)  # warmup
    pred = model(im, augment=False, visualize=False)
    t3 = time_sync()

    # NMS
    pred = non_max_suppression(pred, conf_thres, iou_thres, None, False, max_det=1000)

    # Second-stage classifier (optional)
    # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)

    # Process predictions
    for i, det in enumerate(pred):  # per image
        im0 = img0.copy()
        gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
        annotator = Annotator(im0, line_width=line_thickness, example=str(names))
        imc = im0
        if len(det):
            # Rescale boxes from img_size to im0 size
            det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round()

            # Write results
            for *xyxy, conf, cls in reversed(det):
                c = int(cls)  # integer class
                label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
                annotator.box_label(xyxy, label, color=colors(c, True))

        # save as text
        # Write results
        content = {}
        for *xyxy, conf, cls in reversed(det):
            xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
            x, y, width, height = xywh
            current_label = LABELS[int(cls.item())]

            if content.get(current_label, None) is None:
                content[current_label] = []

            current_dict = {'x': x, 'y': y, 'width': width, 'height': height}
            content[current_label].append(current_dict)  # label format

        # Stream results
        im0 = annotator.result()
        return im0, content
        # if view_img:
        #     cv2.imshow(str(p), im0)
        #     cv2.waitKey(1)  # 1 millisecond
        #
        # # Save results (image with detections)
        # if save_img:
        #     if dataset.mode == 'image':
        #         cv2.imwrite(save_path, im0)
        #     else:  # 'video' or 'stream'
        #         if vid_path[i] != save_path:  # new video
        #             vid_path[i] = save_path
        #             if isinstance(vid_writer[i], cv2.VideoWriter):
        #                 vid_writer[i].release()  # release previous video writer
        #             if vid_cap:  # video
        #                 fps = vid_cap.get(cv2.CAP_PROP_FPS)
        #                 w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        #                 h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        #             else:  # stream
        #                 fps, w, h = 30, im0.shape[1], im0.shape[0]
        #             save_path = str(Path(save_path).with_suffix('.mp4'))  # force *.mp4 suffix on results videos
        #             vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
        #         vid_writer[i].write(im0)

    # Print time (inference-only)

iface = gr.Interface(fn=joint_detection, inputs="image", outputs=["image", "json"])
iface.launch(server_name="0.0.0.0", server_port=7860)