merve HF staff commited on
Commit
a6e8639
1 Parent(s): a7c611e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -10,7 +10,8 @@ import cv2
10
  import os
11
 
12
 
13
-
 
14
 
15
 
16
  def image_grid(imgs, rows, cols):
@@ -41,7 +42,6 @@ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
41
  pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
42
  "runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.float32
43
  )
44
- pipe = pipe.to("cuda")
45
 
46
  def infer(prompts, negative_prompts, image):
47
 
 
10
  import os
11
 
12
 
13
+ from jax import device
14
+ jax.config.update('jax_platform_name', 'gpu')
15
 
16
 
17
  def image_grid(imgs, rows, cols):
 
42
  pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
43
  "runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.float32
44
  )
 
45
 
46
  def infer(prompts, negative_prompts, image):
47