multimodalart HF staff commited on
Commit
5042a41
1 Parent(s): 9155e06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -9
app.py CHANGED
@@ -12,6 +12,7 @@ import json
12
  import random
13
  import copy
14
  import gc
 
15
 
16
  lora_list = hf_hub_download(repo_id="multimodalart/LoraTheExplorer", filename="sdxl_loras.json", repo_type="space")
17
 
@@ -44,16 +45,20 @@ with open(lora_list, "r") as file:
44
  for item in data
45
  ]
46
 
 
 
47
  for item in sdxl_loras:
48
  saved_name = hf_hub_download(item["repo"], item["weights"])
49
 
50
- if saved_name.endswith('.safetensors'):
51
- state_dict = load_file(saved_name)
52
- else:
53
  state_dict = torch.load(saved_name)
54
-
55
- item["saved_name"] = saved_name
56
- item["state_dict"] = state_dict #{k: v.to(device="cuda", dtype=torch.float16) for k, v in state_dict.items() if torch.is_tensor(v)}
 
 
 
 
57
 
58
  css = '''
59
  .gradio-container{max-width: 650px! important}
@@ -82,9 +87,12 @@ div#share-btn-container > div {flex-direction: row;background: black;align-items
82
 
83
  original_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
84
 
 
85
  def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lora_2_scale=0.5, seed=-1, progress=gr.Progress(track_tqdm=True)):
86
- state_dict_1 = copy.deepcopy(shuffled_items[0]['state_dict'])
87
- state_dict_2 = copy.deepcopy(shuffled_items[1]['state_dict'])
 
 
88
  pipe = copy.deepcopy(original_pipe)
89
  pipe.to("cuda")
90
 
@@ -102,7 +110,6 @@ def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lor
102
  image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=20, width=768, height=768, generator=generator).images[0]
103
  del pipe
104
  gc.collect()
105
- torch.cuda.empty_cache()
106
  return image, gr.update(visible=True), seed
107
 
108
  def get_description(item):
 
12
  import random
13
  import copy
14
  import gc
15
+ import spaces
16
 
17
  lora_list = hf_hub_download(repo_id="multimodalart/LoraTheExplorer", filename="sdxl_loras.json", repo_type="space")
18
 
 
45
  for item in data
46
  ]
47
 
48
+ state_dicts = {}
49
+
50
  for item in sdxl_loras:
51
  saved_name = hf_hub_download(item["repo"], item["weights"])
52
 
53
+ if not saved_name.endswith('.safetensors'):
 
 
54
  state_dict = torch.load(saved_name)
55
+ else:
56
+ state_dict = load_file(saved_name)
57
+
58
+ state_dicts[item["repo"]] = {
59
+ "saved_name": saved_name,
60
+ "state_dict": state_dict
61
+ }
62
 
63
  css = '''
64
  .gradio-container{max-width: 650px! important}
 
87
 
88
  original_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
89
 
90
+ @spaces.GPU
91
  def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lora_2_scale=0.5, seed=-1, progress=gr.Progress(track_tqdm=True)):
92
+ repo_id_1 = shuffled_items[0]['repo']
93
+ repo_id_2 = shuffled_items[1]['repo']
94
+ state_dict_1 = copy.deepcopy(state_dicts[repo_id_1]["state_dict"])
95
+ state_dict_2 = copy.deepcopy(state_dicts[repo_id_2]["state_dict"])
96
  pipe = copy.deepcopy(original_pipe)
97
  pipe.to("cuda")
98
 
 
110
  image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=20, width=768, height=768, generator=generator).images[0]
111
  del pipe
112
  gc.collect()
 
113
  return image, gr.update(visible=True), seed
114
 
115
  def get_description(item):