|
import torch |
|
from enum import Enum |
|
import gc |
|
import numpy as np |
|
import jax.numpy as jnp |
|
import jax |
|
|
|
from PIL import Image |
|
from typing import List |
|
|
|
from flax.training.common_utils import shard |
|
from flax.jax_utils import replicate |
|
from flax import jax_utils |
|
import einops |
|
|
|
from transformers import CLIPTokenizer, CLIPFeatureExtractor, FlaxCLIPTextModel |
|
from diffusers import ( |
|
FlaxDDIMScheduler, |
|
FlaxAutoencoderKL, |
|
FlaxUNet2DConditionModel as VanillaFlaxUNet2DConditionModel, |
|
) |
|
from text_to_animation.models.unet_2d_condition_flax import FlaxUNet2DConditionModel |
|
from diffusers import FlaxControlNetModel |
|
|
|
from text_to_animation.pipelines.text_to_video_pipeline_flax import ( |
|
FlaxTextToVideoPipeline, |
|
) |
|
|
|
import utils.utils as utils |
|
import utils.gradio_utils as gradio_utils |
|
import os |
|
|
|
on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR" |
|
|
|
unshard = lambda x: einops.rearrange(x, "d b ... -> (d b) ...") |
|
|
|
|
|
class ModelType(Enum): |
|
Text2Video = 1 |
|
ControlNetPose = 2 |
|
StableDiffusion = 3 |
|
|
|
|
|
def replicate_devices(array): |
|
return jnp.expand_dims(array, 0).repeat(jax.device_count(), 0) |
|
|
|
|
|
class ControlAnimationModel: |
|
def __init__(self, dtype, **kwargs): |
|
self.dtype = dtype |
|
self.rng = jax.random.PRNGKey(0) |
|
self.pipe = None |
|
self.model_type = None |
|
|
|
self.states = {} |
|
self.model_name = "" |
|
|
|
def set_model( |
|
self, |
|
model_id: str, |
|
**kwargs, |
|
): |
|
if hasattr(self, "pipe") and self.pipe is not None: |
|
del self.pipe |
|
self.pipe = None |
|
gc.collect() |
|
|
|
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( |
|
"fusing/stable-diffusion-v1-5-controlnet-openpose", |
|
from_pt=True, |
|
dtype=jnp.float16, |
|
) |
|
|
|
scheduler, scheduler_state = FlaxDDIMScheduler.from_pretrained( |
|
model_id, subfolder="scheduler", from_pt=True |
|
) |
|
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") |
|
feature_extractor = CLIPFeatureExtractor.from_pretrained( |
|
model_id, subfolder="feature_extractor" |
|
) |
|
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( |
|
model_id, subfolder="unet", from_pt=True, dtype=self.dtype |
|
) |
|
unet_vanilla = VanillaFlaxUNet2DConditionModel.from_config( |
|
model_id, subfolder="unet", from_pt=True, dtype=self.dtype |
|
) |
|
vae, vae_params = FlaxAutoencoderKL.from_pretrained( |
|
model_id, subfolder="vae", from_pt=True, dtype=self.dtype |
|
) |
|
text_encoder = FlaxCLIPTextModel.from_pretrained( |
|
model_id, subfolder="text_encoder", from_pt=True, dtype=self.dtype |
|
) |
|
self.pipe = FlaxTextToVideoPipeline( |
|
vae=vae, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
unet=unet, |
|
unet_vanilla=unet_vanilla, |
|
controlnet=controlnet, |
|
scheduler=scheduler, |
|
safety_checker=None, |
|
feature_extractor=feature_extractor, |
|
) |
|
self.params = { |
|
"unet": unet_params, |
|
"vae": vae_params, |
|
"scheduler": scheduler_state, |
|
"controlnet": controlnet_params, |
|
"text_encoder": text_encoder.params, |
|
} |
|
self.p_params = jax_utils.replicate(self.params) |
|
self.model_name = model_id |
|
|
|
def generate_initial_frames( |
|
self, |
|
prompt: str, |
|
video_path: str, |
|
n_prompt: str = "", |
|
num_imgs: int = 4, |
|
resolution: int = 512, |
|
model_id: str = "runwayml/stable-diffusion-v1-5", |
|
) -> List[Image.Image]: |
|
self.set_model(model_id=model_id) |
|
|
|
video_path = gradio_utils.motion_to_video_path(video_path) |
|
|
|
added_prompt = "high quality, best quality, HD, clay stop-motion, claymation, HQ, masterpiece, art, smooth" |
|
prompts = added_prompt + ", " + prompt |
|
|
|
added_n_prompt = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly" |
|
negative_prompts = added_n_prompt + ", " + n_prompt |
|
|
|
video, fps = utils.prepare_video( |
|
video_path, resolution, None, self.dtype, False, output_fps=4 |
|
) |
|
control = utils.pre_process_pose(video, apply_pose_detect=False) |
|
|
|
seeds = [seed for seed in jax.random.randint(self.rng, [num_imgs], 0, 65536)] |
|
prngs = [jax.random.PRNGKey(seed) for seed in seeds] |
|
print(seeds) |
|
images = self.pipe.generate_starting_frames( |
|
params=self.p_params, |
|
prngs=prngs, |
|
controlnet_image=control, |
|
prompt=prompts, |
|
neg_prompt=negative_prompts, |
|
) |
|
|
|
images = [np.array(images[i]) for i in range(images.shape[0])] |
|
|
|
return images |
|
|
|
def generate_video_from_frame(self, controlnet_video, prompt, seed, neg_prompt=""): |
|
|
|
prng_seed = jax.random.PRNGKey(seed) |
|
len_vid = controlnet_video.shape[0] |
|
|
|
added_prompt = "high quality, best quality, HD, clay stop-motion, claymation, HQ, masterpiece, art, smooth" |
|
prompts = added_prompt + ", " + prompt |
|
|
|
added_n_prompt = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly" |
|
negative_prompts = added_n_prompt + ", " + neg_prompt |
|
|
|
|
|
|
|
|
|
prompt_ids = self.pipe.prepare_text_inputs([prompts] * len_vid) |
|
n_prompt_ids = self.pipe.prepare_text_inputs([negative_prompts] * len_vid) |
|
prng = replicate_devices( |
|
prng_seed |
|
) |
|
image = replicate_devices(controlnet_video) |
|
prompt_ids = replicate_devices(prompt_ids) |
|
n_prompt_ids = replicate_devices(n_prompt_ids) |
|
motion_field_strength_x = replicate_devices(jnp.array(3)) |
|
motion_field_strength_y = replicate_devices(jnp.array(4)) |
|
smooth_bg_strength = replicate_devices(jnp.array(0.8)) |
|
vid = ( |
|
self.pipe( |
|
image=image, |
|
prompt_ids=prompt_ids, |
|
neg_prompt_ids=n_prompt_ids, |
|
params=self.p_params, |
|
prng_seed=prng, |
|
jit=True, |
|
smooth_bg_strength=smooth_bg_strength, |
|
motion_field_strength_x=motion_field_strength_x, |
|
motion_field_strength_y=motion_field_strength_y, |
|
).images |
|
)[0] |
|
return utils.create_gif(np.array(vid), 4, path=None, watermark=None) |
|
|