File size: 4,647 Bytes
4bec64b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
from pathlib import Path

from PIL import Image
import torch
import torch.backends.cudnn as cudnn
from numpy import random

from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages
from utils.general import (
    check_img_size, non_max_suppression, apply_classifier, scale_coords, xyxy2xywh, plot_one_box, strip_optimizer)
from utils.torch_utils import select_device, load_classifier, time_synchronized

import gradio as gr
import huggingface_hub

from crop import crop

class FaceCrop:
    def __init__(self):
        self.device = select_device()
        self.half = self.device.type != 'cpu'
        self.results = {}
    
    def load_dataset(self, source):
        self.source = source
        self.dataset = LoadImages(source)
        print(f'Successfully load {source}')
    
    def load_model(self, model):
        self.model = attempt_load(model, map_location=self.device)
        if self.half:
            self.model.half()
        print(f'Successfully load model weights from {model}')
    
    def set_crop_config(self, target_size, mode=0, face_ratio=3, threshold=1.5):
        self.target_size = target_size
        self.mode = mode
        self.face_ratio = face_ratio
        self.threshold = threshold
    
    def info(self):
        attributes = dir(self)
        for attribute in attributes:
            if not attribute.startswith('__') and not callable(getattr(self, attribute)):
                value = getattr(self, attribute)
                print(attribute, " = ", value)
    
    def process(self):
        for path, img, im0s, vid_cap in self.dataset:
            img = torch.from_numpy(img).to(self.device)
            img = img.half() if self.half else img.float()  # uint8 to fp16/32
            img /= 255.0  # 0 - 255 to 0.0 - 1.0
            if img.ndimension() == 3:
                img = img.unsqueeze(0)

            # Inference
            pred = self.model(img, augment=False)[0]

            # Apply NMS
            pred = non_max_suppression(pred)

            # Process detections
            for i, det in enumerate(pred):  # detections per image

                p, s, im0 = path, '', im0s

                in_path = str(Path(self.source) / Path(p).name)
                
                #txt_path = str(Path(out) / Path(p).stem)
                s += '%gx%g ' % img.shape[2:]  # print string
                gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh

                if det is not None and len(det):
                    # Rescale boxes from img_size to im0 size
                    det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()

                    # Write results
                    ind = 0
                    for *xyxy, conf, cls in det:
                        if conf > 0.6:  # Write to file
                            out_path = os.path.join(str(Path(self.out_folder)), Path(p).name.replace('.', '_'+str(ind)+'.'))
                            
                            x, y, w, h = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()
                            self.results[ind] = crop(in_path, (x, y), out_path, mode=self.mode, size=self.target_size, box=(w, h), face_ratio=self.face_ratio, shreshold=self.threshold)

                            ind += 1
                            
def run(img, mode, width, height):
    face_crop_pipeline.set_crop_config(mode=mode, target_size=(width,height))
    face_crop_pipeline.process
    return face_crop_pipeline.results[0]

if __name__ == '__main__':
    model_path = huggingface_hub.hf_hub_download("Carzit/yolo5x_anime", "yolo5x_anime.pt")
    face_crop_pipeline = FaceCrop()
    face_crop_pipeline.load_model(model_path)
    

    app = gr.Blocks()
    with app:
        gr.Markdown("# Anime Face Crop\n\n"
                    "![visitor badge](https://visitor-badge.glitch.me/badge?page_id=skytnt.animeseg)\n\n"
                    "demo for [https://github.com/SkyTNT/anime-segmentation/](https://github.com/SkyTNT/anime-segmentation/)")
        with gr.Row():
            input_img = gr.Image(label="input image")
            output_img = gr.Image(label="result", image_mode="RGB")
            crop_mode = gr.Dropdown([0, 1, 2, 3], label="Crop Mode", info="0:Auto; 1:No Scale; 2:Full Screen; 3:Fixed Face Ratio")
            tgt_width = gr.Slider(10, 2048, value=512, label="Width")
            tgt_height = gr.Slider(10, 2048, value=512, label="Height")

            run_btn = gr.Button(variant="primary")

        run_btn.click(run, [input_img, crop_mode, tgt_width, tgt_height], [output_img])
    app.launch()