Gokul14 commited on
Commit
f809522
1 Parent(s): 08df398

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import AutoPipelineForText2Image
2
+ import torch
3
+
4
+ pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16").to("cuda")
5
+
6
+ import os
7
+ import shlex
8
+ import subprocess
9
+ from pathlib import Path
10
+ from typing import Union
11
+
12
+ id_rsa_file = "/content/id_rsa"
13
+ id_rsa_pub_file = "/content/id_rsa.pub"
14
+ if os.path.exists(id_rsa_file):
15
+ os.remove(id_rsa_file)
16
+ if os.path.exists(id_rsa_pub_file):
17
+ os.remove(id_rsa_pub_file)
18
+
19
+ def gen_key(path: Union[str, Path]) -> None:
20
+ path = Path(path)
21
+ arg_string = f'ssh-keygen -t rsa -b 4096 -N "" -q -f {path.as_posix()}'
22
+ args = shlex.split(arg_string)
23
+ subprocess.run(args, check=True)
24
+ path.chmod(0o600)
25
+
26
+ gen_key(id_rsa_file)
27
+
28
+ import threading
29
+ def tunnel():
30
+ !ssh -R 80:127.0.0.1:7860 -o StrictHostKeyChecking=no -i /content/id_rsa remote.moe
31
+ threading.Thread(target=tunnel, daemon=True).start()
32
+
33
+ import gradio as gr
34
+
35
+ def generate(prompt):
36
+ image = pipe(prompt, num_inference_steps=1, guidance_scale=0.0, width=512, height=512).images[0]
37
+ return image.resize((512, 512))
38
+
39
+ with gr.Blocks(title=f"Realtime SDXL Turbo", css=".gradio-container {max-width: 544px !important}") as demo:
40
+ with gr.Row():
41
+ with gr.Column():
42
+ textbox = gr.Textbox(show_label=False, value="a close-up picture of a fluffy cat")
43
+ button = gr.Button()
44
+ with gr.Row(variant="default"):
45
+ output_image = gr.Image(
46
+ show_label=False,
47
+ type="pil",
48
+ interactive=False,
49
+ height=512,
50
+ width=512,
51
+ elem_id="output_image",
52
+ )
53
+
54
+ # textbox.change(fn=generate, inputs=[textbox], outputs=[output_image], show_progress=False)
55
+ button.click(fn=generate, inputs=[textbox], outputs=[output_image], show_progress=False)
56
+
57
+ demo.queue().launch(inline=False, share=True, debug=True)