generative_recsys / lightning_app.py
rynmurdock's picture
sig
67e8481
raw
history blame
16.9 kB
import torch
# lol
sidel = 512
DEVICE = 'cuda'
STEPS = 4
output_hidden_state = False
device = "cuda"
dtype = torch.float16
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('TkAgg')
from sklearn.linear_model import LinearRegression
from sfast.compilers.diffusion_pipeline_compiler import (compile, compile_unet,
CompilationConfig)
config = CompilationConfig.Default()
try:
import triton
config.enable_triton = True
except ImportError:
print('Triton not installed, skip')
config.enable_cuda_graph = True
config.enable_jit = True
config.enable_jit_freeze = True
config.enable_cnn_optimization = True
config.preserve_parameters = False
config.prefer_lowp_gemm = True
import imageio
import gradio as gr
import numpy as np
from sklearn.svm import SVC
from sklearn.inspection import permutation_importance
from sklearn import preprocessing
import pandas as pd
import random
import time
from PIL import Image
from safety_checker_improved import maybe_nsfw
torch.set_grad_enabled(False)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# TODO put back?
# import spaces
prompt_list = [p for p in list(set(
pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
start_time = time.time()
####################### Setup Model
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, LCMScheduler, ConsistencyDecoderVAE, AutoencoderTiny
from hyper_tile import split_attention, flush
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from PIL import Image
from transformers import CLIPVisionModelWithProjection
import uuid
import av
def write_video(file_name, images, fps=10):
print('Saving')
container = av.open(file_name, mode="w")
stream = container.add_stream("h264", rate=fps)
stream.width = sidel
stream.height = sidel
stream.pix_fmt = "yuv420p"
for img in images:
img = np.array(img)
img = np.round(img).astype(np.uint8)
frame = av.VideoFrame.from_ndarray(img, format="rgb24")
for packet in stream.encode(frame):
container.mux(packet)
# Flush stream
for packet in stream.encode():
container.mux(packet)
# Close the file
container.close()
print('Saved')
bases = {
#"basem": "emilianJR/epiCRealism"
#SG161222/Realistic_Vision_V6.0_B1_noVAE
#runwayml/stable-diffusion-v1-5
#frankjoshua/realisticVisionV51_v51VAE
#Lykon/dreamshaper-7
}
image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=dtype).to(DEVICE)
vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=dtype)
# vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=dtype)
# vae = compile_unet(vae, config=config)
#adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
#pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapter=adapter, image_encoder=image_encoder, torch_dtype=dtype)
#pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
#pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora",)
#pipe.set_adapters(["lcm-lora"], [1])
#pipe.fuse_lora()
pipe = AnimateDiffPipeline.from_pretrained('emilianJR/epiCRealism', torch_dtype=dtype, image_encoder=image_encoder, vae=vae)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
repo = "ByteDance/AnimateDiff-Lightning"
ckpt = f"animatediff_lightning_4step_diffusers.safetensors"
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device='cpu'), strict=False)
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin", map_location='cpu')
pipe.set_ip_adapter_scale(.8)
# pipe.unet.fuse_qkv_projections()
#pipe.enable_free_init(method="gaussian", use_fast_sampling=True)
pipe = compile(pipe, config=config)
pipe.to(device=DEVICE)
# THIS WOULD NEED PATCHING TODO
with split_attention(pipe.vae, tile_size=128, swap_size=2, disable=False, aspect_ratio=1):
# ! Change the tile_size and disable to see their effects
with split_attention(pipe.unet, tile_size=128, swap_size=2, disable=False, aspect_ratio=1):
im_embs = torch.zeros(1, 1, 1, 1024, device=DEVICE, dtype=dtype)
output = pipe(prompt='a person', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[im_embs], num_inference_steps=STEPS)
leave_im_emb, _ = pipe.encode_image(
output.frames[0][len(output.frames[0])//2], DEVICE, 1, output_hidden_state
)
assert len(output.frames[0]) == 16
leave_im_emb.to('cpu')
# TODO put back
# @spaces.GPU()
def generate(prompt, in_im_embs=None, base='basem'):
if in_im_embs == None:
in_im_embs = torch.zeros(1, 1, 1, 1024, device=DEVICE, dtype=dtype)
#in_im_embs = in_im_embs / torch.norm(in_im_embs)
else:
in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0)
#im_embs = torch.cat((torch.zeros(1, 1024, device=DEVICE, dtype=dtype), in_im_embs), 0)
with split_attention(pipe.unet, tile_size=128, swap_size=2, disable=False, aspect_ratio=1):
# ! Change the tile_size and disable to see their effects
with split_attention(pipe.vae, tile_size=128, disable=False, aspect_ratio=1):
output = pipe(prompt=prompt, guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
im_emb, _ = pipe.encode_image(
output.frames[0][len(output.frames[0])//2], DEVICE, 1, output_hidden_state
)
nsfw = maybe_nsfw(output.frames[0][len(output.frames[0])//2])
name = str(uuid.uuid4()).replace("-", "")
path = f"/tmp/{name}.mp4"
if nsfw:
gr.Warning("NSFW content detected.")
# TODO could return an automatic dislike of auto dislike on the backend for neither as well; just would need refactoring.
return None, im_emb
plt.close('all')
plt.hist(np.array(im_emb.to('cpu')).flatten(), bins=5)
plt.savefig('real_im_emb_plot.jpg')
write_video(path, output.frames[0])
return path, im_emb.to('cpu')
#######################
# TODO add to state instead of shared across all
glob_idx = 0
def next_image(embs, ys, calibrate_prompts):
global glob_idx
glob_idx = glob_idx + 1
with torch.no_grad():
if len(calibrate_prompts) > 0:
print('######### Calibrating with sample prompts #########')
prompt = calibrate_prompts.pop(0)
print(prompt)
image, img_embs = generate(prompt)
embs += img_embs
print(len(embs))
return image, embs, ys, calibrate_prompts
else:
print('######### Roaming #########')
# sample a .8 of rated embeddings for some stochasticity, or at least two embeddings.
# could take a sample < len(embs)
#n_to_choose = max(int((len(embs))), 2)
#indices = random.sample(range(len(embs)), n_to_choose)
# sample only as many negatives as there are positives
#pos_indices = [i for i in indices if ys[i] == 1]
#neg_indices = [i for i in indices if ys[i] == 0]
#lower = min(len(pos_indices), len(neg_indices))
#neg_indices = random.sample(neg_indices, lower)
#pos_indices = random.sample(pos_indices, lower)
#indices = neg_indices + pos_indices
pos_indices = [i for i in range(len(embs)) if ys[i] == 1]
neg_indices = [i for i in range(len(embs)) if ys[i] == 0]
# the embs & ys stay tied by index but we shuffle to drop randomly
random.shuffle(pos_indices)
random.shuffle(neg_indices)
#if len(pos_indices) - len(neg_indices) > 48 and len(pos_indices) > 80:
# pos_indices = pos_indices[32:]
if len(neg_indices) - len(pos_indices) > 48/16 and len(pos_indices) > 120/16:
pos_indices = pos_indices[1:]
if len(neg_indices) - len(pos_indices) > 48/16 and len(neg_indices) > 200/16:
neg_indices = neg_indices[2:]
print(len(pos_indices), len(neg_indices))
indices = pos_indices + neg_indices
embs = [embs[i] for i in indices]
ys = [ys[i] for i in indices]
indices = list(range(len(embs)))
# handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike'
if len(list(set(ys))) <= 1:
embs.append(.01*torch.randn(1024))
embs.append(.01*torch.randn(1024))
ys.append(0)
ys.append(1)
# also add the latest 0 and the latest 1
has_0 = False
has_1 = False
for i in reversed(range(len(ys))):
if ys[i] == 0 and has_0 == False:
indices.append(i)
has_0 = True
elif ys[i] == 1 and has_1 == False:
indices.append(i)
has_1 = True
if has_0 and has_1:
break
# we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749);
# this ends up adding a rating but losing an embedding, it seems.
# let's take off a rating if so to continue without indexing errors.
if len(ys) > len(embs):
print('ys are longer than embs; popping latest rating')
ys.pop(-1)
feature_embs = np.array(torch.stack([embs[i].to('cpu') for i in indices] + [leave_im_emb[0].to('cpu')]).to('cpu'))
scaler = preprocessing.StandardScaler().fit(feature_embs)
feature_embs = scaler.transform(feature_embs)
chosen_y = np.array([ys[i] for i in indices] + [0])
print('Gathering coefficients')
#lin_class = LinearRegression(fit_intercept=False).fit(feature_embs, chosen_y)
lin_class = SVC(max_iter=50000, kernel='linear', class_weight='balanced', C=1).fit(feature_embs, chosen_y)
coef_ = torch.tensor(lin_class.coef_, dtype=torch.double)
coef_ = coef_ / coef_.abs().max() * 3
print(coef_.shape, 'COEF')
plt.close('all')
plt.hist(np.array(coef_).flatten(), bins=5)
plt.savefig('plot.jpg')
print(coef_)
print('Gathered')
rng_prompt = random.choice(prompt_list)
w = 1# if len(embs) % 2 == 0 else 0
im_emb = w * coef_.to(dtype=dtype)
prompt= 'the scene' if glob_idx % 2 == 0 else rng_prompt
print(prompt)
image, im_emb = generate(prompt, im_emb)
embs += im_emb
if len(embs) > 700/16:
embs = embs[1:]
ys = ys[1:]
return image, embs, ys, calibrate_prompts
def start(_, embs, ys, calibrate_prompts):
image, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
return [
gr.Button(value='Like (L)', interactive=True),
gr.Button(value='Neither (Space)', interactive=True),
gr.Button(value='Dislike (A)', interactive=True),
gr.Button(value='Start', interactive=False),
image,
embs,
ys,
calibrate_prompts
]
def choose(img, choice, embs, ys, calibrate_prompts):
if choice == 'Like (L)':
choice = 1
elif choice == 'Neither (Space)':
embs = embs[:-1]
img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
return img, embs, ys, calibrate_prompts
else:
choice = 0
# if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
# TODO skip allowing rating
if img == None:
print('NSFW -- choice is disliked')
choice = 0
ys += [choice]*1
img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
return img, embs, ys, calibrate_prompts
css = '''.gradio-container{max-width: 700px !important}
#description{text-align: center}
#description h1, #description h3{display: block}
#description p{margin-top: 0}
.fade-in-out {animation: fadeInOut 3s forwards}
@keyframes fadeInOut {
0% {
background: var(--bg-color);
}
100% {
background: var(--button-secondary-background-fill);
}
}
'''
js_head = '''
<script>
document.addEventListener('keydown', function(event) {
if (event.key === 'a' || event.key === 'A') {
// Trigger click on 'dislike' if 'A' is pressed
document.getElementById('dislike').click();
} else if (event.key === ' ' || event.keyCode === 32) {
// Trigger click on 'neither' if Spacebar is pressed
document.getElementById('neither').click();
} else if (event.key === 'l' || event.key === 'L') {
// Trigger click on 'like' if 'L' is pressed
document.getElementById('like').click();
}
});
function fadeInOut(button, color) {
button.style.setProperty('--bg-color', color);
button.classList.remove('fade-in-out');
void button.offsetWidth; // This line forces a repaint by accessing a DOM property
button.classList.add('fade-in-out');
button.addEventListener('animationend', () => {
button.classList.remove('fade-in-out'); // Reset the animation state
}, {once: true});
}
document.body.addEventListener('click', function(event) {
const target = event.target;
if (target.id === 'dislike') {
fadeInOut(target, '#ff1717');
} else if (target.id === 'like') {
fadeInOut(target, '#006500');
} else if (target.id === 'neither') {
fadeInOut(target, '#cccccc');
}
});
</script>
'''
with gr.Blocks(css=css, head=js_head) as demo:
gr.Markdown('''### Blue Tigers: Generative Recommenders for Exporation of Video
Explore the latent space without text prompts based on your preferences. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/).
''', elem_id="description")
embs = gr.State([])
ys = gr.State([])
calibrate_prompts = gr.State([
'the moon is melting into my glass of tea',
'a sea slug -- pair of claws scuttling -- jelly fish glowing',
'an adorable creature. It may be a goblin or a pig or a slug.',
'an animation about a gorgeous nebula',
'an octopus writhes',
])
def l():
return None
with gr.Row(elem_id='output-image'):
img = gr.Video(
label='Lightning',
autoplay=True,
interactive=False,
height=sidel,
width=sidel,
include_audio=False,
elem_id="video_output"
)
img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
with gr.Row(equal_height=True):
b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither")
b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like")
b1.click(
choose,
[img, b1, embs, ys, calibrate_prompts],
[img, embs, ys, calibrate_prompts]
)
b2.click(
choose,
[img, b2, embs, ys, calibrate_prompts],
[img, embs, ys, calibrate_prompts]
)
b3.click(
choose,
[img, b3, embs, ys, calibrate_prompts],
[img, embs, ys, calibrate_prompts]
)
with gr.Row():
b4 = gr.Button(value='Start')
b4.click(start,
[b4, embs, ys, calibrate_prompts],
[b1, b2, b3, b4, img, embs, ys, calibrate_prompts])
with gr.Row():
html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several prompts and then roam. </ div><br><br><br>
<div style='text-align:center; font-size:14px'>Note that while the AnimateDiff-Lightning model with NSFW filtering is unlikely to produce NSFW images, this may still occur, and users should avoid NSFW content when rating.
</ div>
<br><br>
<div style='text-align:center; font-size:14px'>Thanks to @multimodalart for their contributions to the demo, esp. the interface and @maxbittker for feedback.
</ div>''')
demo.launch(share=True)