File size: 4,647 Bytes
4bec64b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import os
from pathlib import Path
from PIL import Image
import torch
import torch.backends.cudnn as cudnn
from numpy import random
from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages
from utils.general import (
check_img_size, non_max_suppression, apply_classifier, scale_coords, xyxy2xywh, plot_one_box, strip_optimizer)
from utils.torch_utils import select_device, load_classifier, time_synchronized
import gradio as gr
import huggingface_hub
from crop import crop
class FaceCrop:
def __init__(self):
self.device = select_device()
self.half = self.device.type != 'cpu'
self.results = {}
def load_dataset(self, source):
self.source = source
self.dataset = LoadImages(source)
print(f'Successfully load {source}')
def load_model(self, model):
self.model = attempt_load(model, map_location=self.device)
if self.half:
self.model.half()
print(f'Successfully load model weights from {model}')
def set_crop_config(self, target_size, mode=0, face_ratio=3, threshold=1.5):
self.target_size = target_size
self.mode = mode
self.face_ratio = face_ratio
self.threshold = threshold
def info(self):
attributes = dir(self)
for attribute in attributes:
if not attribute.startswith('__') and not callable(getattr(self, attribute)):
value = getattr(self, attribute)
print(attribute, " = ", value)
def process(self):
for path, img, im0s, vid_cap in self.dataset:
img = torch.from_numpy(img).to(self.device)
img = img.half() if self.half else img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
# Inference
pred = self.model(img, augment=False)[0]
# Apply NMS
pred = non_max_suppression(pred)
# Process detections
for i, det in enumerate(pred): # detections per image
p, s, im0 = path, '', im0s
in_path = str(Path(self.source) / Path(p).name)
#txt_path = str(Path(out) / Path(p).stem)
s += '%gx%g ' % img.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
if det is not None and len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
# Write results
ind = 0
for *xyxy, conf, cls in det:
if conf > 0.6: # Write to file
out_path = os.path.join(str(Path(self.out_folder)), Path(p).name.replace('.', '_'+str(ind)+'.'))
x, y, w, h = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()
self.results[ind] = crop(in_path, (x, y), out_path, mode=self.mode, size=self.target_size, box=(w, h), face_ratio=self.face_ratio, shreshold=self.threshold)
ind += 1
def run(img, mode, width, height):
face_crop_pipeline.set_crop_config(mode=mode, target_size=(width,height))
face_crop_pipeline.process
return face_crop_pipeline.results[0]
if __name__ == '__main__':
model_path = huggingface_hub.hf_hub_download("Carzit/yolo5x_anime", "yolo5x_anime.pt")
face_crop_pipeline = FaceCrop()
face_crop_pipeline.load_model(model_path)
app = gr.Blocks()
with app:
gr.Markdown("# Anime Face Crop\n\n"
"![visitor badge](https://visitor-badge.glitch.me/badge?page_id=skytnt.animeseg)\n\n"
"demo for [https://github.com/SkyTNT/anime-segmentation/](https://github.com/SkyTNT/anime-segmentation/)")
with gr.Row():
input_img = gr.Image(label="input image")
output_img = gr.Image(label="result", image_mode="RGB")
crop_mode = gr.Dropdown([0, 1, 2, 3], label="Crop Mode", info="0:Auto; 1:No Scale; 2:Full Screen; 3:Fixed Face Ratio")
tgt_width = gr.Slider(10, 2048, value=512, label="Width")
tgt_height = gr.Slider(10, 2048, value=512, label="Height")
run_btn = gr.Button(variant="primary")
run_btn.click(run, [input_img, crop_mode, tgt_width, tgt_height], [output_img])
app.launch()
|