# TODO save & restart from (if it exists) dataframe parquet # TODO is spaces GPU blocking? Seems significantly slower while it's generating new in background import torch # lol DEVICE = 'cuda' STEPS = 6 output_hidden_state = False device = "cuda" dtype = torch.bfloat16 import matplotlib.pyplot as plt import matplotlib import logging from sklearn.linear_model import Ridge import os import imageio import gradio as gr import numpy as np from sklearn.svm import SVC from sklearn.inspection import permutation_importance from sklearn import preprocessing import pandas as pd from apscheduler.schedulers.background import BackgroundScheduler import random import time from PIL import Image from safety_checker_improved import maybe_nsfw torch.set_grad_enabled(False) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True prevs_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate']) import spaces prompt_list = [p for p in list(set( pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str] start_time = time.time() ####################### Setup Model from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, LCMScheduler, AutoencoderTiny, UNet2DConditionModel, AutoencoderKL from transformers import CLIPTextModel from huggingface_hub import hf_hub_download from safetensors.torch import load_file from PIL import Image from transformers import CLIPVisionModelWithProjection import uuid import av def write_video(file_name, images, fps=17): print('Saving') container = av.open(file_name, mode="w") stream = container.add_stream("h264", rate=fps) # stream.options = {'preset': 'faster'} stream.thread_count = 0 stream.width = 512 stream.height = 512 stream.pix_fmt = "yuv420p" for img in images: img = np.array(img) img = np.round(img).astype(np.uint8) frame = av.VideoFrame.from_ndarray(img, format="rgb24") for packet in stream.encode(frame): container.mux(packet) # Flush stream for packet in stream.encode(): container.mux(packet) # Close the file container.close() print('Saved') image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="sdxl_models/image_encoder", torch_dtype=dtype).to(DEVICE) #vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=dtype) # vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=dtype) # vae = compile_unet(vae, config=config) #finetune_path = '''/home/ryn_mote/Misc/finetune-sd1.5/dreambooth-model best''''' #unet = UNet2DConditionModel.from_pretrained(finetune_path+'/unet/').to(dtype) #text_encoder = CLIPTextModel.from_pretrained(finetune_path+'/text_encoder/').to(dtype) unet = UNet2DConditionModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='unet').to(dtype) text_encoder = CLIPTextModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='text_encoder').to(dtype) adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM") pipe = AnimateDiffPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", motion_adapter=adapter, image_encoder=image_encoder, torch_dtype=dtype, unet=unet, text_encoder=text_encoder) pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear") pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora",) pipe.set_adapters(["lcm-lora"], [.9]) pipe.fuse_lora() #pipe = AnimateDiffPipeline.from_pretrained('emilianJR/epiCRealism', torch_dtype=dtype, image_encoder=image_encoder) #pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear") #repo = "ByteDance/AnimateDiff-Lightning" #ckpt = f"animatediff_lightning_4step_diffusers.safetensors" pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15_vit-G.bin", map_location='cpu') # This IP adapter improves outputs substantially. pipe.set_ip_adapter_scale(.8) pipe.unet.fuse_qkv_projections() #pipe.enable_free_init(method="gaussian", use_fast_sampling=True) pipe.to(device=DEVICE) #pipe.unet = torch.compile(pipe.unet) #pipe.vae = torch.compile(pipe.vae) #im_embs = torch.zeros(1, 1, 1, 1280, device=DEVICE, dtype=dtype) #output = pipe(prompt='a person', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[im_embs], num_inference_steps=STEPS) #leave_im_emb, _ = pipe.encode_image( # output.frames[0][len(output.frames[0])//2], DEVICE, 1, output_hidden_state #) #assert len(output.frames[0]) == 16 #leave_im_emb.detach().to('cpu') @spaces.GPU(duration=20) def generate_gpu(in_im_embs): print('start gen') in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0) #im_embs = torch.cat((torch.zeros(1, 1280, device=DEVICE, dtype=dtype), in_im_embs), 0) output = pipe(prompt='a scene', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS) print('image is made') im_emb, _ = pipe.encode_image( output.frames[0][len(output.frames[0])//2], 'cuda', 1, output_hidden_state ) print('im_emb is made') im_emb = im_emb.detach().to('cpu').to(torch.float32) return output, im_emb def generate(in_im_embs): output, im_emb = generate_gpu(in_im_embs) nsfw = maybe_nsfw(output.frames[0][len(output.frames[0])//2]) name = str(uuid.uuid4()).replace("-", "") path = f"/tmp/{name}.mp4" if nsfw: gr.Warning("NSFW content detected.") # TODO could return an automatic dislike of auto dislike on the backend for neither as well; just would need refactoring. return None, im_emb output.frames[0] = output.frames[0] + list(reversed(output.frames[0])) write_video(path, output.frames[0]) return path, im_emb ####################### # TODO add to state instead of shared across all glob_idx = 0 # TODO # We can keep a df of media paths, embeddings, and user ratings. # We can drop by lowest user ratings to keep enough RAM available when we get too many rows. # We can continuously update by who is most recently active in the background & server as we go, plucking using "has been seen" and similarity # to user embeds def get_user_emb(embs, ys): # handle case where every instance of calibration videos is 'Neither' or 'Like' or 'Dislike' if len(list(set(ys))) <= 1: embs.append(.01*torch.randn(1280)) embs.append(.01*torch.randn(1280)) ys.append(0) ys.append(1) print('Fixing only one feedback class available.\n') indices = list(range(len(embs))) # sample only as many negatives as there are positives pos_indices = [i for i in indices if ys[i] == 1] neg_indices = [i for i in indices if ys[i] == 0] #lower = min(len(pos_indices), len(neg_indices)) #neg_indices = random.sample(neg_indices, lower) #pos_indices = random.sample(pos_indices, lower) print(len(neg_indices), len(pos_indices)) # we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749); # this ends up adding a rating but losing an embedding, it seems. # let's take off a rating if so to continue without indexing errors. if len(ys) > len(embs): print('ys are longer than embs; popping latest rating') ys.pop(-1) feature_embs = np.array(torch.stack([embs[i].squeeze().to('cpu') for i in indices]).to('cpu')) #scaler = preprocessing.StandardScaler().fit(feature_embs) #feature_embs = scaler.transform(feature_embs) chosen_y = np.array([ys[i] for i in indices]) print('Gathering coefficients') #lin_class = Ridge(fit_intercept=False).fit(feature_embs, chosen_y) lin_class = SVC(max_iter=20, kernel='linear', C=.1, class_weight='balanced').fit(feature_embs, chosen_y) coef_ = torch.tensor(lin_class.coef_, dtype=torch.double).detach().to('cpu') coef_ = coef_ / coef_.abs().max() * 3 print('Gathered') w = 1# if len(embs) % 2 == 0 else 0 im_emb = w * coef_.to(dtype=dtype) return im_emb def pluck_img(user_id, user_emb): print(user_id, 'user_id') not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]] while len(not_rated_rows) == 0: not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]] time.sleep(.01) print(f'user {user_id} has rated all rows') # TODO optimize this lol best_sim = -100000 for i in not_rated_rows.iterrows(): # TODO sloppy .to but it is 3am. sim = torch.cosine_similarity(i[1]['embeddings'].detach().to('cpu'), user_emb.detach().to('cpu')) if sim > best_sim: best_sim = sim best_row = i[1] img = best_row['paths'] return img def background_next_image(): global prevs_df # only let it get N (maybe 3) ahead of the user #not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]] rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]] while len(rated_rows) < 4: # not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]] rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]] time.sleep(.01) print('all users have 8 or more rows left to rate') latest_user_id = rated_rows.iloc[-1]['latest_user_to_rate'] #rated_rows = prevs_df[[i[1]['user:rating'].get(latest_user_id, None) is not None for i in prevs_df.iterrows()]] #while len(rated_rows) < 4: # rated_rows = prevs_df[[i[1]['user:rating'].get(latest_user_id, None) is not None for i in prevs_df.iterrows()]] # time.sleep(.01) # print('latest user has < 4 rated_rows') print(latest_user_id) embs, ys = pluck_embs_ys(latest_user_id) user_emb = get_user_emb(embs, ys) img, embs = generate(user_emb) print(img) if img: tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate']) tmp_df['paths'] = [img] tmp_df['embeddings'] = [embs] tmp_df['user:rating'] = [{' ': ' '}] prevs_df = pd.concat((prevs_df, tmp_df)) # we can free up storage by deleting the image if len(prevs_df) > 50: oldest_path = prevs_df.iloc[6]['paths'] if os.path.isfile(oldest_path): os.remove(oldest_path) else: # If it fails, inform the user. print("Error: %s file not found" % oldest_path) # only keep 50 images & embeddings & ips, then remove oldest besides calibrating prevs_df = pd.concat((prevs_df.iloc[:6], prevs_df.iloc[7:])) def pluck_embs_ys(user_id): rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]] #not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) == None for i in prevs_df.iterrows()]] #while len(not_rated_rows) == 0: # not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) == None for i in prevs_df.iterrows()]] # rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]] # time.sleep(.01) # print('current user has 0 not_rated_rows') embs = rated_rows['embeddings'].to_list() ys = [i[user_id] for i in rated_rows['user:rating'].to_list()] print('embs', 'ys', embs, ys) return embs, ys def next_image(calibrate_prompts, user_id): global glob_idx glob_idx = glob_idx + 1 print(prevs_df) with torch.no_grad(): if len(calibrate_prompts) > 0: print('######### Calibrating with sample media #########') cal_video = calibrate_prompts.pop(0) image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0] return image, calibrate_prompts else: print('######### Roaming #########') embs, ys = pluck_embs_ys(user_id) user_emb = get_user_emb(embs, ys) image = pluck_img(user_id, user_emb) return image, calibrate_prompts def start(_, calibrate_prompts, user_id, request: gr.Request): image, calibrate_prompts = next_image(calibrate_prompts, user_id) return [ gr.Button(value='Like (L)', interactive=True), gr.Button(value='Neither (Space)', interactive=True, visible=False), gr.Button(value='Dislike (A)', interactive=True), gr.Button(value='Start', interactive=False), image, calibrate_prompts ] def choose(img, choice, calibrate_prompts, user_id, request: gr.Request): global prevs_df if choice == 'Like (L)': choice = 1 elif choice == 'Neither (Space)': img, calibrate_prompts = next_image(calibrate_prompts, user_id) return img, calibrate_prompts else: choice = 0 # if we detected NSFW, leave that area of latent space regardless of how they rated chosen. # TODO skip allowing rating & just continue if img == None: print('NSFW -- choice is disliked') choice = 0 row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()] prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id] img, calibrate_prompts = next_image(calibrate_prompts, user_id) return img, calibrate_prompts css = '''.gradio-container{max-width: 700px !important} #description{text-align: center} #description h1, #description h3{display: block} #description p{margin-top: 0} .fade-in-out {animation: fadeInOut 3s forwards} @keyframes fadeInOut { 0% { background: var(--bg-color); } 100% { background: var(--button-secondary-background-fill); } } ''' js_head = ''' ''' with gr.Blocks(css=css, head=js_head) as demo: gr.Markdown('''# Blue Tigers ### Generative Recommenders for Exporation of Video Explore the latent space without text prompts based on your preferences. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/). ''', elem_id="description") user_id = gr.State(int(str(time.time())[-7:].replace('.', ''))) print('USER_ID: ',user_id) calibrate_prompts = gr.State([ './first.mp4', './second.mp4', './third.mp4', './fourth.mp4', './fifth.mp4', './sixth.mp4', ]) def l(): return None with gr.Row(elem_id='output-image'): img = gr.Video( label='Lightning', autoplay=True, interactive=False, height=512, width=512, include_audio=False, elem_id="video_output" ) img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''') with gr.Row(equal_height=True): b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike") b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither", visible=False) b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like") b1.click( choose, [img, b1, calibrate_prompts, user_id], [img, calibrate_prompts], ) b2.click( choose, [img, b2, calibrate_prompts, user_id], [img, calibrate_prompts], ) b3.click( choose, [img, b3, calibrate_prompts, user_id], [img, calibrate_prompts], ) with gr.Row(): b4 = gr.Button(value='Start') b4.click(start, [b4, calibrate_prompts, user_id], [b1, b2, b3, b4, img, calibrate_prompts] ) with gr.Row(): html = gr.HTML('''
You will calibrate for several videos and then roam.


Note that while the AnimateLCM model with NSFW filtering is unlikely to produce NSFW images, this may still occur, and users should avoid NSFW content when rating.

Thanks to @multimodalart for their contributions to the demo, esp. the interface and @maxbittker for feedback. ''') # TODO quite logging log = logging.getLogger('log_here') log.setLevel(logging.ERROR) scheduler = BackgroundScheduler() scheduler.add_job(func=background_next_image, trigger="interval", seconds=4) scheduler.start() def encode_space(x): im_emb, _ = pipe.encode_image( image, 'cpu', 1, output_hidden_state ) return im_emb.detach().to('cpu').to(torch.float32) # prep our calibration prompts for im in [ './first.mp4', './second.mp4', './third.mp4', './fourth.mp4', './fifth.mp4', './sixth.mp4', ]: tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating']) tmp_df['paths'] = [im] image = list(imageio.imiter(im)) image = image[len(image)//2] im_emb = encode_space(image) tmp_df['embeddings'] = [im_emb.detach().to('cpu')] tmp_df['user:rating'] = [{' ': ' '}] prevs_df = pd.concat((prevs_df, tmp_df)) demo.launch(share=True)