|
import os |
|
from pathlib import Path |
|
|
|
from PIL import Image |
|
import torch |
|
import torch.backends.cudnn as cudnn |
|
from numpy import random |
|
|
|
from models.experimental import attempt_load |
|
from utils.datasets import LoadStreams, LoadImages |
|
from utils.general import ( |
|
check_img_size, non_max_suppression, apply_classifier, scale_coords, xyxy2xywh, plot_one_box, strip_optimizer) |
|
from utils.torch_utils import select_device, load_classifier, time_synchronized |
|
|
|
import gradio as gr |
|
import huggingface_hub |
|
|
|
from crop import crop |
|
|
|
class FaceCrop: |
|
def __init__(self): |
|
self.device = select_device() |
|
self.half = self.device.type != 'cpu' |
|
self.results = {} |
|
|
|
def load_dataset(self, source): |
|
self.source = source |
|
self.dataset = LoadImages(source) |
|
print(f'Successfully load {source}') |
|
|
|
def load_model(self, model): |
|
self.model = attempt_load(model, map_location=self.device) |
|
if self.half: |
|
self.model.half() |
|
print(f'Successfully load model weights from {model}') |
|
|
|
def set_crop_config(self, target_size, mode=0, face_ratio=3, threshold=1.5): |
|
self.target_size = target_size |
|
self.mode = mode |
|
self.face_ratio = face_ratio |
|
self.threshold = threshold |
|
|
|
def info(self): |
|
attributes = dir(self) |
|
for attribute in attributes: |
|
if not attribute.startswith('__') and not callable(getattr(self, attribute)): |
|
value = getattr(self, attribute) |
|
print(attribute, " = ", value) |
|
|
|
def process(self): |
|
for path, img, im0s, vid_cap in self.dataset: |
|
img = torch.from_numpy(img).to(self.device) |
|
img = img.half() if self.half else img.float() |
|
img /= 255.0 |
|
if img.ndimension() == 3: |
|
img = img.unsqueeze(0) |
|
|
|
|
|
pred = self.model(img, augment=False)[0] |
|
|
|
|
|
pred = non_max_suppression(pred) |
|
|
|
|
|
for i, det in enumerate(pred): |
|
|
|
p, s, im0 = path, '', im0s |
|
|
|
in_path = str(Path(self.source) / Path(p).name) |
|
|
|
|
|
s += '%gx%g ' % img.shape[2:] |
|
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] |
|
|
|
if det is not None and len(det): |
|
|
|
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() |
|
|
|
|
|
ind = 0 |
|
for *xyxy, conf, cls in det: |
|
if conf > 0.6: |
|
out_path = os.path.join(str(Path(self.out_folder)), Path(p).name.replace('.', '_'+str(ind)+'.')) |
|
|
|
x, y, w, h = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() |
|
self.results[ind] = crop(in_path, (x, y), out_path, mode=self.mode, size=self.target_size, box=(w, h), face_ratio=self.face_ratio, shreshold=self.threshold) |
|
|
|
ind += 1 |
|
|
|
def run(img, mode, width, height): |
|
face_crop_pipeline.set_crop_config(mode=mode, target_size=(width,height)) |
|
face_crop_pipeline.process |
|
return face_crop_pipeline.results[0] |
|
|
|
if __name__ == '__main__': |
|
model_path = huggingface_hub.hf_hub_download("Carzit/yolo5x_anime", "yolo5x_anime.pt") |
|
face_crop_pipeline = FaceCrop() |
|
face_crop_pipeline.load_model(model_path) |
|
|
|
|
|
app = gr.Blocks() |
|
with app: |
|
gr.Markdown("# Anime Face Crop\n\n" |
|
"![visitor badge](https://visitor-badge.glitch.me/badge?page_id=skytnt.animeseg)\n\n" |
|
"demo for [https://github.com/SkyTNT/anime-segmentation/](https://github.com/SkyTNT/anime-segmentation/)") |
|
with gr.Row(): |
|
input_img = gr.Image(label="input image") |
|
output_img = gr.Image(label="result", image_mode="RGB") |
|
crop_mode = gr.Dropdown([0, 1, 2, 3], label="Crop Mode", info="0:Auto; 1:No Scale; 2:Full Screen; 3:Fixed Face Ratio") |
|
tgt_width = gr.Slider(10, 2048, value=512, label="Width") |
|
tgt_height = gr.Slider(10, 2048, value=512, label="Height") |
|
|
|
run_btn = gr.Button(variant="primary") |
|
|
|
run_btn.click(run, [input_img, crop_mode, tgt_width, tgt_height], [output_img]) |
|
app.launch() |
|
|
|
|
|
|