rynmurdock commited on
Commit
e573858
1 Parent(s): f94c06d

lfs and sync with blue-tigers github

Browse files
Files changed (5) hide show
  1. .gitattributes +20 -0
  2. app.py +78 -108
  3. lightning_app.py +0 -452
  4. requirements.txt +1 -3
  5. twitter_prompts.csv +0 -72
.gitattributes CHANGED
@@ -1 +1,21 @@
1
  nsfweffnetv2-b02-3epochs.h5 filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  nsfweffnetv2-b02-3epochs.h5 filter=lfs diff=lfs merge=lfs -text
2
+ fifth.gemb_.pt filter=lfs diff=lfs merge=lfs -text
3
+ ninth.im_.pt filter=lfs diff=lfs merge=lfs -text
4
+ tenth.gemb_.pt filter=lfs diff=lfs merge=lfs -text
5
+ third.gemb_.pt filter=lfs diff=lfs merge=lfs -text
6
+ eigth.gemb_.pt filter=lfs diff=lfs merge=lfs -text
7
+ first.gemb_.pt filter=lfs diff=lfs merge=lfs -text
8
+ fourth.gemb_.pt filter=lfs diff=lfs merge=lfs -text
9
+ ninth.gemb_.pt filter=lfs diff=lfs merge=lfs -text
10
+ sixth.gemb_.pt filter=lfs diff=lfs merge=lfs -text
11
+ tenth.im_.pt filter=lfs diff=lfs merge=lfs -text
12
+ eigth.im_.pt filter=lfs diff=lfs merge=lfs -text
13
+ seventh.gemb_.pt filter=lfs diff=lfs merge=lfs -text
14
+ sixth.im_.pt filter=lfs diff=lfs merge=lfs -text
15
+ third.im_.pt filter=lfs diff=lfs merge=lfs -text
16
+ fifth.im_.pt filter=lfs diff=lfs merge=lfs -text
17
+ first.im_.pt filter=lfs diff=lfs merge=lfs -text
18
+ fourth.im_.pt filter=lfs diff=lfs merge=lfs -text
19
+ second.gemb_.pt filter=lfs diff=lfs merge=lfs -text
20
+ second.im_.pt filter=lfs diff=lfs merge=lfs -text
21
+ seventh.im_.pt filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -10,12 +10,9 @@ STEPS = 6
10
  output_hidden_state = False
11
  device = "cuda"
12
  dtype = torch.bfloat16
 
13
 
14
- import matplotlib.pyplot as plt
15
- import matplotlib
16
  import logging
17
-
18
-
19
  import os
20
  import imageio
21
  import gradio as gr
@@ -24,8 +21,6 @@ from sklearn.svm import SVC
24
  from sklearn import preprocessing
25
  import pandas as pd
26
  from apscheduler.schedulers.background import BackgroundScheduler
27
- import sched
28
- import threading
29
 
30
  import random
31
  import time
@@ -104,7 +99,7 @@ pipe = AnimateDiffPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", mot
104
  unet=unet, text_encoder=text_encoder)
105
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
106
  pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora",)
107
- pipe.set_adapters(["lcm-lora"], [.9])
108
  pipe.fuse_lora()
109
 
110
 
@@ -121,6 +116,7 @@ pipe.unet.fuse_qkv_projections()
121
  #pipe.enable_free_init(method="gaussian", use_fast_sampling=True)
122
 
123
  pipe.to(device=DEVICE)
 
124
  #pipe.unet = torch.compile(pipe.unet)
125
  #pipe.vae = torch.compile(pipe.vae)
126
 
@@ -130,9 +126,10 @@ pipe.to(device=DEVICE)
130
  from transformers import AutoProcessor, PaliGemmaForConditionalGeneration, BitsAndBytesConfig
131
 
132
  quantization_config = BitsAndBytesConfig(load_in_4bit=True)
133
- pali = PaliGemmaForConditionalGeneration.from_pretrained('google/paligemma-3b-pt-224', torch_dtype=dtype, device_map='cuda').eval()
134
  processor = AutoProcessor.from_pretrained('google/paligemma-3b-pt-224')
135
 
 
136
 
137
  @spaces.GPU()
138
  def to_wanted_embs(image_outputs, input_ids, attention_mask, cache_position=None):
@@ -148,19 +145,34 @@ def to_wanted_embs(image_outputs, input_ids, attention_mask, cache_position=None
148
  return inputs_embeds
149
 
150
 
 
151
  @spaces.GPU()
152
- def generate_pali(user_emb):
153
- with torch.no_grad():
154
- prompt = 'caption en'
155
- model_inputs = processor(text=prompt, images=torch.zeros(1, 3, 224, 224), return_tensors="pt")
156
- # we need to get im_embs taken in here.
157
- input_len = model_inputs["input_ids"].shape[-1]
158
- input_embeds = to_wanted_embs(user_emb.squeeze()[None, None, :].repeat(1, 256, 1),
159
- model_inputs["input_ids"].to(device),
160
- model_inputs["attention_mask"].to(device))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
- generation = pali.generate(max_new_tokens=100, do_sample=True, top_p=.94, temperature=1.2, inputs_embeds=input_embeds)
163
- decoded = processor.decode(generation[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
164
  return decoded
165
 
166
 
@@ -182,7 +194,7 @@ def generate_gpu(in_im_embs, prompt='the scene'):
182
  im = torchvision.transforms.ToTensor()(output.frames[0][len(output.frames[0])//2]).unsqueeze(0)
183
  im = torch.nn.functional.interpolate(im, (224, 224))
184
  im = (im - .5) * 2
185
- gemb = pali.vision_tower(im.to(device).to(dtype)).last_hidden_state.detach().to('cpu').to(torch.float32).mean(1)
186
  return output, im_emb, gemb
187
 
188
 
@@ -210,10 +222,10 @@ def generate(in_im_embs, prompt='the scene'):
210
  def get_user_emb(embs, ys):
211
  # handle case where every instance of calibration videos is 'Neither' or 'Like' or 'Dislike'
212
 
213
- if len(list(ys)) <= 7:
214
- aways = [.01*torch.randn_like(embs[0]) for i in range(3)]
215
  embs += aways
216
- awal = [0 for i in range(3)]
217
  ys += awal
218
 
219
  indices = list(range(len(embs)))
@@ -241,9 +253,10 @@ def get_user_emb(embs, ys):
241
  feature_embs = feature_embs / feature_embs.norm()
242
 
243
  #lin_class = Ridge(fit_intercept=False).fit(feature_embs, chosen_y)
244
- lin_class = SVC(max_iter=20, kernel='linear', C=.1, class_weight='balanced').fit(feature_embs.squeeze(), chosen_y)
 
245
  coef_ = torch.tensor(lin_class.coef_, dtype=torch.float32).detach().to('cpu')
246
- coef_ = coef_ / coef_.abs().max() * 3
247
 
248
  w = 1# if len(embs) % 2 == 0 else 0
249
  im_emb = w * coef_.to(dtype=dtype)
@@ -273,7 +286,7 @@ def background_next_image():
273
  # only let it get N (maybe 3) ahead of the user
274
  #not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
275
  rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]]
276
- while len(rated_rows) < 4:
277
  # not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
278
  rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]]
279
  time.sleep(.01)
@@ -290,25 +303,21 @@ def background_next_image():
290
  rated_from_user = rated_rows[[i[1]['from_user_id'] == uid for i in rated_rows.iterrows()]]
291
 
292
  # we pop previous ratings if there are > n
293
- if len(rated_from_user) >= 15:
294
  oldest = rated_from_user.iloc[0]['paths']
295
  prevs_df = prevs_df[prevs_df['paths'] != oldest]
296
  # we don't compute more after n are in the queue for them
297
- if len(unrated_from_user) >= 10:
298
- continue
299
-
300
- if len(rated_rows) < 5:
301
  continue
302
 
303
  embs, ys, gembs = pluck_embs_ys(uid)
304
-
305
- user_emb = get_user_emb(embs, ys)
306
-
307
- if len(gembs) > 4:
308
- user_gem = get_user_emb(gembs, ys) / 4 # TODO scale this correctly; matplotlib, etc.
309
- text = generate_pali(user_gem)
310
  else:
311
- text = generate_pali(torch.zeros(1, 1152))
312
  img, embs, new_gem = generate(user_emb, text)
313
 
314
  if img:
@@ -351,60 +360,16 @@ def next_image(calibrate_prompts, user_id):
351
  if len(calibrate_prompts) > 0:
352
  cal_video = calibrate_prompts.pop(0)
353
  image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
354
-
355
  return image, calibrate_prompts, ''
356
  else:
357
  embs, ys, gembs = pluck_embs_ys(user_id)
358
- user_emb = get_user_emb(embs, ys)
359
  image, text = pluck_img(user_id, user_emb)
360
  return image, calibrate_prompts, text
361
 
362
 
363
 
364
-
365
-
366
-
367
- done_init = False
368
-
369
  def start(_, calibrate_prompts, user_id, request: gr.Request):
370
- global done_init
371
- global prevs_df
372
-
373
- if not done_init:
374
- # prep our calibration videos
375
- for im in [
376
- './first.mp4',
377
- # './second.mp4',
378
- # './third.mp4',
379
- # './fourth.mp4',
380
- # './fifth.mp4',
381
- # './sixth.mp4',
382
- # './seventh.mp4',
383
- # './eigth.mp4',
384
- # './ninth.mp4',
385
- # './tenth.mp4',
386
- ]:
387
- tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'text', 'gemb'])
388
- tmp_df['paths'] = [im]
389
- image = list(imageio.imiter(im))
390
- image = image[len(image)//2]
391
-
392
- im = torchvision.transforms.ToTensor()(image).unsqueeze(0)
393
- im = torch.nn.functional.interpolate(im, (224, 224))
394
- im = (im - .5) * 2
395
-
396
- im_emb, gemb = encode_space(image, im)
397
- im_emb = im_emb.to('cpu')
398
- gemb = gemb.to('cpu')
399
-
400
- tmp_df['embeddings'] = [im_emb]
401
- tmp_df['gemb'] = [gemb]
402
- tmp_df['user:rating'] = [{' ': ' '}]
403
- prevs_df = pd.concat((prevs_df, tmp_df))
404
- done_init = True
405
-
406
-
407
-
408
  user_id = int(str(time.time())[-7:].replace('.', ''))
409
  image, calibrate_prompts, text = next_image(calibrate_prompts, user_id)
410
  return [
@@ -436,6 +401,7 @@ def choose(img, choice, calibrate_prompts, user_id, request: gr.Request):
436
  print('NSFW -- choice is disliked')
437
  choice = 0
438
 
 
439
  row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
440
  # if it's still in the dataframe, add the choice
441
  if len(prevs_df.loc[row_mask, 'user:rating']) > 0:
@@ -506,11 +472,11 @@ Explore the latent space without text prompts based on your preferences. Learn m
506
  # calibration videos -- this is a misnomer now :D
507
  calibrate_prompts = gr.State([
508
  './first.mp4',
509
- # './second.mp4',
510
- # './third.mp4',
511
- # './fourth.mp4',
512
- # './fifth.mp4',
513
- # './sixth.mp4',
514
  ])
515
  def l():
516
  return None
@@ -569,26 +535,30 @@ scheduler = BackgroundScheduler()
569
  scheduler.add_job(func=background_next_image, trigger="interval", seconds=.5)
570
  scheduler.start()
571
 
572
- #thread = threading.Thread(target=background_next_image,)
573
- #thread.start()
574
 
575
- # TODO shouldn't call this before gradio launch, yeah?
576
- @spaces.GPU(duration=50)
577
- def encode_space(x, im):
578
- with torch.no_grad():
579
- print('encode')
580
- im_emb, _ = pipe.encode_image(
581
- x, DEVICE, 1, output_hidden_state
582
- )
583
-
584
- print('encoded')
585
-
586
- print('pali_enc')
587
- gemb = pali.vision_tower(im.to(dtype).to('cuda')).last_hidden_state
588
-
589
- print('pali_enced')
590
- return im_emb.to('cpu'), gemb.to('cpu')
591
-
592
- demo.launch(share=True,)
 
 
 
 
 
 
593
 
594
 
 
10
  output_hidden_state = False
11
  device = "cuda"
12
  dtype = torch.bfloat16
13
+ N_IMG_EMBS = 3
14
 
 
 
15
  import logging
 
 
16
  import os
17
  import imageio
18
  import gradio as gr
 
21
  from sklearn import preprocessing
22
  import pandas as pd
23
  from apscheduler.schedulers.background import BackgroundScheduler
 
 
24
 
25
  import random
26
  import time
 
99
  unet=unet, text_encoder=text_encoder)
100
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
101
  pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora",)
102
+ pipe.set_adapters(["lcm-lora"], [.95])
103
  pipe.fuse_lora()
104
 
105
 
 
116
  #pipe.enable_free_init(method="gaussian", use_fast_sampling=True)
117
 
118
  pipe.to(device=DEVICE)
119
+
120
  #pipe.unet = torch.compile(pipe.unet)
121
  #pipe.vae = torch.compile(pipe.vae)
122
 
 
126
  from transformers import AutoProcessor, PaliGemmaForConditionalGeneration, BitsAndBytesConfig
127
 
128
  quantization_config = BitsAndBytesConfig(load_in_4bit=True)
129
+ pali = PaliGemmaForConditionalGeneration.from_pretrained('google/paligemma-3b-pt-224', torch_dtype=dtype, quantization_config=quantization_config).eval()
130
  processor = AutoProcessor.from_pretrained('google/paligemma-3b-pt-224')
131
 
132
+ #pali = torch.compile(pali)
133
 
134
  @spaces.GPU()
135
  def to_wanted_embs(image_outputs, input_ids, attention_mask, cache_position=None):
 
145
  return inputs_embeds
146
 
147
 
148
+ # TODO cache descriptions?
149
  @spaces.GPU()
150
+ def generate_pali(n_embs):
151
+ prompt = 'caption en'
152
+ model_inputs = processor(text=prompt, images=torch.zeros(1, 3, 224, 224), return_tensors="pt")
153
+ # we need to get im_embs taken in here.
154
+
155
+ descs = ''
156
+ for n, emb in enumerate(n_embs):
157
+ if n < len(n_embs)-1:
158
+ input_len = model_inputs["input_ids"].shape[-1]
159
+ input_embeds = to_wanted_embs(emb,
160
+ model_inputs["input_ids"].to(device),
161
+ model_inputs["attention_mask"].to(device))
162
+ generation = pali.generate(max_new_tokens=20, do_sample=True, top_p=.94, temperature=1.2, inputs_embeds=input_embeds)
163
+ decoded = processor.decode(generation[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
164
+ descs += f'Description: {decoded}\n'
165
+ else:
166
+ prompt = f'en {descs} Describe a new image that is similar.'
167
+ print(prompt)
168
+ model_inputs = processor(text=prompt, images=torch.zeros(1, 3, 224, 224), return_tensors="pt")
169
+ input_len = model_inputs["input_ids"].shape[-1]
170
+ input_embeds = to_wanted_embs(emb,
171
+ model_inputs["input_ids"].to(device),
172
+ model_inputs["attention_mask"].to(device))
173
+ generation = pali.generate(max_new_tokens=20, do_sample=True, top_p=.94, temperature=1.2, inputs_embeds=input_embeds)
174
+ decoded = processor.decode(generation[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
175
 
 
 
176
  return decoded
177
 
178
 
 
194
  im = torchvision.transforms.ToTensor()(output.frames[0][len(output.frames[0])//2]).unsqueeze(0)
195
  im = torch.nn.functional.interpolate(im, (224, 224))
196
  im = (im - .5) * 2
197
+ gemb = pali.vision_tower(im.to(device).to(dtype)).last_hidden_state.detach().to('cpu').to(torch.float32)
198
  return output, im_emb, gemb
199
 
200
 
 
222
  def get_user_emb(embs, ys):
223
  # handle case where every instance of calibration videos is 'Neither' or 'Like' or 'Dislike'
224
 
225
+ if len(list(ys)) <= 10:
226
+ aways = [torch.zeros_like(embs[0]) for i in range(10)]
227
  embs += aways
228
+ awal = [0 for i in range(5)] + [1 for i in range(5)]
229
  ys += awal
230
 
231
  indices = list(range(len(embs)))
 
253
  feature_embs = feature_embs / feature_embs.norm()
254
 
255
  #lin_class = Ridge(fit_intercept=False).fit(feature_embs, chosen_y)
256
+ #class_weight='balanced'
257
+ lin_class = SVC(max_iter=500, kernel='linear', C=.1, ).fit(feature_embs.squeeze(), chosen_y)
258
  coef_ = torch.tensor(lin_class.coef_, dtype=torch.float32).detach().to('cpu')
259
+ coef_ = coef_ / coef_.abs().max()
260
 
261
  w = 1# if len(embs) % 2 == 0 else 0
262
  im_emb = w * coef_.to(dtype=dtype)
 
286
  # only let it get N (maybe 3) ahead of the user
287
  #not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
288
  rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]]
289
+ while len(rated_rows) < 5:
290
  # not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
291
  rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]]
292
  time.sleep(.01)
 
303
  rated_from_user = rated_rows[[i[1]['from_user_id'] == uid for i in rated_rows.iterrows()]]
304
 
305
  # we pop previous ratings if there are > n
306
+ if len(rated_from_user) >= 25:
307
  oldest = rated_from_user.iloc[0]['paths']
308
  prevs_df = prevs_df[prevs_df['paths'] != oldest]
309
  # we don't compute more after n are in the queue for them
310
+ if len(unrated_from_user) >= 20:
 
 
 
311
  continue
312
 
313
  embs, ys, gembs = pluck_embs_ys(uid)
314
+ user_emb = get_user_emb(embs, ys) * 3
315
+ pos_gembs = [g for g, y in zip(gembs, ys) if y == 1]
316
+ if len(pos_gembs) > 4:
317
+ hist_gem = random.sample(pos_gembs, N_IMG_EMBS) # rng n embeddings
318
+ text = generate_pali(hist_gem)
 
319
  else:
320
+ text = 'the scene'
321
  img, embs, new_gem = generate(user_emb, text)
322
 
323
  if img:
 
360
  if len(calibrate_prompts) > 0:
361
  cal_video = calibrate_prompts.pop(0)
362
  image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
 
363
  return image, calibrate_prompts, ''
364
  else:
365
  embs, ys, gembs = pluck_embs_ys(user_id)
366
+ user_emb = get_user_emb(embs, ys) * 3
367
  image, text = pluck_img(user_id, user_emb)
368
  return image, calibrate_prompts, text
369
 
370
 
371
 
 
 
 
 
 
372
  def start(_, calibrate_prompts, user_id, request: gr.Request):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  user_id = int(str(time.time())[-7:].replace('.', ''))
374
  image, calibrate_prompts, text = next_image(calibrate_prompts, user_id)
375
  return [
 
401
  print('NSFW -- choice is disliked')
402
  choice = 0
403
 
404
+ print(prevs_df['paths'].to_list(), img)
405
  row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
406
  # if it's still in the dataframe, add the choice
407
  if len(prevs_df.loc[row_mask, 'user:rating']) > 0:
 
472
  # calibration videos -- this is a misnomer now :D
473
  calibrate_prompts = gr.State([
474
  './first.mp4',
475
+ './second.mp4',
476
+ './third.mp4',
477
+ './fourth.mp4',
478
+ './fifth.mp4',
479
+ './sixth.mp4',
480
  ])
481
  def l():
482
  return None
 
535
  scheduler.add_job(func=background_next_image, trigger="interval", seconds=.5)
536
  scheduler.start()
537
 
 
 
538
 
539
+ # prep our calibration videos
540
+ for im in [
541
+ './first.mp4',
542
+ './second.mp4',
543
+ './third.mp4',
544
+ './fourth.mp4',
545
+ './fifth.mp4',
546
+ './sixth.mp4',
547
+ './seventh.mp4',
548
+ './eigth.mp4',
549
+ './ninth.mp4',
550
+ './tenth.mp4',
551
+ ]:
552
+ tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'text', 'gemb'])
553
+ tmp_df['paths'] = [im]
554
+ image = list(imageio.imiter(im))
555
+ image = image[len(image)//2]
556
+ tmp_df['embeddings'] = [torch.load(im.replace('mp4', 'im_.pt'))]
557
+ tmp_df['gemb'] = [torch.load(im.replace('mp4', 'gemb_.pt'))]
558
+ tmp_df['user:rating'] = [{' ': ' '}]
559
+ prevs_df = pd.concat((prevs_df, tmp_df))
560
+
561
+
562
+ demo.launch(share=True, server_port=8443)
563
 
564
 
lightning_app.py DELETED
@@ -1,452 +0,0 @@
1
-
2
- import torch
3
-
4
- # lol
5
- sidel = 512
6
- DEVICE = 'cuda'
7
- STEPS = 4
8
- output_hidden_state = False
9
- device = "cuda"
10
- dtype = torch.float16
11
-
12
- import matplotlib.pyplot as plt
13
- import matplotlib
14
- matplotlib.use('TkAgg')
15
-
16
- from sklearn.linear_model import LinearRegression
17
- from sfast.compilers.diffusion_pipeline_compiler import (compile, compile_unet,
18
- CompilationConfig)
19
- config = CompilationConfig.Default()
20
-
21
- try:
22
- import triton
23
- config.enable_triton = True
24
- except ImportError:
25
- print('Triton not installed, skip')
26
- config.enable_cuda_graph = True
27
-
28
- config.enable_jit = True
29
- config.enable_jit_freeze = True
30
-
31
- config.enable_cnn_optimization = True
32
- config.preserve_parameters = False
33
- config.prefer_lowp_gemm = True
34
-
35
- import imageio
36
- import gradio as gr
37
- import numpy as np
38
- from sklearn.svm import SVC
39
- from sklearn.inspection import permutation_importance
40
- from sklearn import preprocessing
41
- import pandas as pd
42
-
43
- import random
44
- import time
45
- from PIL import Image
46
- from safety_checker_improved import maybe_nsfw
47
-
48
-
49
- torch.set_grad_enabled(False)
50
- torch.backends.cuda.matmul.allow_tf32 = True
51
- torch.backends.cudnn.allow_tf32 = True
52
-
53
- # TODO put back?
54
- # import spaces
55
-
56
- prompt_list = [p for p in list(set(
57
- pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
58
-
59
- start_time = time.time()
60
-
61
- ####################### Setup Model
62
- from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, LCMScheduler, ConsistencyDecoderVAE, AutoencoderTiny
63
- from hyper_tile import split_attention, flush
64
- from huggingface_hub import hf_hub_download
65
- from safetensors.torch import load_file
66
- from PIL import Image
67
- from transformers import CLIPVisionModelWithProjection
68
- import uuid
69
- import av
70
-
71
- def write_video(file_name, images, fps=10):
72
- print('Saving')
73
- container = av.open(file_name, mode="w")
74
-
75
- stream = container.add_stream("h264", rate=fps)
76
- stream.width = sidel
77
- stream.height = sidel
78
- stream.pix_fmt = "yuv420p"
79
-
80
- for img in images:
81
- img = np.array(img)
82
- img = np.round(img).astype(np.uint8)
83
- frame = av.VideoFrame.from_ndarray(img, format="rgb24")
84
- for packet in stream.encode(frame):
85
- container.mux(packet)
86
- # Flush stream
87
- for packet in stream.encode():
88
- container.mux(packet)
89
- # Close the file
90
- container.close()
91
- print('Saved')
92
-
93
- bases = {
94
- #"basem": "emilianJR/epiCRealism"
95
- #SG161222/Realistic_Vision_V6.0_B1_noVAE
96
- #runwayml/stable-diffusion-v1-5
97
- #frankjoshua/realisticVisionV51_v51VAE
98
- #Lykon/dreamshaper-7
99
- }
100
-
101
- image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=dtype).to(DEVICE)
102
- vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=dtype)
103
-
104
- # vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=dtype)
105
- # vae = compile_unet(vae, config=config)
106
-
107
- #adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
108
- #pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapter=adapter, image_encoder=image_encoder, torch_dtype=dtype)
109
- #pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
110
- #pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora",)
111
- #pipe.set_adapters(["lcm-lora"], [1])
112
- #pipe.fuse_lora()
113
-
114
- pipe = AnimateDiffPipeline.from_pretrained('emilianJR/epiCRealism', torch_dtype=dtype, image_encoder=image_encoder, vae=vae)
115
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
116
- repo = "ByteDance/AnimateDiff-Lightning"
117
- ckpt = f"animatediff_lightning_4step_diffusers.safetensors"
118
- pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device='cpu'), strict=False)
119
-
120
-
121
- pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin", map_location='cpu')
122
- pipe.set_ip_adapter_scale(.8)
123
- # pipe.unet.fuse_qkv_projections()
124
- #pipe.enable_free_init(method="gaussian", use_fast_sampling=True)
125
-
126
- pipe = compile(pipe, config=config)
127
- pipe.to(device=DEVICE)
128
-
129
-
130
- # THIS WOULD NEED PATCHING TODO
131
- with split_attention(pipe.vae, tile_size=128, swap_size=2, disable=False, aspect_ratio=1):
132
- # ! Change the tile_size and disable to see their effects
133
- with split_attention(pipe.unet, tile_size=128, swap_size=2, disable=False, aspect_ratio=1):
134
- im_embs = torch.zeros(1, 1, 1, 1024, device=DEVICE, dtype=dtype)
135
- output = pipe(prompt='a person', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[im_embs], num_inference_steps=STEPS)
136
- leave_im_emb, _ = pipe.encode_image(
137
- output.frames[0][len(output.frames[0])//2], DEVICE, 1, output_hidden_state
138
- )
139
- assert len(output.frames[0]) == 16
140
- leave_im_emb.to('cpu')
141
-
142
-
143
- # TODO put back
144
- # @spaces.GPU()
145
- def generate(prompt, in_im_embs=None, base='basem'):
146
-
147
- if in_im_embs == None:
148
- in_im_embs = torch.zeros(1, 1, 1, 1024, device=DEVICE, dtype=dtype)
149
- #in_im_embs = in_im_embs / torch.norm(in_im_embs)
150
- else:
151
- in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0)
152
- #im_embs = torch.cat((torch.zeros(1, 1024, device=DEVICE, dtype=dtype), in_im_embs), 0)
153
-
154
- with split_attention(pipe.unet, tile_size=128, swap_size=2, disable=False, aspect_ratio=1):
155
- # ! Change the tile_size and disable to see their effects
156
- with split_attention(pipe.vae, tile_size=128, disable=False, aspect_ratio=1):
157
- output = pipe(prompt=prompt, guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
158
-
159
- im_emb, _ = pipe.encode_image(
160
- output.frames[0][len(output.frames[0])//2], DEVICE, 1, output_hidden_state
161
- )
162
-
163
- nsfw = maybe_nsfw(output.frames[0][len(output.frames[0])//2])
164
-
165
- name = str(uuid.uuid4()).replace("-", "")
166
- path = f"/tmp/{name}.mp4"
167
-
168
- if nsfw:
169
- gr.Warning("NSFW content detected.")
170
- # TODO could return an automatic dislike of auto dislike on the backend for neither as well; just would need refactoring.
171
- return None, im_emb
172
-
173
- plt.close('all')
174
- plt.hist(np.array(im_emb.to('cpu')).flatten(), bins=5)
175
- plt.savefig('real_im_emb_plot.jpg')
176
-
177
- write_video(path, output.frames[0])
178
- return path, im_emb.to('cpu')
179
-
180
-
181
- #######################
182
-
183
- # TODO add to state instead of shared across all
184
- glob_idx = 0
185
-
186
- def next_image(embs, ys, calibrate_prompts):
187
- global glob_idx
188
- glob_idx = glob_idx + 1
189
-
190
- with torch.no_grad():
191
- if len(calibrate_prompts) > 0:
192
- print('######### Calibrating with sample prompts #########')
193
- prompt = calibrate_prompts.pop(0)
194
- print(prompt)
195
- image, img_embs = generate(prompt)
196
- embs += img_embs
197
- print(len(embs))
198
- return image, embs, ys, calibrate_prompts
199
- else:
200
- print('######### Roaming #########')
201
-
202
- # sample a .8 of rated embeddings for some stochasticity, or at least two embeddings.
203
- # could take a sample < len(embs)
204
- #n_to_choose = max(int((len(embs))), 2)
205
- #indices = random.sample(range(len(embs)), n_to_choose)
206
-
207
- # sample only as many negatives as there are positives
208
- #pos_indices = [i for i in indices if ys[i] == 1]
209
- #neg_indices = [i for i in indices if ys[i] == 0]
210
- #lower = min(len(pos_indices), len(neg_indices))
211
- #neg_indices = random.sample(neg_indices, lower)
212
- #pos_indices = random.sample(pos_indices, lower)
213
- #indices = neg_indices + pos_indices
214
-
215
- pos_indices = [i for i in range(len(embs)) if ys[i] == 1]
216
- neg_indices = [i for i in range(len(embs)) if ys[i] == 0]
217
-
218
- # the embs & ys stay tied by index but we shuffle to drop randomly
219
- random.shuffle(pos_indices)
220
- random.shuffle(neg_indices)
221
-
222
- #if len(pos_indices) - len(neg_indices) > 48 and len(pos_indices) > 80:
223
- # pos_indices = pos_indices[32:]
224
- if len(neg_indices) - len(pos_indices) > 48/16 and len(pos_indices) > 120/16:
225
- pos_indices = pos_indices[1:]
226
- if len(neg_indices) - len(pos_indices) > 48/16 and len(neg_indices) > 200/16:
227
- neg_indices = neg_indices[2:]
228
-
229
-
230
- print(len(pos_indices), len(neg_indices))
231
- indices = pos_indices + neg_indices
232
-
233
- embs = [embs[i] for i in indices]
234
- ys = [ys[i] for i in indices]
235
- indices = list(range(len(embs)))
236
-
237
-
238
- # handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike'
239
- if len(list(set(ys))) <= 1:
240
- embs.append(.01*torch.randn(1024))
241
- embs.append(.01*torch.randn(1024))
242
- ys.append(0)
243
- ys.append(1)
244
-
245
-
246
- # also add the latest 0 and the latest 1
247
- has_0 = False
248
- has_1 = False
249
- for i in reversed(range(len(ys))):
250
- if ys[i] == 0 and has_0 == False:
251
- indices.append(i)
252
- has_0 = True
253
- elif ys[i] == 1 and has_1 == False:
254
- indices.append(i)
255
- has_1 = True
256
- if has_0 and has_1:
257
- break
258
-
259
- # we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749);
260
- # this ends up adding a rating but losing an embedding, it seems.
261
- # let's take off a rating if so to continue without indexing errors.
262
- if len(ys) > len(embs):
263
- print('ys are longer than embs; popping latest rating')
264
- ys.pop(-1)
265
-
266
- feature_embs = np.array(torch.stack([embs[i].to('cpu') for i in indices] + [leave_im_emb[0].to('cpu')]).to('cpu'))
267
- scaler = preprocessing.StandardScaler().fit(feature_embs)
268
- feature_embs = scaler.transform(feature_embs)
269
- chosen_y = np.array([ys[i] for i in indices] + [0])
270
-
271
- print('Gathering coefficients')
272
- #lin_class = LinearRegression(fit_intercept=False).fit(feature_embs, chosen_y)
273
- lin_class = SVC(max_iter=50000, kernel='linear', class_weight='balanced', C=1).fit(feature_embs, chosen_y)
274
- coef_ = torch.tensor(lin_class.coef_, dtype=torch.double)
275
- coef_ = coef_ / coef_.abs().max() * 3
276
- print(coef_.shape, 'COEF')
277
-
278
- plt.close('all')
279
- plt.hist(np.array(coef_).flatten(), bins=5)
280
- plt.savefig('plot.jpg')
281
- print(coef_)
282
- print('Gathered')
283
-
284
- rng_prompt = random.choice(prompt_list)
285
- w = 1# if len(embs) % 2 == 0 else 0
286
- im_emb = w * coef_.to(dtype=dtype)
287
-
288
- prompt= 'the scene' if glob_idx % 2 == 0 else rng_prompt
289
- print(prompt)
290
- image, im_emb = generate(prompt, im_emb)
291
- embs += im_emb
292
-
293
- if len(embs) > 700/16:
294
- embs = embs[1:]
295
- ys = ys[1:]
296
-
297
- return image, embs, ys, calibrate_prompts
298
-
299
-
300
-
301
-
302
-
303
-
304
-
305
-
306
-
307
- def start(_, embs, ys, calibrate_prompts):
308
- image, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
309
- return [
310
- gr.Button(value='Like (L)', interactive=True),
311
- gr.Button(value='Neither (Space)', interactive=True),
312
- gr.Button(value='Dislike (A)', interactive=True),
313
- gr.Button(value='Start', interactive=False),
314
- image,
315
- embs,
316
- ys,
317
- calibrate_prompts
318
- ]
319
-
320
-
321
- def choose(img, choice, embs, ys, calibrate_prompts):
322
- if choice == 'Like (L)':
323
- choice = 1
324
- elif choice == 'Neither (Space)':
325
- embs = embs[:-1]
326
- img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
327
- return img, embs, ys, calibrate_prompts
328
- else:
329
- choice = 0
330
-
331
- # if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
332
- # TODO skip allowing rating
333
- if img == None:
334
- print('NSFW -- choice is disliked')
335
- choice = 0
336
-
337
- ys += [choice]*1
338
- img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
339
- return img, embs, ys, calibrate_prompts
340
-
341
- css = '''.gradio-container{max-width: 700px !important}
342
- #description{text-align: center}
343
- #description h1, #description h3{display: block}
344
- #description p{margin-top: 0}
345
- .fade-in-out {animation: fadeInOut 3s forwards}
346
- @keyframes fadeInOut {
347
- 0% {
348
- background: var(--bg-color);
349
- }
350
- 100% {
351
- background: var(--button-secondary-background-fill);
352
- }
353
- }
354
- '''
355
- js_head = '''
356
- <script>
357
- document.addEventListener('keydown', function(event) {
358
- if (event.key === 'a' || event.key === 'A') {
359
- // Trigger click on 'dislike' if 'A' is pressed
360
- document.getElementById('dislike').click();
361
- } else if (event.key === ' ' || event.keyCode === 32) {
362
- // Trigger click on 'neither' if Spacebar is pressed
363
- document.getElementById('neither').click();
364
- } else if (event.key === 'l' || event.key === 'L') {
365
- // Trigger click on 'like' if 'L' is pressed
366
- document.getElementById('like').click();
367
- }
368
- });
369
- function fadeInOut(button, color) {
370
- button.style.setProperty('--bg-color', color);
371
- button.classList.remove('fade-in-out');
372
- void button.offsetWidth; // This line forces a repaint by accessing a DOM property
373
-
374
- button.classList.add('fade-in-out');
375
- button.addEventListener('animationend', () => {
376
- button.classList.remove('fade-in-out'); // Reset the animation state
377
- }, {once: true});
378
- }
379
- document.body.addEventListener('click', function(event) {
380
- const target = event.target;
381
- if (target.id === 'dislike') {
382
- fadeInOut(target, '#ff1717');
383
- } else if (target.id === 'like') {
384
- fadeInOut(target, '#006500');
385
- } else if (target.id === 'neither') {
386
- fadeInOut(target, '#cccccc');
387
- }
388
- });
389
-
390
- </script>
391
- '''
392
-
393
- with gr.Blocks(css=css, head=js_head) as demo:
394
- gr.Markdown('''### Blue Tigers: Generative Recommenders for Exporation of Video
395
- Explore the latent space without text prompts based on your preferences. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/).
396
- ''', elem_id="description")
397
- embs = gr.State([])
398
- ys = gr.State([])
399
- calibrate_prompts = gr.State([
400
- 'the moon is melting into my glass of tea',
401
- 'a sea slug -- pair of claws scuttling -- jelly fish glowing',
402
- 'an adorable creature. It may be a goblin or a pig or a slug.',
403
- 'an animation about a gorgeous nebula',
404
- 'an octopus writhes',
405
- ])
406
- def l():
407
- return None
408
-
409
- with gr.Row(elem_id='output-image'):
410
- img = gr.Video(
411
- label='Lightning',
412
- autoplay=True,
413
- interactive=False,
414
- height=sidel,
415
- width=sidel,
416
- include_audio=False,
417
- elem_id="video_output"
418
- )
419
- img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
420
- with gr.Row(equal_height=True):
421
- b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
422
- b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither")
423
- b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like")
424
- b1.click(
425
- choose,
426
- [img, b1, embs, ys, calibrate_prompts],
427
- [img, embs, ys, calibrate_prompts]
428
- )
429
- b2.click(
430
- choose,
431
- [img, b2, embs, ys, calibrate_prompts],
432
- [img, embs, ys, calibrate_prompts]
433
- )
434
- b3.click(
435
- choose,
436
- [img, b3, embs, ys, calibrate_prompts],
437
- [img, embs, ys, calibrate_prompts]
438
- )
439
- with gr.Row():
440
- b4 = gr.Button(value='Start')
441
- b4.click(start,
442
- [b4, embs, ys, calibrate_prompts],
443
- [b1, b2, b3, b4, img, embs, ys, calibrate_prompts])
444
- with gr.Row():
445
- html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several prompts and then roam. </ div><br><br><br>
446
- <div style='text-align:center; font-size:14px'>Note that while the AnimateDiff-Lightning model with NSFW filtering is unlikely to produce NSFW images, this may still occur, and users should avoid NSFW content when rating.
447
- </ div>
448
- <br><br>
449
- <div style='text-align:center; font-size:14px'>Thanks to @multimodalart for their contributions to the demo, esp. the interface and @maxbittker for feedback.
450
- </ div>''')
451
-
452
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -15,6 +15,4 @@ tensorflow==2.14.0
15
  imageio
16
  apscheduler
17
  pandas
18
- av
19
- torchvision
20
- bitsandbytes
 
15
  imageio
16
  apscheduler
17
  pandas
18
+ av
 
 
twitter_prompts.csv DELETED
@@ -1,72 +0,0 @@
1
- ,0
2
- 0,a sunset
3
- 1,a still life in blue
4
- 2,last day on earth
5
- 3,the conch shell
6
- 4,the winds of change
7
- 5,a surrealist eye
8
- 6,a surrealist polaroid photo of an apple
9
- 7,metaphysics
10
- 8,the sun is setting into my glass of tea
11
- 9,the moon at 3am
12
- 10,a memento mori
13
- 11,quaking aspen tree
14
- 12,violets and daffodils
15
- 13,espresso
16
- 14,sisyphus
17
- 15,high windows of stained glass
18
- 16,a green dog
19
- 17,an adorable companion; it is a pig
20
- 18,bird of paradise
21
- 19,a complex intricate machine
22
- 20,a white clock
23
- 21,a film featuring the landscape Salt Lake City Utah
24
- 22,a creature
25
- 23,a house set aflame.
26
- 24,a gorgeous landscape by Cy Twombly
27
- 25,smoke rises from the caterpillar's hookah
28
- 26,corvid in red
29
- 27,Monet's pond
30
- 28,Genesis
31
- 29,Death is a black camel that kneels down so we can ride
32
- 30,a cherry tree made of fractals
33
- 29,the end of the sidewalk
34
- 30,a polaroid photo of a bustling city of lights and sky scrapers
35
- 31,The Fig Tree metaphor
36
- 32,God killed Van Gogh.
37
- 33,a cosmic entity alien with four eyes.
38
- 34,a horse with 128 eyes.
39
- 35,a being with an infinite set of eyes (it is omniscient)
40
- 36,A sticky-note magnum opus featuring birds
41
- 37,Moka Pot
42
- 38,the moon is a sickle cell
43
- 39,The Penultimate Supper
44
- 40,Art
45
- 41,surrealism
46
- 42,a god made of wires & dust
47
- 43,a dandelion blown into the universe
48
-
49
-
50
-
51
-
52
-
53
-
54
-
55
-
56
-
57
-
58
-
59
-
60
-
61
-
62
-
63
-
64
-
65
-
66
-
67
-
68
-
69
-
70
-
71
-
72
-