Update app.py
Browse files
app.py
CHANGED
@@ -10,7 +10,8 @@ import cv2
|
|
10 |
import os
|
11 |
|
12 |
|
13 |
-
|
|
|
14 |
|
15 |
|
16 |
def image_grid(imgs, rows, cols):
|
@@ -41,7 +42,6 @@ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
|
|
41 |
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
42 |
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.float32
|
43 |
)
|
44 |
-
pipe = pipe.to("cuda")
|
45 |
|
46 |
def infer(prompts, negative_prompts, image):
|
47 |
|
|
|
10 |
import os
|
11 |
|
12 |
|
13 |
+
from jax import device
|
14 |
+
jax.config.update('jax_platform_name', 'gpu')
|
15 |
|
16 |
|
17 |
def image_grid(imgs, rows, cols):
|
|
|
42 |
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
43 |
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.float32
|
44 |
)
|
|
|
45 |
|
46 |
def infer(prompts, negative_prompts, image):
|
47 |
|