|
import sys |
|
import torch |
|
import pickle |
|
import cv2 |
|
import gradio as gr |
|
import numpy as np |
|
|
|
from PIL import Image |
|
from collections import defaultdict |
|
from glob import glob |
|
|
|
from matplotlib import pyplot as plt |
|
from matplotlib import animation |
|
|
|
from easydict import EasyDict as edict |
|
from huggingface_hub import hf_hub_download |
|
|
|
sys.path.append("./rome/") |
|
sys.path.append('./DECA') |
|
|
|
from rome.infer import Infer |
|
from rome.src.utils.processing import process_black_shape, tensor2image |
|
from rome.src.utils.visuals import mask_errosion |
|
|
|
|
|
default_modnet_path = hf_hub_download('Pie31415/rome', 'modnet_photographic_portrait_matting.ckpt') |
|
default_model_path = hf_hub_download('Pie31415/rome', 'rome.pth') |
|
|
|
|
|
args = edict({ |
|
"save_dir": ".", |
|
"save_render": True, |
|
"model_checkpoint": default_model_path, |
|
"modnet_path": default_modnet_path, |
|
"random_seed": 0, |
|
"debug": False, |
|
"verbose": False, |
|
"model_image_size": 256, |
|
"align_source": True, |
|
"align_target": False, |
|
"align_scale": 1.25, |
|
"use_mesh_deformations": False, |
|
"subdivide_mesh": False, |
|
"renderer_sigma": 1e-08, |
|
"renderer_zfar": 100.0, |
|
"renderer_type": "soft_mesh", |
|
"renderer_texture_type": "texture_uv", |
|
"renderer_normalized_alphas": False, |
|
"deca_path": "DECA", |
|
"rome_data_dir": "rome/data", |
|
"autoenc_cat_alphas": False, |
|
"autoenc_align_inputs": False, |
|
"autoenc_use_warp": False, |
|
"autoenc_num_channels": 64, |
|
"autoenc_max_channels": 512, |
|
"autoenc_num_groups": 4, |
|
"autoenc_num_bottleneck_groups": 0, |
|
"autoenc_num_blocks": 2, |
|
"autoenc_num_layers": 4, |
|
"autoenc_block_type": "bottleneck", |
|
"neural_texture_channels": 8, |
|
"num_harmonic_encoding_funcs": 6, |
|
"unet_num_channels": 64, |
|
"unet_max_channels": 512, |
|
"unet_num_groups": 4, |
|
"unet_num_blocks": 1, |
|
"unet_num_layers": 2, |
|
"unet_block_type": "conv", |
|
"unet_skip_connection_type": "cat", |
|
"unet_use_normals_cond": True, |
|
"unet_use_vertex_cond": False, |
|
"unet_use_uvs_cond": False, |
|
"unet_pred_mask": False, |
|
"use_separate_seg_unet": True, |
|
"norm_layer_type": "gn", |
|
"activation_type": "relu", |
|
"conv_layer_type": "ws_conv", |
|
"deform_norm_layer_type": "gn", |
|
"deform_activation_type": "relu", |
|
"deform_conv_layer_type": "ws_conv", |
|
"unet_seg_weight": 0.0, |
|
"unet_seg_type": "bce_with_logits", |
|
"deform_face_tightness": 0.0001, |
|
"use_whole_segmentation": False, |
|
"mask_hair_for_neck": False, |
|
"use_hair_from_avatar": False, |
|
"use_scalp_deforms": True, |
|
"use_neck_deforms": True, |
|
"use_basis_deformer": False, |
|
"use_unet_deformer": True, |
|
"pretrained_encoder_basis_path": "", |
|
"pretrained_vertex_basis_path": "", |
|
"num_basis": 50, |
|
"basis_init": "pca", |
|
"num_vertex": 5023, |
|
"train_basis": True, |
|
"path_to_deca": "DECA", |
|
"path_to_linear_hair_model": "data/linear_hair.pth", |
|
"path_to_mobile_model": "data/disp_model.pth", |
|
"n_scalp": 60, |
|
"use_distill": False, |
|
"use_mobile_version": False, |
|
"deformer_path": "data/rome.pth", |
|
"output_unet_deformer_feats": 32, |
|
"use_deca_details": False, |
|
"use_flametex": False, |
|
"upsample_type": "nearest", |
|
"num_frequencies": 6, |
|
"deform_face_scale_coef": 0.0, |
|
"device": "cuda" |
|
}) |
|
|
|
|
|
generic_model_path = hf_hub_download('Pie31415/rome', 'generic_model.pkl') |
|
deca_model_path = hf_hub_download('Pie31415/rome', 'deca_model.tar') |
|
|
|
with open(generic_model_path, 'rb') as f: |
|
ss = pickle.load(f, encoding='latin1') |
|
|
|
with open('./DECA/data/generic_model.pkl', 'wb') as out: |
|
pickle.dump(ss, out) |
|
|
|
with open(deca_model_path, "rb") as input: |
|
with open('./DECA/data/deca_model.tar', "wb") as out: |
|
for line in input: |
|
out.write(line) |
|
|
|
|
|
infer = Infer(args) |
|
|
|
def image_inference( |
|
source_img: gr.inputs.Image = None, |
|
driver_img: gr.inputs.Image = None |
|
): |
|
out = infer.evaluate(source_img, driver_img, crop_center=False) |
|
res = tensor2image(torch.cat([out['source_information']['data_dict']['source_img'][0].cpu(), |
|
out['source_information']['data_dict']['target_img'][0].cpu(), |
|
out['render_masked'].cpu(), out['pred_target_shape_img'][0].cpu()], dim=2)) |
|
return res[..., ::-1] |
|
|
|
def extract_frames(driver_vid): |
|
image_frames = [] |
|
vid = cv2.VideoCapture(driver_vid) |
|
|
|
while True: |
|
success, img = vid.read() |
|
|
|
if not success: break |
|
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
pil_img = Image.fromarray(img) |
|
image_frames.append(pil_img) |
|
|
|
return image_frames |
|
|
|
def video_inference(source_img, driver_vid): |
|
image_frames = extract_frames(driver_vid) |
|
|
|
resulted_imgs = defaultdict(list) |
|
|
|
video_folder = 'jenya_driver/' |
|
image_frames = sorted(glob(f"{video_folder}/*", recursive=True), key=lambda x: int(x.split('/')[-1][:-4])) |
|
|
|
mask_hard_threshold = 0.5 |
|
N = len(image_frames)//20 |
|
for i in range(0, N, 4): |
|
new_out = infer.evaluate(source_img, Image.open(image_frames[i]), |
|
source_information_for_reuse=out.get('source_information')) |
|
|
|
mask_pred = (new_out['pred_target_unet_mask'].cpu() > mask_hard_threshold).float() |
|
mask_pred = mask_errosion(mask_pred[0].float().numpy() * 255) |
|
render = new_out['pred_target_img'].cpu() * (mask_pred) + (1 - mask_pred) |
|
|
|
normals = process_black_shape(((new_out['pred_target_normal'][0].cpu() + 1) / 2 * mask_pred + (1 - mask_pred) ) ) |
|
normals[normals==0.5]=1. |
|
|
|
resulted_imgs['res_normal'].append(tensor2image(normals)) |
|
resulted_imgs['res_mesh_images'].append(tensor2image(new_out['pred_target_shape_img'][0])) |
|
resulted_imgs['res_renders'].append(tensor2image(render[0])) |
|
|
|
video = np.array(resulted_imgs['res_renders']) |
|
|
|
fig = plt.figure() |
|
im = plt.imshow(video[0,:,:,::-1]) |
|
plt.axis('off') |
|
plt.close() |
|
|
|
def init(): |
|
im.set_data(video[0,:,:,::-1]) |
|
|
|
def animate(i): |
|
im.set_data(video[i,:,:,::-1]) |
|
return im |
|
|
|
anim = animation.FuncAnimation(fig, animate, init_func=init, |
|
frames=video.shape[0], interval=30) |
|
|
|
return anim |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# **<p align='center'>ROME: Realistic one-shot mesh-based head avatars</p>**") |
|
|
|
gr.Markdown( |
|
""" |
|
<p style='text-align: center'> |
|
Create a personal avatar from just a single image using ROME. |
|
<br> <a href='https://arxiv.org/abs/2206.08343' target='_blank'>Paper</a> | <a href='https://samsunglabs.github.io/rome' target='_blank'>Project Page</a> | <a href='https://github.com/SamsungLabs/rome' target='_blank'>Github</a> |
|
</p> |
|
""" |
|
) |
|
|
|
with gr.Tab("Image Inference"): |
|
with gr.Row(): |
|
source_img = gr.Image(type="pil", label="source image", show_label=True) |
|
driver_img = gr.Image(type="pil", label="driver image", show_label=True) |
|
image_output = gr.Image() |
|
image_button = gr.Button("Predict") |
|
with gr.Tab("Video Inference"): |
|
with gr.Row(): |
|
source_img2 = gr.Image(type="pil", label="source image", show_label=True) |
|
driver_vid = gr.Video(label="driver video") |
|
video_output = gr.Image() |
|
video_button = gr.Button("Predict") |
|
|
|
gr.Examples( |
|
examples=[ |
|
["./examples/lincoln.jpg", "./examples/taras2.jpg"], |
|
["./examples/lincoln.jpg", "./examples/taras1.jpg"] |
|
], |
|
inputs=[source_img, driver_img], |
|
outputs=[image_output], |
|
fn=image_inference, |
|
cache_examples=True |
|
) |
|
|
|
image_button.click(image_inference, inputs=[source_img, driver_img], outputs=image_output) |
|
video_button.click(video_inference, inputs=[source_img2, driver_vid], outputs=video_output) |
|
|
|
demo.launch() |