multimodalart HF staff commited on
Commit
4120479
1 Parent(s): 2905c29

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from time import sleep
3
+ from diffusers import DiffusionPipeline
4
+ import torch
5
+ import json
6
+ import random
7
+
8
+ lora_list = hf_hub_download(repo_id="multimodalart/LoraTheExplorer", filename="sdxl_loras.json", repo_type="space")
9
+
10
+ with open(lora_list, "r") as file:
11
+ data = json.load(file)
12
+ sdxl_loras = [
13
+ {
14
+ "image": item["image"],
15
+ "title": item["title"],
16
+ "repo": item["repo"],
17
+ "trigger_word": item["trigger_word"],
18
+ "weights": item["weights"],
19
+ "is_compatible": item["is_compatible"],
20
+ "is_pivotal": item.get("is_pivotal", False),
21
+ "text_embedding_weights": item.get("text_embedding_weights", None),
22
+ "is_nc": item.get("is_nc", False)
23
+ }
24
+ for item in data
25
+ ]
26
+
27
+ saved_names = [
28
+ hf_hub_download(item["repo"], item["weights"]) for item in sdxl_loras
29
+ ]
30
+
31
+ css = '''
32
+ #title{text-align:center}
33
+ #plus_column{align-self: center}
34
+ #plus_button{font-size: 250%; text-align: center}
35
+ .gradio-container{width: 700px !important; margin: 0 auto !important}
36
+ #prompt input{width: calc(100% - 160px);border-top-right-radius: 0px;border-bottom-right-radius: 0px;}
37
+ #run_button{position:absolute;margin-top: 57px;right: 0;margin-right: 0.8em;border-bottom-left-radius: 0px;
38
+ border-top-left-radius: 0px;}
39
+ '''
40
+
41
+ pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
42
+ original_pipe = copy.deepcopy(pipe)
43
+ def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lora_2_scale=0.5, progress=gr.Progress(track_tqdm=True)):
44
+ pipe = copy.deepcopy(original_pipe)
45
+ pipe.load_lora_weights(shuffled_items[0]['repo'], weight_name=shuffled_items[0]['weights'])
46
+ pipe.fuse_lora(lora_1_scale)
47
+ pipe.load_lora_weights(shuffled_items[1]['repo'], weight_name=shuffled_items[1]['weights'])
48
+ pipe.fuse_lora(lora_2_scale)
49
+
50
+ pipe.to(torch_dtype=torch.float16)
51
+ pipe.to("cuda")
52
+ if negative_prompt == "":
53
+ negative_prompt = False
54
+ image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=25, guidance_scale=7).images[0]
55
+ return image
56
+
57
+ def get_description(item):
58
+ trigger_word = item["trigger_word"]
59
+ return f"LoRA trigger word: `{trigger_word}`" if trigger_word else "LoRA trigger word: `none`, will be applied automatically", trigger_word
60
+
61
+ def shuffle_images():
62
+ compatible_items = [item for item in data if item['is_compatible']]
63
+ random.shuffle(compatible_items)
64
+ two_shuffled_items = compatible_items[:2]
65
+ title_1 = gr.update(label=two_shuffled_items[0]['title'], value=two_shuffled_items[0]['image'])
66
+ title_2 = gr.update(label=two_shuffled_items[1]['title'], value=two_shuffled_items[1]['image'])
67
+
68
+ description_1, trigger_word_1 = get_description(two_shuffled_items[0])
69
+ description_2, trigger_word_2 = get_description(two_shuffled_items[1])
70
+
71
+ prompt = gr.update(value=f"{trigger_word_1} {trigger_word_2}")
72
+ return title_1,description_1,title_2,description_2,prompt, two_shuffled_items
73
+
74
+ with gr.Blocks(css=css) as demo:
75
+ shuffled_items = gr.State()
76
+ title = gr.HTML(
77
+ '''<h1>LoRA Roulette 🎲</h1>
78
+ <h4>Two LoRAs are loaded to SDXL at random, find a way to combine them for your art 🎨</h4>
79
+ ''',
80
+ elem_id="title"
81
+ )
82
+ with gr.Row():
83
+ with gr.Column(min_width=10, scale=6):
84
+ lora_1 = gr.Image(interactive=False, height=350)
85
+ lora_1_prompt = gr.Markdown()
86
+ with gr.Column(min_width=10, scale=1, elem_id="plus_column"):
87
+ plus = gr.HTML("+", elem_id="plus_button")
88
+ with gr.Column(min_width=10, scale=6):
89
+ lora_2 = gr.Image(interactive=False, height=350)
90
+ lora_2_prompt = gr.Markdown()
91
+ with gr.Row():
92
+ prompt = gr.Textbox(label="Your prompt", info="arrange the trigger words of the two LoRAs in a coherent sentence", interactive=True, elem_id="prompt")
93
+ run_btn = gr.Button("Run", elem_id="run_button")
94
+
95
+ output_image = gr.Image()
96
+ with gr.Accordion("Advanced settings", open=False):
97
+ negative_prompt = gr.Textbox(label="Negative prompt")
98
+ with gr.Row():
99
+ lora_1_scale = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
100
+ lora_2_scale = gr.Slider(label="LoRa 2 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
101
+ shuffle_button = gr.Button("Reshuffle LoRAs!")
102
+
103
+ demo.load(shuffle_images, inputs=[], outputs=[lora_1,lora_1_prompt,lora_2,lora_2_prompt, prompt, shuffled_items], queue=False, show_progress="hidden")
104
+ shuffle_button.click(shuffle_images, outputs=[lora_1,lora_1_prompt,lora_2,lora_2_prompt, prompt, shuffled_items], queue=False, show_progress="hidden")
105
+
106
+ run_btn.click(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image])
107
+ prompt.submit(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image])
108
+ demo.queue()
109
+ demo.launch()