qninhdt commited on
Commit
0f9e661
1 Parent(s): 478d8f0

Upload 53 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +10 -0
  2. LICENSE +21 -0
  3. README.md +228 -0
  4. assets/cat_2x.gif +3 -0
  5. assets/clear2rainy_results.jpg +3 -0
  6. assets/day2night_results.jpg +3 -0
  7. assets/edge_to_image_results.jpg +3 -0
  8. assets/examples/bird.png +3 -0
  9. assets/examples/bird_canny.png +0 -0
  10. assets/examples/bird_canny_blue.png +0 -0
  11. assets/examples/circles_inference_input.png +0 -0
  12. assets/examples/circles_inference_output.png +0 -0
  13. assets/examples/clear2rainy_input.png +0 -0
  14. assets/examples/clear2rainy_output.png +0 -0
  15. assets/examples/day2night_input.png +0 -0
  16. assets/examples/day2night_output.png +0 -0
  17. assets/examples/my_horse2zebra_input.jpg +0 -0
  18. assets/examples/my_horse2zebra_output.jpg +0 -0
  19. assets/examples/night2day_input.png +0 -0
  20. assets/examples/night2day_output.png +0 -0
  21. assets/examples/rainy2clear_input.png +0 -0
  22. assets/examples/rainy2clear_output.png +0 -0
  23. assets/examples/sketch_input.png +0 -0
  24. assets/examples/sketch_output.png +0 -0
  25. assets/examples/training_evaluation.png +0 -0
  26. assets/examples/training_evaluation_unpaired.png +0 -0
  27. assets/examples/training_step_0.png +0 -0
  28. assets/examples/training_step_500.png +0 -0
  29. assets/examples/training_step_6000.png +0 -0
  30. assets/fish_2x.gif +3 -0
  31. assets/gen_variations.jpg +3 -0
  32. assets/method.jpg +0 -0
  33. assets/night2day_results.jpg +3 -0
  34. assets/rainy2clear.jpg +3 -0
  35. assets/teaser_results.jpg +3 -0
  36. docs/training_cyclegan_turbo.md +98 -0
  37. docs/training_pix2pix_turbo.md +118 -0
  38. environment.yaml +34 -0
  39. gradio_canny2image.py +78 -0
  40. gradio_sketch2image.py +382 -0
  41. requirements.txt +29 -0
  42. scripts/download_fill50k.sh +5 -0
  43. scripts/download_horse2zebra.sh +5 -0
  44. src/cyclegan_turbo.py +254 -0
  45. src/image_prep.py +12 -0
  46. src/inference_paired.py +75 -0
  47. src/inference_unpaired.py +58 -0
  48. src/model.py +73 -0
  49. src/my_utils/dino_struct.py +185 -0
  50. src/my_utils/training_utils.py +409 -0
.gitattributes CHANGED
@@ -33,3 +33,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/cat_2x.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/clear2rainy_results.jpg filter=lfs diff=lfs merge=lfs -text
38
+ assets/day2night_results.jpg filter=lfs diff=lfs merge=lfs -text
39
+ assets/edge_to_image_results.jpg filter=lfs diff=lfs merge=lfs -text
40
+ assets/examples/bird.png filter=lfs diff=lfs merge=lfs -text
41
+ assets/fish_2x.gif filter=lfs diff=lfs merge=lfs -text
42
+ assets/gen_variations.jpg filter=lfs diff=lfs merge=lfs -text
43
+ assets/night2day_results.jpg filter=lfs diff=lfs merge=lfs -text
44
+ assets/rainy2clear.jpg filter=lfs diff=lfs merge=lfs -text
45
+ assets/teaser_results.jpg filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 img-to-img-turbo
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # img2img-turbo
2
+
3
+ [**Paper**](https://arxiv.org/abs/2403.12036) | [**Sketch2Image Demo**](https://huggingface.co/spaces/gparmar/img2img-turbo-sketch)
4
+ #### **Quick start:** [**Running Locally**](#getting-started) | [**Gradio (locally hosted)**](#gradio-demo) | [**Training**](#training-with-your-own-data)
5
+
6
+ ### Cat Sketching
7
+ <p align="left" >
8
+ <img src="https://raw.githubusercontent.com/GaParmar/img2img-turbo/main/assets/cat_2x.gif" width="800" />
9
+ </p>
10
+
11
+ ### Fish Sketching
12
+ <p align="left">
13
+ <img src="https://raw.githubusercontent.com/GaParmar/img2img-turbo/main/assets/fish_2x.gif" width="800" />
14
+ </p>
15
+
16
+
17
+ We propose a general method for adapting a single-step diffusion model, such as SD-Turbo, to new tasks and domains through adversarial learning. This enables us to leverage the internal knowledge of pre-trained diffusion models while achieving efficient inference (e.g., for 512x512 images, 0.29 seconds on A6000 and 0.11 seconds on A100).
18
+
19
+ Our one-step conditional models **CycleGAN-Turbo** and **pix2pix-turbo** can perform various image-to-image translation tasks for both unpaired and paired settings. CycleGAN-Turbo outperforms existing GAN-based and diffusion-based methods, while pix2pix-turbo is on par with recent works such as ControlNet for Sketch2Photo and Edge2Image, but with one-step inference.
20
+
21
+ [One-Step Image Translation with Text-to-Image Models](https://arxiv.org/abs/2403.12036)<br>
22
+ [Gaurav Parmar](https://gauravparmar.com/), [Taesung Park](https://taesung.me/), [Srinivasa Narasimhan](https://www.cs.cmu.edu/~srinivas/), [Jun-Yan Zhu](https://github.com/junyanz/)<br>
23
+ CMU and Adobe, arXiv 2403.12036
24
+
25
+ <br>
26
+ <div>
27
+ <p align="center">
28
+ <img src='assets/teaser_results.jpg' align="center" width=1000px>
29
+ </p>
30
+ </div>
31
+
32
+
33
+
34
+
35
+ ## Results
36
+
37
+ ### Paired Translation with pix2pix-turbo
38
+ **Edge to Image**
39
+ <div>
40
+ <p align="center">
41
+ <img src='assets/edge_to_image_results.jpg' align="center" width=800px>
42
+ </p>
43
+ </div>
44
+
45
+ <!-- **Sketch to Image**
46
+ TODO -->
47
+ ### Generating Diverse Outputs
48
+ By varying the input noise map, our method can generate diverse outputs from the same input conditioning.
49
+ The output style can be controlled by changing the text prompt.
50
+ <div> <p align="center">
51
+ <img src='assets/gen_variations.jpg' align="center" width=800px>
52
+ </p> </div>
53
+
54
+ ### Unpaired Translation with CycleGAN-Turbo
55
+
56
+ **Day to Night**
57
+ <div> <p align="center">
58
+ <img src='assets/day2night_results.jpg' align="center" width=800px>
59
+ </p> </div>
60
+
61
+ **Night to Day**
62
+ <div><p align="center">
63
+ <img src='assets/night2day_results.jpg' align="center" width=800px>
64
+ </p> </div>
65
+
66
+ **Clear to Rainy**
67
+ <div>
68
+ <p align="center">
69
+ <img src='assets/clear2rainy_results.jpg' align="center" width=800px>
70
+ </p>
71
+ </div>
72
+
73
+ **Rainy to Clear**
74
+ <div>
75
+ <p align="center">
76
+ <img src='assets/rainy2clear.jpg' align="center" width=800px>
77
+ </p>
78
+ </div>
79
+ <hr>
80
+
81
+
82
+ ## Method
83
+ **Our Generator Architecture:**
84
+ We tightly integrate three separate modules in the original latent diffusion models into a single end-to-end network with small trainable weights. This architecture allows us to translate the input image x to the output y, while retaining the input scene structure. We use LoRA adapters in each module, introduce skip connections and Zero-Convs between input and output, and retrain the first layer of the U-Net. Blue boxes indicate trainable layers. Semi-transparent layers are frozen. The same generator can be used for various GAN objectives.
85
+ <div>
86
+ <p align="center">
87
+ <img src='assets/method.jpg' align="center" width=900px>
88
+ </p>
89
+ </div>
90
+
91
+
92
+ ## Getting Started
93
+ **Environment Setup**
94
+ - We provide a [conda env file](environment.yml) that contains all the required dependencies.
95
+ ```
96
+ conda env create -f environment.yaml
97
+ ```
98
+ - Following this, you can activate the conda environment with the command below.
99
+ ```
100
+ conda activate img2img-turbo
101
+ ```
102
+ - Or use virtual environment:
103
+ ```
104
+ python3 -m venv venv
105
+ source venv/bin/activate
106
+ pip install -r requirements.txt
107
+ ```
108
+ **Paired Image Translation (pix2pix-turbo)**
109
+ - The following command takes an image file and a prompt as inputs, extracts the canny edges, and saves the results in the directory specified.
110
+ ```bash
111
+ python src/inference_paired.py --model_name "edge_to_image" \
112
+ --input_image "assets/examples/bird.png" \
113
+ --prompt "a blue bird" \
114
+ --output_dir "outputs"
115
+ ```
116
+ <table>
117
+ <th>Input Image</th>
118
+ <th>Canny Edges</th>
119
+ <th>Model Output</th>
120
+ </tr>
121
+ <tr>
122
+ <td><img src='assets/examples/bird.png' width="200px"></td>
123
+ <td><img src='assets/examples/bird_canny.png' width="200px"></td>
124
+ <td><img src='assets/examples/bird_canny_blue.png' width="200px"></td>
125
+ </tr>
126
+ </table>
127
+ <br>
128
+
129
+ - The following command takes a sketch and a prompt as inputs, and saves the results in the directory specified.
130
+ ```bash
131
+ python src/inference_paired.py --model_name "sketch_to_image_stochastic" \
132
+ --input_image "assets/examples/sketch_input.png" --gamma 0.4 \
133
+ --prompt "ethereal fantasy concept art of an asteroid. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy" \
134
+ --output_dir "outputs"
135
+ ```
136
+ <table>
137
+ <th>Input</th>
138
+ <th>Model Output</th>
139
+ </tr>
140
+ <tr>
141
+ <td><img src='assets/examples/sketch_input.png' width="400px"></td>
142
+ <td><img src='assets/examples/sketch_output.png' width="400px"></td>
143
+ </tr>
144
+ </table>
145
+ <br>
146
+
147
+ **Unpaired Image Translation (CycleGAN-Turbo)**
148
+ - The following command takes a **day** image file as input, and saves the output **night** in the directory specified.
149
+ ```
150
+ python src/inference_unpaired.py --model_name "day_to_night" \
151
+ --input_image "assets/examples/day2night_input.png" --output_dir "outputs"
152
+ ```
153
+ <table>
154
+ <th>Input (day)</th>
155
+ <th>Model Output (night)</th>
156
+ </tr>
157
+ <tr>
158
+ <td><img src='assets/examples/day2night_input.png' width="400px"></td>
159
+ <td><img src='assets/examples/day2night_output.png' width="400px"></td>
160
+ </tr>
161
+ </table>
162
+
163
+ - The following command takes a **night** image file as input, and saves the output **day** in the directory specified.
164
+ ```
165
+ python src/inference_unpaired.py --model_name "night_to_day" \
166
+ --input_image "assets/examples/night2day_input.png" --output_dir "outputs"
167
+ ```
168
+ <table>
169
+ <th>Input (night)</th>
170
+ <th>Model Output (day)</th>
171
+ </tr>
172
+ <tr>
173
+ <td><img src='assets/examples/night2day_input.png' width="400px"></td>
174
+ <td><img src='assets/examples/night2day_output.png' width="400px"></td>
175
+ </tr>
176
+ </table>
177
+
178
+ - The following command takes a **clear** image file as input, and saves the output **rainy** in the directory specified.
179
+ ```
180
+ python src/inference_unpaired.py --model_name "clear_to_rainy" \
181
+ --input_image "assets/examples/clear2rainy_input.png" --output_dir "outputs"
182
+ ```
183
+ <table>
184
+ <th>Input (clear)</th>
185
+ <th>Model Output (rainy)</th>
186
+ </tr>
187
+ <tr>
188
+ <td><img src='assets/examples/clear2rainy_input.png' width="400px"></td>
189
+ <td><img src='assets/examples/clear2rainy_output.png' width="400px"></td>
190
+ </tr>
191
+ </table>
192
+
193
+ - The following command takes a **rainy** image file as input, and saves the output **clear** in the directory specified.
194
+ ```
195
+ python src/inference_unpaired.py --model_name "rainy_to_clear" \
196
+ --input_image "assets/examples/rainy2clear_input.png" --output_dir "outputs"
197
+ ```
198
+ <table>
199
+ <th>Input (rainy)</th>
200
+ <th>Model Output (clear)</th>
201
+ </tr>
202
+ <tr>
203
+ <td><img src='assets/examples/rainy2clear_input.png' width="400px"></td>
204
+ <td><img src='assets/examples/rainy2clear_output.png' width="400px"></td>
205
+ </tr>
206
+ </table>
207
+
208
+
209
+
210
+ ## Gradio Demo
211
+ - We provide a Gradio demo for the paired image translation tasks.
212
+ - The following command will launch the sketch to image locally using gradio.
213
+ ```
214
+ gradio gradio_sketch2image.py
215
+ ```
216
+ - The following command will launch the canny edge to image gradio demo locally.
217
+ ```
218
+ gradio gradio_canny2image.py
219
+ ```
220
+
221
+
222
+ ## Training with your own data
223
+ - See the steps [here](docs/training_pix2pix_turbo.md) for training a pix2pix-turbo model on your paired data.
224
+ - See the steps [here](docs/training_cyclegan_turbo.md) for training a CycleGAN-Turbo model on your unpaired data.
225
+
226
+
227
+ ## Acknowledgment
228
+ Our work uses the Stable Diffusion-Turbo as the base model with the following [LICENSE](https://huggingface.co/stabilityai/sd-turbo/blob/main/LICENSE).
assets/cat_2x.gif ADDED

Git LFS Details

  • SHA256: 65a49403cf594d7b5300547edded6794e1306b61fb5f6837a96320a17954e826
  • Pointer size: 132 Bytes
  • Size of remote file: 4.63 MB
assets/clear2rainy_results.jpg ADDED

Git LFS Details

  • SHA256: f8b03789185cdb546080d0a3173e1e7054a4a013c2f3581d4d69fb4f99fe94d2
  • Pointer size: 132 Bytes
  • Size of remote file: 2.87 MB
assets/day2night_results.jpg ADDED

Git LFS Details

  • SHA256: 152448e2de3e09184f34e2d4bf8f41af02669fb6dafd77f4994a5da3b50410bf
  • Pointer size: 132 Bytes
  • Size of remote file: 2.91 MB
assets/edge_to_image_results.jpg ADDED

Git LFS Details

  • SHA256: c0e900c2fe954443b87c8643980c287ff91066a5adb21fbec75595c00a4ab615
  • Pointer size: 132 Bytes
  • Size of remote file: 2.37 MB
assets/examples/bird.png ADDED

Git LFS Details

  • SHA256: cad49fc7d3071b2bcd078bc8dde365f8fa62eaa6d43705fd50c212794a3aac35
  • Pointer size: 132 Bytes
  • Size of remote file: 1.07 MB
assets/examples/bird_canny.png ADDED
assets/examples/bird_canny_blue.png ADDED
assets/examples/circles_inference_input.png ADDED
assets/examples/circles_inference_output.png ADDED
assets/examples/clear2rainy_input.png ADDED
assets/examples/clear2rainy_output.png ADDED
assets/examples/day2night_input.png ADDED
assets/examples/day2night_output.png ADDED
assets/examples/my_horse2zebra_input.jpg ADDED
assets/examples/my_horse2zebra_output.jpg ADDED
assets/examples/night2day_input.png ADDED
assets/examples/night2day_output.png ADDED
assets/examples/rainy2clear_input.png ADDED
assets/examples/rainy2clear_output.png ADDED
assets/examples/sketch_input.png ADDED
assets/examples/sketch_output.png ADDED
assets/examples/training_evaluation.png ADDED
assets/examples/training_evaluation_unpaired.png ADDED
assets/examples/training_step_0.png ADDED
assets/examples/training_step_500.png ADDED
assets/examples/training_step_6000.png ADDED
assets/fish_2x.gif ADDED

Git LFS Details

  • SHA256: 9668ef45316f92d7c36db1e6d1854d2d413a2d87b32d73027149aeb02cc94e9d
  • Pointer size: 132 Bytes
  • Size of remote file: 2.48 MB
assets/gen_variations.jpg ADDED

Git LFS Details

  • SHA256: f9443d34ae70cc7d6d5123f7517b7f6e601ba6a59fedd63935e8dcd2dbf507e7
  • Pointer size: 132 Bytes
  • Size of remote file: 3.33 MB
assets/method.jpg ADDED
assets/night2day_results.jpg ADDED

Git LFS Details

  • SHA256: 2c2e0c3e5673e803482d881ab4df66e4e3103803e52daf48da43fb398742a3e8
  • Pointer size: 132 Bytes
  • Size of remote file: 2.37 MB
assets/rainy2clear.jpg ADDED

Git LFS Details

  • SHA256: ba435223d2c72430a9defeb7da94d43af9ddf67c32f11beb78c463f6a95347f5
  • Pointer size: 132 Bytes
  • Size of remote file: 2.49 MB
assets/teaser_results.jpg ADDED

Git LFS Details

  • SHA256: 55f14cff3825bf475ed7cf3847182a9689d4e7745204acbcd6ae8023d855e9ea
  • Pointer size: 132 Bytes
  • Size of remote file: 2.06 MB
docs/training_cyclegan_turbo.md ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Training with Unpaired Data (CycleGAN-turbo)
2
+ Here, we show how to train a CycleGAN-turbo model using unpaired data.
3
+ We will use the [horse2zebra dataset](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/datasets.md) introduced by [CycleGAN](https://junyanz.github.io/CycleGAN/) as an example dataset.
4
+
5
+
6
+ ### Step 1. Get the Dataset
7
+ - First download the horse2zebra dataset from [here](https://www.cs.cmu.edu/~img2img-turbo/data/my_horse2zebra.zip) using the command below.
8
+ ```
9
+ bash scripts/download_horse2zebra.sh
10
+ ```
11
+
12
+ - Our training scripts expect the dataset to be in the following format:
13
+ ```
14
+ data
15
+ ├── dataset_name
16
+ │ ├── train_A
17
+ │ │ ├── 000000.png
18
+ │ │ ├── 000001.png
19
+ │ │ └── ...
20
+ │ ├── train_B
21
+ │ │ ├── 000000.png
22
+ │ │ ├── 000001.png
23
+ │ │ └── ...
24
+ │ └── fixed_prompt_a.txt
25
+ | └── fixed_prompt_b.txt
26
+ |
27
+ | ├── test_A
28
+ │ │ ├── 000000.png
29
+ │ │ ├── 000001.png
30
+ │ │ └── ...
31
+ │ ├── test_B
32
+ │ │ ├── 000000.png
33
+ │ │ ├── 000001.png
34
+ │ │ └── ...
35
+ ```
36
+ - The `fixed_prompt_a.txt` and `fixed_prompt_b.txt` files contain the **fixed caption** used for the source and target domains respectively.
37
+
38
+
39
+ ### Step 2. Train the Model
40
+ - Initialize the `accelerate` environment with the following command:
41
+ ```
42
+ accelerate config
43
+ ```
44
+
45
+ - Run the following command to train the model.
46
+ ```
47
+ export NCCL_P2P_DISABLE=1
48
+ accelerate launch --main_process_port 29501 src/train_cyclegan_turbo.py \
49
+ --pretrained_model_name_or_path="stabilityai/sd-turbo" \
50
+ --output_dir="output/cyclegan_turbo/my_horse2zebra" \
51
+ --dataset_folder "data/my_horse2zebra" \
52
+ --train_img_prep "resize_286_randomcrop_256x256_hflip" --val_img_prep "no_resize" \
53
+ --learning_rate="1e-5" --max_train_steps=25000 \
54
+ --train_batch_size=1 --gradient_accumulation_steps=1 \
55
+ --report_to "wandb" --tracker_project_name "gparmar_unpaired_h2z_cycle_debug_v2" \
56
+ --enable_xformers_memory_efficient_attention --validation_steps 250 \
57
+ --lambda_gan 0.5 --lambda_idt 1 --lambda_cycle 1
58
+ ```
59
+
60
+ - Additional optional flags:
61
+ - `--enable_xformers_memory_efficient_attention`: Enable memory-efficient attention in the model.
62
+
63
+ ### Step 3. Monitor the training progress
64
+ - You can monitor the training progress using the [Weights & Biases](https://wandb.ai/site) dashboard.
65
+
66
+ - The training script will visualizing the training batch, the training losses, and validation set L2, LPIPS, and FID scores (if specified).
67
+ <div>
68
+ <p align="center">
69
+ <img src='../assets/examples/training_evaluation_unpaired.png' align="center" width=800px>
70
+ </p>
71
+ </div>
72
+
73
+
74
+ - The model checkpoints will be saved in the `<output_dir>/checkpoints` directory.
75
+
76
+
77
+ ### Step 4. Running Inference with the trained models
78
+
79
+ - You can run inference using the trained model using the following command:
80
+ ```
81
+ python src/inference_unpaired.py --model_path "output/cyclegan_turbo/my_horse2zebra/checkpoints/model_1001.pkl" \
82
+ --input_image "data/my_horse2zebra/test_A/n02381460_20.jpg" \
83
+ --prompt "picture of a zebra" --direction "a2b" \
84
+ --output_dir "outputs" --image_prep "no_resize"
85
+ ```
86
+
87
+ - The above command should generate the following output:
88
+ <table>
89
+ <tr>
90
+ <th>Model Input</th>
91
+ <th>Model Output</th>
92
+ </tr>
93
+ <tr>
94
+ <td><img src='../assets/examples/my_horse2zebra_input.jpg' width="200px"></td>
95
+ <td><img src='../assets/examples/my_horse2zebra_output.jpg' width="200px"></td>
96
+ </tr>
97
+ </table>
98
+
docs/training_pix2pix_turbo.md ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Training with Paired Data (pix2pix-turbo)
2
+ Here, we show how to train a pix2pix-turbo model using paired data.
3
+ We will use the [Fill50k dataset](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md) used by [ControlNet](https://github.com/lllyasviel/ControlNet) as an example dataset.
4
+
5
+
6
+ ### Step 1. Get the Dataset
7
+ - First download a modified Fill50k dataset from [here](https://www.cs.cmu.edu/~img2img-turbo/data/my_fill50k.zip) using the command below.
8
+ ```
9
+ bash scripts/download_fill50k.sh
10
+ ```
11
+
12
+ - Our training scripts expect the dataset to be in the following format:
13
+ ```
14
+ data
15
+ ├── dataset_name
16
+ │ ├── train_A
17
+ │ │ ├── 000000.png
18
+ │ │ ├── 000001.png
19
+ │ │ └── ...
20
+ │ ├── train_B
21
+ │ │ ├── 000000.png
22
+ │ │ ├── 000001.png
23
+ │ │ └── ...
24
+ │ └── train_prompts.json
25
+ |
26
+ | ├── test_A
27
+ │ │ ├── 000000.png
28
+ │ │ ├── 000001.png
29
+ │ │ └── ...
30
+ │ ├── test_B
31
+ │ │ ├── 000000.png
32
+ │ │ ├── 000001.png
33
+ │ │ └── ...
34
+ │ └── test_prompts.json
35
+ ```
36
+
37
+
38
+ ### Step 2. Train the Model
39
+ - Initialize the `accelerate` environment with the following command:
40
+ ```
41
+ accelerate config
42
+ ```
43
+
44
+ - Run the following command to train the model.
45
+ ```
46
+ accelerate launch src/train_pix2pix_turbo.py \
47
+ --pretrained_model_name_or_path="stabilityai/sd-turbo" \
48
+ --output_dir="output/pix2pix_turbo/fill50k" \
49
+ --dataset_folder="data/my_fill50k" \
50
+ --resolution=512 \
51
+ --train_batch_size=2 \
52
+ --enable_xformers_memory_efficient_attention --viz_freq 25 \
53
+ --track_val_fid \
54
+ --report_to "wandb" --tracker_project_name "pix2pix_turbo_fill50k"
55
+ ```
56
+
57
+ - Additional optional flags:
58
+ - `--track_val_fid`: Track FID score on the validation set using the [Clean-FID](https://github.com/GaParmar/clean-fid) implementation.
59
+ - `--enable_xformers_memory_efficient_attention`: Enable memory-efficient attention in the model.
60
+ - `--viz_freq`: Frequency of visualizing the results during training.
61
+
62
+ ### Step 3. Monitor the training progress
63
+ - You can monitor the training progress using the [Weights & Biases](https://wandb.ai/site) dashboard.
64
+
65
+ - The training script will visualizing the training batch, the training losses, and validation set L2, LPIPS, and FID scores (if specified).
66
+ <div>
67
+ <p align="center">
68
+ <img src='../assets/examples/training_evaluation.png' align="center" width=800px>
69
+ </p>
70
+ </div>
71
+
72
+
73
+ - The model checkpoints will be saved in the `<output_dir>/checkpoints` directory.
74
+
75
+ - Screenshots of the training progress are shown below:
76
+ - Step 0:
77
+ <div>
78
+ <p align="center">
79
+ <img src='../assets/examples/training_step_0.png' align="center" width=800px>
80
+ </p>
81
+ </div>
82
+
83
+ - Step 500:
84
+ <div>
85
+ <p align="center">
86
+ <img src='../assets/examples/training_step_500.png' align="center" width=800px>
87
+ </p>
88
+ </div>
89
+
90
+ - Step 6000:
91
+ <div>
92
+ <p align="center">
93
+ <img src='../assets/examples/training_step_6000.png' align="center" width=800px>
94
+ </p>
95
+ </div>
96
+
97
+
98
+ ### Step 4. Running Inference with the trained models
99
+
100
+ - You can run inference using the trained model using the following command:
101
+ ```
102
+ python src/inference_paired.py --model_path "output/pix2pix_turbo/fill50k/checkpoints/model_6001.pkl" \
103
+ --input_image "data/my_fill50k/test_A/40000.png" \
104
+ --prompt "violet circle with orange background" \
105
+ --output_dir "outputs"
106
+ ```
107
+
108
+ - The above command should generate the following output:
109
+ <table>
110
+ <tr>
111
+ <th>Model Input</th>
112
+ <th>Model Output</th>
113
+ </tr>
114
+ <tr>
115
+ <td><img src='../assets/examples/circles_inference_input.png' width="200px"></td>
116
+ <td><img src='../assets/examples/circles_inference_output.png' width="200px"></td>
117
+ </tr>
118
+ </table>
environment.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: img2img-turbo
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - python=3.10
7
+ - pip:
8
+ - clip @ git+https://github.com/openai/CLIP.git
9
+ - einops>=0.6.1
10
+ - numpy>=1.24.4
11
+ - open-clip-torch>=2.20.0
12
+ - opencv-python==4.6.0.66
13
+ - pillow>=9.5.0
14
+ - scipy==1.11.1
15
+ - timm>=0.9.2
16
+ - tokenizers
17
+ - torch>=2.0.1
18
+
19
+ - torchaudio>=2.0.2
20
+ - torchdata==0.6.1
21
+ - torchmetrics>=1.0.1
22
+ - torchvision>=0.15.2
23
+
24
+ - tqdm>=4.65.0
25
+ - transformers==4.35.2
26
+ - urllib3<1.27,>=1.25.4
27
+ - xformers>=0.0.20
28
+ - streamlit-keyup==0.2.0
29
+ - lpips
30
+ - clean-fid
31
+ - peft
32
+ - dominate
33
+ - diffusers==0.25.1
34
+ - gradio==3.43.1
gradio_canny2image.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import torch
4
+ from torchvision import transforms
5
+ import gradio as gr
6
+ from src.image_prep import canny_from_pil
7
+ from src.pix2pix_turbo import Pix2Pix_Turbo
8
+
9
+ model = Pix2Pix_Turbo("edge_to_image")
10
+
11
+
12
+ def process(input_image, prompt, low_threshold, high_threshold):
13
+ # resize to be a multiple of 8
14
+ new_width = input_image.width - input_image.width % 8
15
+ new_height = input_image.height - input_image.height % 8
16
+ input_image = input_image.resize((new_width, new_height))
17
+ canny = canny_from_pil(input_image, low_threshold, high_threshold)
18
+ with torch.no_grad():
19
+ c_t = transforms.ToTensor()(canny).unsqueeze(0).cuda()
20
+ output_image = model(c_t, prompt)
21
+ output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)
22
+ # flippy canny values, map all 0s to 1s and 1s to 0s
23
+ canny_viz = 1 - (np.array(canny) / 255)
24
+ canny_viz = Image.fromarray((canny_viz * 255).astype(np.uint8))
25
+ return canny_viz, output_pil
26
+
27
+
28
+ if __name__ == "__main__":
29
+ # load the model
30
+ with gr.Blocks() as demo:
31
+ gr.Markdown("# Pix2pix-Turbo: **Canny Edge -> Image**")
32
+ with gr.Row():
33
+ with gr.Column():
34
+ input_image = gr.Image(sources="upload", type="pil")
35
+ prompt = gr.Textbox(label="Prompt")
36
+ low_threshold = gr.Slider(
37
+ label="Canny low threshold",
38
+ minimum=1,
39
+ maximum=255,
40
+ value=100,
41
+ step=10,
42
+ )
43
+ high_threshold = gr.Slider(
44
+ label="Canny high threshold",
45
+ minimum=1,
46
+ maximum=255,
47
+ value=200,
48
+ step=10,
49
+ )
50
+ run_button = gr.Button(value="Run")
51
+ with gr.Column():
52
+ result_canny = gr.Image(type="pil")
53
+ with gr.Column():
54
+ result_output = gr.Image(type="pil")
55
+
56
+ prompt.submit(
57
+ fn=process,
58
+ inputs=[input_image, prompt, low_threshold, high_threshold],
59
+ outputs=[result_canny, result_output],
60
+ )
61
+ low_threshold.change(
62
+ fn=process,
63
+ inputs=[input_image, prompt, low_threshold, high_threshold],
64
+ outputs=[result_canny, result_output],
65
+ )
66
+ high_threshold.change(
67
+ fn=process,
68
+ inputs=[input_image, prompt, low_threshold, high_threshold],
69
+ outputs=[result_canny, result_output],
70
+ )
71
+ run_button.click(
72
+ fn=process,
73
+ inputs=[input_image, prompt, low_threshold, high_threshold],
74
+ outputs=[result_canny, result_output],
75
+ )
76
+
77
+ demo.queue()
78
+ demo.launch(debug=True, share=False)
gradio_sketch2image.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ from PIL import Image
4
+ import base64
5
+ from io import BytesIO
6
+
7
+ import torch
8
+ import torchvision.transforms.functional as F
9
+ import gradio as gr
10
+
11
+ from src.pix2pix_turbo import Pix2Pix_Turbo
12
+
13
+ model = Pix2Pix_Turbo("sketch_to_image_stochastic")
14
+
15
+ style_list = [
16
+ {
17
+ "name": "Cinematic",
18
+ "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
19
+ },
20
+ {
21
+ "name": "3D Model",
22
+ "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
23
+ },
24
+ {
25
+ "name": "Anime",
26
+ "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
27
+ },
28
+ {
29
+ "name": "Digital Art",
30
+ "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
31
+ },
32
+ {
33
+ "name": "Photographic",
34
+ "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
35
+ },
36
+ {
37
+ "name": "Pixel art",
38
+ "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
39
+ },
40
+ {
41
+ "name": "Fantasy art",
42
+ "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
43
+ },
44
+ {
45
+ "name": "Neonpunk",
46
+ "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
47
+ },
48
+ {
49
+ "name": "Manga",
50
+ "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
51
+ },
52
+ ]
53
+
54
+ styles = {k["name"]: k["prompt"] for k in style_list}
55
+ STYLE_NAMES = list(styles.keys())
56
+ DEFAULT_STYLE_NAME = "Fantasy art"
57
+ MAX_SEED = np.iinfo(np.int32).max
58
+
59
+
60
+ def pil_image_to_data_uri(img, format="PNG"):
61
+ buffered = BytesIO()
62
+ img.save(buffered, format=format)
63
+ img_str = base64.b64encode(buffered.getvalue()).decode()
64
+ return f"data:image/{format.lower()};base64,{img_str}"
65
+
66
+
67
+ def run(image, prompt, prompt_template, style_name, seed, val_r):
68
+ print(f"prompt: {prompt}")
69
+ print("sketch updated")
70
+ if image is None:
71
+ ones = Image.new("L", (512, 512), 255)
72
+ temp_uri = pil_image_to_data_uri(ones)
73
+ return ones, gr.update(link=temp_uri), gr.update(link=temp_uri)
74
+ prompt = prompt_template.replace("{prompt}", prompt)
75
+ image = image.convert("RGB")
76
+ image_t = F.to_tensor(image) > 0.5
77
+ print(f"r_val={val_r}, seed={seed}")
78
+ with torch.no_grad():
79
+ c_t = image_t.unsqueeze(0).cuda().float()
80
+ torch.manual_seed(seed)
81
+ B, C, H, W = c_t.shape
82
+ noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
83
+ output_image = model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
84
+ output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
85
+ input_sketch_uri = pil_image_to_data_uri(Image.fromarray(255 - np.array(image)))
86
+ output_image_uri = pil_image_to_data_uri(output_pil)
87
+ return (
88
+ output_pil,
89
+ gr.update(link=input_sketch_uri),
90
+ gr.update(link=output_image_uri),
91
+ )
92
+
93
+
94
+ def update_canvas(use_line, use_eraser):
95
+ if use_eraser:
96
+ _color = "#ffffff"
97
+ brush_size = 20
98
+ if use_line:
99
+ _color = "#000000"
100
+ brush_size = 4
101
+ return gr.update(brush_radius=brush_size, brush_color=_color, interactive=True)
102
+
103
+
104
+ def upload_sketch(file):
105
+ _img = Image.open(file.name)
106
+ _img = _img.convert("L")
107
+ return gr.update(value=_img, source="upload", interactive=True)
108
+
109
+
110
+ scripts = """
111
+ async () => {
112
+ globalThis.theSketchDownloadFunction = () => {
113
+ console.log("test")
114
+ var link = document.createElement("a");
115
+ dataUri = document.getElementById('download_sketch').href
116
+ link.setAttribute("href", dataUri)
117
+ link.setAttribute("download", "sketch.png")
118
+ document.body.appendChild(link); // Required for Firefox
119
+ link.click();
120
+ document.body.removeChild(link); // Clean up
121
+
122
+ // also call the output download function
123
+ theOutputDownloadFunction();
124
+ return false
125
+ }
126
+
127
+ globalThis.theOutputDownloadFunction = () => {
128
+ console.log("test output download function")
129
+ var link = document.createElement("a");
130
+ dataUri = document.getElementById('download_output').href
131
+ link.setAttribute("href", dataUri);
132
+ link.setAttribute("download", "output.png");
133
+ document.body.appendChild(link); // Required for Firefox
134
+ link.click();
135
+ document.body.removeChild(link); // Clean up
136
+ return false
137
+ }
138
+
139
+ globalThis.UNDO_SKETCH_FUNCTION = () => {
140
+ console.log("undo sketch function")
141
+ var button_undo = document.querySelector('#input_image > div.image-container.svelte-p3y7hu > div.svelte-s6ybro > button:nth-child(1)');
142
+ // Create a new 'click' event
143
+ var event = new MouseEvent('click', {
144
+ 'view': window,
145
+ 'bubbles': true,
146
+ 'cancelable': true
147
+ });
148
+ button_undo.dispatchEvent(event);
149
+ }
150
+
151
+ globalThis.DELETE_SKETCH_FUNCTION = () => {
152
+ console.log("delete sketch function")
153
+ var button_del = document.querySelector('#input_image > div.image-container.svelte-p3y7hu > div.svelte-s6ybro > button:nth-child(2)');
154
+ // Create a new 'click' event
155
+ var event = new MouseEvent('click', {
156
+ 'view': window,
157
+ 'bubbles': true,
158
+ 'cancelable': true
159
+ });
160
+ button_del.dispatchEvent(event);
161
+ }
162
+
163
+ globalThis.togglePencil = () => {
164
+ el_pencil = document.getElementById('my-toggle-pencil');
165
+ el_pencil.classList.toggle('clicked');
166
+ // simulate a click on the gradio button
167
+ btn_gradio = document.querySelector("#cb-line > label > input");
168
+ var event = new MouseEvent('click', {
169
+ 'view': window,
170
+ 'bubbles': true,
171
+ 'cancelable': true
172
+ });
173
+ btn_gradio.dispatchEvent(event);
174
+ if (el_pencil.classList.contains('clicked')) {
175
+ document.getElementById('my-toggle-eraser').classList.remove('clicked');
176
+ document.getElementById('my-div-pencil').style.backgroundColor = "gray";
177
+ document.getElementById('my-div-eraser').style.backgroundColor = "white";
178
+ }
179
+ else {
180
+ document.getElementById('my-toggle-eraser').classList.add('clicked');
181
+ document.getElementById('my-div-pencil').style.backgroundColor = "white";
182
+ document.getElementById('my-div-eraser').style.backgroundColor = "gray";
183
+ }
184
+ }
185
+
186
+ globalThis.toggleEraser = () => {
187
+ element = document.getElementById('my-toggle-eraser');
188
+ element.classList.toggle('clicked');
189
+ // simulate a click on the gradio button
190
+ btn_gradio = document.querySelector("#cb-eraser > label > input");
191
+ var event = new MouseEvent('click', {
192
+ 'view': window,
193
+ 'bubbles': true,
194
+ 'cancelable': true
195
+ });
196
+ btn_gradio.dispatchEvent(event);
197
+ if (element.classList.contains('clicked')) {
198
+ document.getElementById('my-toggle-pencil').classList.remove('clicked');
199
+ document.getElementById('my-div-pencil').style.backgroundColor = "white";
200
+ document.getElementById('my-div-eraser').style.backgroundColor = "gray";
201
+ }
202
+ else {
203
+ document.getElementById('my-toggle-pencil').classList.add('clicked');
204
+ document.getElementById('my-div-pencil').style.backgroundColor = "gray";
205
+ document.getElementById('my-div-eraser').style.backgroundColor = "white";
206
+ }
207
+ }
208
+ }
209
+ """
210
+
211
+ with gr.Blocks(css="style.css") as demo:
212
+
213
+ gr.HTML(
214
+ """
215
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
216
+ <div>
217
+ <h2><a href="https://github.com/GaParmar/img2img-turbo">One-Step Image Translation with Text-to-Image Models</a></h2>
218
+ <div>
219
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
220
+ <a href='https://gauravparmar.com/'>Gaurav Parmar, </a>
221
+ &nbsp;
222
+ <a href='https://taesung.me/'> Taesung Park,</a>
223
+ &nbsp;
224
+ <a href='https://www.cs.cmu.edu/~srinivas/'>Srinivasa Narasimhan, </a>
225
+ &nbsp;
226
+ <a href='https://www.cs.cmu.edu/~junyanz/'> Jun-Yan Zhu </a>
227
+ </div>
228
+ </div>
229
+ </br>
230
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
231
+ <a href='https://arxiv.org/abs/2403.12036'>
232
+ <img src="https://img.shields.io/badge/arXiv-2403.12036-red">
233
+ </a>
234
+ &nbsp;
235
+ <a href='https://github.com/GaParmar/img2img-turbo'>
236
+ <img src='https://img.shields.io/badge/github-%23121011.svg'>
237
+ </a>
238
+ &nbsp;
239
+ <a href='https://github.com/GaParmar/img2img-turbo/blob/main/LICENSE'>
240
+ <img src='https://img.shields.io/badge/license-MIT-lightgrey'>
241
+ </a>
242
+ </div>
243
+ </div>
244
+ </div>
245
+ <div>
246
+ </br>
247
+ </div>
248
+ """
249
+ )
250
+
251
+ # these are hidden buttons that are used to trigger the canvas changes
252
+ line = gr.Checkbox(label="line", value=False, elem_id="cb-line")
253
+ eraser = gr.Checkbox(label="eraser", value=False, elem_id="cb-eraser")
254
+ with gr.Row(elem_id="main_row"):
255
+ with gr.Column(elem_id="column_input"):
256
+ gr.Markdown("## INPUT", elem_id="input_header")
257
+ image = gr.Image(
258
+ source="canvas",
259
+ tool="color-sketch",
260
+ type="pil",
261
+ image_mode="L",
262
+ invert_colors=True,
263
+ shape=(512, 512),
264
+ brush_radius=4,
265
+ height=440,
266
+ width=440,
267
+ brush_color="#000000",
268
+ interactive=True,
269
+ show_download_button=True,
270
+ elem_id="input_image",
271
+ show_label=False,
272
+ )
273
+ download_sketch = gr.Button(
274
+ "Download sketch", scale=1, elem_id="download_sketch"
275
+ )
276
+
277
+ gr.HTML(
278
+ """
279
+ <div class="button-row">
280
+ <div id="my-div-pencil" class="pad2"> <button id="my-toggle-pencil" onclick="return togglePencil(this)"></button> </div>
281
+ <div id="my-div-eraser" class="pad2"> <button id="my-toggle-eraser" onclick="return toggleEraser(this)"></button> </div>
282
+ <div class="pad2"> <button id="my-button-undo" onclick="return UNDO_SKETCH_FUNCTION(this)"></button> </div>
283
+ <div class="pad2"> <button id="my-button-clear" onclick="return DELETE_SKETCH_FUNCTION(this)"></button> </div>
284
+ <div class="pad2"> <button href="TODO" download="image" id="my-button-down" onclick='return theSketchDownloadFunction()'></button> </div>
285
+ </div>
286
+ """
287
+ )
288
+ # gr.Markdown("## Prompt", elem_id="tools_header")
289
+ prompt = gr.Textbox(label="Prompt", value="", show_label=True)
290
+ with gr.Row():
291
+ style = gr.Dropdown(
292
+ label="Style",
293
+ choices=STYLE_NAMES,
294
+ value=DEFAULT_STYLE_NAME,
295
+ scale=1,
296
+ )
297
+ prompt_temp = gr.Textbox(
298
+ label="Prompt Style Template",
299
+ value=styles[DEFAULT_STYLE_NAME],
300
+ scale=2,
301
+ max_lines=1,
302
+ )
303
+
304
+ with gr.Row():
305
+ val_r = gr.Slider(
306
+ label="Sketch guidance: ",
307
+ show_label=True,
308
+ minimum=0,
309
+ maximum=1,
310
+ value=0.4,
311
+ step=0.01,
312
+ scale=3,
313
+ )
314
+ seed = gr.Textbox(label="Seed", value=42, scale=1, min_width=50)
315
+ randomize_seed = gr.Button("Random", scale=1, min_width=50)
316
+
317
+ with gr.Column(elem_id="column_process", min_width=50, scale=0.4):
318
+ gr.Markdown("## pix2pix-turbo", elem_id="description")
319
+ run_button = gr.Button("Run", min_width=50)
320
+
321
+ with gr.Column(elem_id="column_output"):
322
+ gr.Markdown("## OUTPUT", elem_id="output_header")
323
+ result = gr.Image(
324
+ label="Result",
325
+ height=440,
326
+ width=440,
327
+ elem_id="output_image",
328
+ show_label=False,
329
+ show_download_button=True,
330
+ )
331
+ download_output = gr.Button("Download output", elem_id="download_output")
332
+ gr.Markdown("### Instructions")
333
+ gr.Markdown("**1**. Enter a text prompt (e.g. cat)")
334
+ gr.Markdown("**2**. Start sketching")
335
+ gr.Markdown("**3**. Change the image style using a style template")
336
+ gr.Markdown("**4**. Adjust the effect of sketch guidance using the slider")
337
+ gr.Markdown("**5**. Try different seeds to generate different results")
338
+
339
+ eraser.change(
340
+ fn=lambda x: gr.update(value=not x),
341
+ inputs=[eraser],
342
+ outputs=[line],
343
+ queue=False,
344
+ api_name=False,
345
+ ).then(update_canvas, [line, eraser], [image])
346
+ line.change(
347
+ fn=lambda x: gr.update(value=not x),
348
+ inputs=[line],
349
+ outputs=[eraser],
350
+ queue=False,
351
+ api_name=False,
352
+ ).then(update_canvas, [line, eraser], [image])
353
+
354
+ demo.load(None, None, None, _js=scripts)
355
+ randomize_seed.click(
356
+ lambda x: random.randint(0, MAX_SEED),
357
+ inputs=[],
358
+ outputs=seed,
359
+ queue=False,
360
+ api_name=False,
361
+ )
362
+ inputs = [image, prompt, prompt_temp, style, seed, val_r]
363
+ outputs = [result, download_sketch, download_output]
364
+ prompt.submit(fn=run, inputs=inputs, outputs=outputs, api_name=False)
365
+ style.change(
366
+ lambda x: styles[x],
367
+ inputs=[style],
368
+ outputs=[prompt_temp],
369
+ queue=False,
370
+ api_name=False,
371
+ ).then(
372
+ fn=run,
373
+ inputs=inputs,
374
+ outputs=outputs,
375
+ api_name=False,
376
+ )
377
+ val_r.change(run, inputs=inputs, outputs=outputs, queue=False, api_name=False)
378
+ run_button.click(fn=run, inputs=inputs, outputs=outputs, api_name=False)
379
+ image.change(run, inputs=inputs, outputs=outputs, queue=False, api_name=False)
380
+
381
+ if __name__ == "__main__":
382
+ demo.queue().launch(debug=True, share=True)
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clip @ git+https://github.com/openai/CLIP.git
2
+ einops>=0.6.1
3
+ numpy>=1.24.4
4
+ open-clip-torch>=2.20.0
5
+ opencv-python==4.6.0.66
6
+ pillow>=9.5.0
7
+ scipy==1.11.1
8
+ timm>=0.9.2
9
+ tokenizers
10
+ torch>=2.0.1
11
+
12
+ torchaudio>=2.0.2
13
+ torchdata==0.6.1
14
+ torchmetrics>=1.0.1
15
+ torchvision>=0.15.2
16
+
17
+ tqdm>=4.65.0
18
+ transformers==4.35.2
19
+ urllib3<1.27,>=1.25.4
20
+ xformers>=0.0.20
21
+ streamlit-keyup==0.2.0
22
+ lpips
23
+ clean-fid
24
+ peft
25
+ dominate
26
+ diffusers==0.25.1
27
+ gradio==3.43.1
28
+
29
+ vision_aided_loss
scripts/download_fill50k.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ mkdir -p data
2
+ wget https://www.cs.cmu.edu/~img2img-turbo/data/my_fill50k.zip -O data/my_fill50k.zip
3
+ cd data
4
+ unzip my_fill50k.zip
5
+ rm my_fill50k.zip
scripts/download_horse2zebra.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ mkdir -p data
2
+ wget https://www.cs.cmu.edu/~img2img-turbo/data/my_horse2zebra.zip -O data/my_horse2zebra.zip
3
+ cd data
4
+ unzip my_horse2zebra.zip
5
+ rm my_horse2zebra.zip
src/cyclegan_turbo.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import copy
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import AutoTokenizer, CLIPTextModel
7
+ from diffusers import AutoencoderKL, UNet2DConditionModel
8
+ from peft import LoraConfig
9
+ from peft.utils import get_peft_model_state_dict
10
+ p = "src/"
11
+ sys.path.append(p)
12
+ from model import make_1step_sched, my_vae_encoder_fwd, my_vae_decoder_fwd, download_url
13
+
14
+
15
+ class VAE_encode(nn.Module):
16
+ def __init__(self, vae, vae_b2a=None):
17
+ super(VAE_encode, self).__init__()
18
+ self.vae = vae
19
+ self.vae_b2a = vae_b2a
20
+
21
+ def forward(self, x, direction):
22
+ assert direction in ["a2b", "b2a"]
23
+ if direction == "a2b":
24
+ _vae = self.vae
25
+ else:
26
+ _vae = self.vae_b2a
27
+ return _vae.encode(x).latent_dist.sample() * _vae.config.scaling_factor
28
+
29
+
30
+ class VAE_decode(nn.Module):
31
+ def __init__(self, vae, vae_b2a=None):
32
+ super(VAE_decode, self).__init__()
33
+ self.vae = vae
34
+ self.vae_b2a = vae_b2a
35
+
36
+ def forward(self, x, direction):
37
+ assert direction in ["a2b", "b2a"]
38
+ if direction == "a2b":
39
+ _vae = self.vae
40
+ else:
41
+ _vae = self.vae_b2a
42
+ assert _vae.encoder.current_down_blocks is not None
43
+ _vae.decoder.incoming_skip_acts = _vae.encoder.current_down_blocks
44
+ x_decoded = (_vae.decode(x / _vae.config.scaling_factor).sample).clamp(-1, 1)
45
+ return x_decoded
46
+
47
+
48
+ def initialize_unet(rank, return_lora_module_names=False):
49
+ unet = UNet2DConditionModel.from_pretrained("stabilityai/sd-turbo", subfolder="unet")
50
+ unet.requires_grad_(False)
51
+ unet.train()
52
+ l_target_modules_encoder, l_target_modules_decoder, l_modules_others = [], [], []
53
+ l_grep = ["to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_in", "conv_shortcut", "conv_out", "proj_out", "proj_in", "ff.net.2", "ff.net.0.proj"]
54
+ for n, p in unet.named_parameters():
55
+ if "bias" in n or "norm" in n: continue
56
+ for pattern in l_grep:
57
+ if pattern in n and ("down_blocks" in n or "conv_in" in n):
58
+ l_target_modules_encoder.append(n.replace(".weight",""))
59
+ break
60
+ elif pattern in n and "up_blocks" in n:
61
+ l_target_modules_decoder.append(n.replace(".weight",""))
62
+ break
63
+ elif pattern in n:
64
+ l_modules_others.append(n.replace(".weight",""))
65
+ break
66
+ lora_conf_encoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_encoder, lora_alpha=rank)
67
+ lora_conf_decoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_decoder, lora_alpha=rank)
68
+ lora_conf_others = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_modules_others, lora_alpha=rank)
69
+ unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder")
70
+ unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder")
71
+ unet.add_adapter(lora_conf_others, adapter_name="default_others")
72
+ unet.set_adapters(["default_encoder", "default_decoder", "default_others"])
73
+ if return_lora_module_names:
74
+ return unet, l_target_modules_encoder, l_target_modules_decoder, l_modules_others
75
+ else:
76
+ return unet
77
+
78
+
79
+ def initialize_vae(rank=4, return_lora_module_names=False):
80
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae")
81
+ vae.requires_grad_(False)
82
+ vae.encoder.forward = my_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__)
83
+ vae.decoder.forward = my_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__)
84
+ vae.requires_grad_(True)
85
+ vae.train()
86
+ # add the skip connection convs
87
+ vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
88
+ vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
89
+ vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
90
+ vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
91
+ torch.nn.init.constant_(vae.decoder.skip_conv_1.weight, 1e-5)
92
+ torch.nn.init.constant_(vae.decoder.skip_conv_2.weight, 1e-5)
93
+ torch.nn.init.constant_(vae.decoder.skip_conv_3.weight, 1e-5)
94
+ torch.nn.init.constant_(vae.decoder.skip_conv_4.weight, 1e-5)
95
+ vae.decoder.ignore_skip = False
96
+ vae.decoder.gamma = 1
97
+ l_vae_target_modules = ["conv1","conv2","conv_in", "conv_shortcut",
98
+ "conv", "conv_out", "skip_conv_1", "skip_conv_2", "skip_conv_3",
99
+ "skip_conv_4", "to_k", "to_q", "to_v", "to_out.0",
100
+ ]
101
+ vae_lora_config = LoraConfig(r=rank, init_lora_weights="gaussian", target_modules=l_vae_target_modules)
102
+ vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
103
+ if return_lora_module_names:
104
+ return vae, l_vae_target_modules
105
+ else:
106
+ return vae
107
+
108
+
109
+ class CycleGAN_Turbo(torch.nn.Module):
110
+ def __init__(self, pretrained_name=None, pretrained_path=None, ckpt_folder="checkpoints", lora_rank_unet=8, lora_rank_vae=4):
111
+ super().__init__()
112
+ self.tokenizer = AutoTokenizer.from_pretrained("stabilityai/sd-turbo", subfolder="tokenizer")
113
+ self.text_encoder = CLIPTextModel.from_pretrained("stabilityai/sd-turbo", subfolder="text_encoder").cuda()
114
+ self.sched = make_1step_sched()
115
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae")
116
+ unet = UNet2DConditionModel.from_pretrained("stabilityai/sd-turbo", subfolder="unet")
117
+ vae.encoder.forward = my_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__)
118
+ vae.decoder.forward = my_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__)
119
+ # add the skip connection convs
120
+ vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
121
+ vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
122
+ vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
123
+ vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
124
+ vae.decoder.ignore_skip = False
125
+ self.unet, self.vae = unet, vae
126
+ if pretrained_name == "day_to_night":
127
+ url = "https://www.cs.cmu.edu/~img2img-turbo/models/day2night.pkl"
128
+ self.load_ckpt_from_url(url, ckpt_folder)
129
+ self.timesteps = torch.tensor([999], device="cuda").long()
130
+ self.caption = "driving in the night"
131
+ self.direction = "a2b"
132
+ elif pretrained_name == "night_to_day":
133
+ url = "https://www.cs.cmu.edu/~img2img-turbo/models/night2day.pkl"
134
+ self.load_ckpt_from_url(url, ckpt_folder)
135
+ self.timesteps = torch.tensor([999], device="cuda").long()
136
+ self.caption = "driving in the day"
137
+ self.direction = "b2a"
138
+ elif pretrained_name == "clear_to_rainy":
139
+ url = "https://www.cs.cmu.edu/~img2img-turbo/models/clear2rainy.pkl"
140
+ self.load_ckpt_from_url(url, ckpt_folder)
141
+ self.timesteps = torch.tensor([999], device="cuda").long()
142
+ self.caption = "driving in heavy rain"
143
+ self.direction = "a2b"
144
+ elif pretrained_name == "rainy_to_clear":
145
+ url = "https://www.cs.cmu.edu/~img2img-turbo/models/rainy2clear.pkl"
146
+ self.load_ckpt_from_url(url, ckpt_folder)
147
+ self.timesteps = torch.tensor([999], device="cuda").long()
148
+ self.caption = "driving in the day"
149
+ self.direction = "b2a"
150
+
151
+ elif pretrained_path is not None:
152
+ sd = torch.load(pretrained_path)
153
+ self.load_ckpt_from_state_dict(sd)
154
+ self.timesteps = torch.tensor([999], device="cuda").long()
155
+ self.caption = None
156
+ self.direction = None
157
+
158
+ self.vae_enc.cuda()
159
+ self.vae_dec.cuda()
160
+ self.unet.cuda()
161
+
162
+ def load_ckpt_from_state_dict(self, sd):
163
+ lora_conf_encoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_target_modules_encoder"], lora_alpha=sd["rank_unet"])
164
+ lora_conf_decoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_target_modules_decoder"], lora_alpha=sd["rank_unet"])
165
+ lora_conf_others = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_modules_others"], lora_alpha=sd["rank_unet"])
166
+ self.unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder")
167
+ self.unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder")
168
+ self.unet.add_adapter(lora_conf_others, adapter_name="default_others")
169
+ for n, p in self.unet.named_parameters():
170
+ name_sd = n.replace(".default_encoder.weight", ".weight")
171
+ if "lora" in n and "default_encoder" in n:
172
+ p.data.copy_(sd["sd_encoder"][name_sd])
173
+ for n, p in self.unet.named_parameters():
174
+ name_sd = n.replace(".default_decoder.weight", ".weight")
175
+ if "lora" in n and "default_decoder" in n:
176
+ p.data.copy_(sd["sd_decoder"][name_sd])
177
+ for n, p in self.unet.named_parameters():
178
+ name_sd = n.replace(".default_others.weight", ".weight")
179
+ if "lora" in n and "default_others" in n:
180
+ p.data.copy_(sd["sd_other"][name_sd])
181
+ self.unet.set_adapter(["default_encoder", "default_decoder", "default_others"])
182
+
183
+ vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"])
184
+ self.vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
185
+ self.vae.decoder.gamma = 1
186
+ self.vae_b2a = copy.deepcopy(self.vae)
187
+ self.vae_enc = VAE_encode(self.vae, vae_b2a=self.vae_b2a)
188
+ self.vae_enc.load_state_dict(sd["sd_vae_enc"])
189
+ self.vae_dec = VAE_decode(self.vae, vae_b2a=self.vae_b2a)
190
+ self.vae_dec.load_state_dict(sd["sd_vae_dec"])
191
+
192
+ def load_ckpt_from_url(self, url, ckpt_folder):
193
+ os.makedirs(ckpt_folder, exist_ok=True)
194
+ outf = os.path.join(ckpt_folder, os.path.basename(url))
195
+ download_url(url, outf)
196
+ sd = torch.load(outf)
197
+ self.load_ckpt_from_state_dict(sd)
198
+
199
+ @staticmethod
200
+ def forward_with_networks(x, direction, vae_enc, unet, vae_dec, sched, timesteps, text_emb):
201
+ B = x.shape[0]
202
+ assert direction in ["a2b", "b2a"]
203
+ x_enc = vae_enc(x, direction=direction).to(x.dtype)
204
+ model_pred = unet(x_enc, timesteps, encoder_hidden_states=text_emb,).sample
205
+ x_out = torch.stack([sched.step(model_pred[i], timesteps[i], x_enc[i], return_dict=True).prev_sample for i in range(B)])
206
+ x_out_decoded = vae_dec(x_out, direction=direction)
207
+ return x_out_decoded
208
+
209
+ @staticmethod
210
+ def get_traininable_params(unet, vae_a2b, vae_b2a):
211
+ # add all unet parameters
212
+ params_gen = list(unet.conv_in.parameters())
213
+ unet.conv_in.requires_grad_(True)
214
+ unet.set_adapters(["default_encoder", "default_decoder", "default_others"])
215
+ for n,p in unet.named_parameters():
216
+ if "lora" in n and "default" in n:
217
+ assert p.requires_grad
218
+ params_gen.append(p)
219
+
220
+ # add all vae_a2b parameters
221
+ for n,p in vae_a2b.named_parameters():
222
+ if "lora" in n and "vae_skip" in n:
223
+ assert p.requires_grad
224
+ params_gen.append(p)
225
+ params_gen = params_gen + list(vae_a2b.decoder.skip_conv_1.parameters())
226
+ params_gen = params_gen + list(vae_a2b.decoder.skip_conv_2.parameters())
227
+ params_gen = params_gen + list(vae_a2b.decoder.skip_conv_3.parameters())
228
+ params_gen = params_gen + list(vae_a2b.decoder.skip_conv_4.parameters())
229
+
230
+ # add all vae_b2a parameters
231
+ for n,p in vae_b2a.named_parameters():
232
+ if "lora" in n and "vae_skip" in n:
233
+ assert p.requires_grad
234
+ params_gen.append(p)
235
+ params_gen = params_gen + list(vae_b2a.decoder.skip_conv_1.parameters())
236
+ params_gen = params_gen + list(vae_b2a.decoder.skip_conv_2.parameters())
237
+ params_gen = params_gen + list(vae_b2a.decoder.skip_conv_3.parameters())
238
+ params_gen = params_gen + list(vae_b2a.decoder.skip_conv_4.parameters())
239
+ return params_gen
240
+
241
+ def forward(self, x_t, direction=None, caption=None, caption_emb=None):
242
+ if direction is None:
243
+ assert self.direction is not None
244
+ direction = self.direction
245
+ if caption is None and caption_emb is None:
246
+ assert self.caption is not None
247
+ caption = self.caption
248
+ if caption_emb is not None:
249
+ caption_enc = caption_emb
250
+ else:
251
+ caption_tokens = self.tokenizer(caption, max_length=self.tokenizer.model_max_length,
252
+ padding="max_length", truncation=True, return_tensors="pt").input_ids.to(x_t.device)
253
+ caption_enc = self.text_encoder(caption_tokens)[0].detach().clone()
254
+ return self.forward_with_networks(x_t, direction, self.vae_enc, self.unet, self.vae_dec, self.sched, self.timesteps, caption_enc)
src/image_prep.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import cv2
4
+
5
+
6
+ def canny_from_pil(image, low_threshold=100, high_threshold=200):
7
+ image = np.array(image)
8
+ image = cv2.Canny(image, low_threshold, high_threshold)
9
+ image = image[:, :, None]
10
+ image = np.concatenate([image, image, image], axis=2)
11
+ control_image = Image.fromarray(image)
12
+ return control_image
src/inference_paired.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ from torchvision import transforms
7
+ import torchvision.transforms.functional as F
8
+ from pix2pix_turbo import Pix2Pix_Turbo
9
+ from image_prep import canny_from_pil
10
+
11
+ if __name__ == "__main__":
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument('--input_image', type=str, required=True, help='path to the input image')
14
+ parser.add_argument('--prompt', type=str, required=True, help='the prompt to be used')
15
+ parser.add_argument('--model_name', type=str, default='', help='name of the pretrained model to be used')
16
+ parser.add_argument('--model_path', type=str, default='', help='path to a model state dict to be used')
17
+ parser.add_argument('--output_dir', type=str, default='output', help='the directory to save the output')
18
+ parser.add_argument('--low_threshold', type=int, default=100, help='Canny low threshold')
19
+ parser.add_argument('--high_threshold', type=int, default=200, help='Canny high threshold')
20
+ parser.add_argument('--gamma', type=float, default=0.4, help='The sketch interpolation guidance amount')
21
+ parser.add_argument('--seed', type=int, default=42, help='Random seed to be used')
22
+ parser.add_argument('--use_fp16', action='store_true', help='Use Float16 precision for faster inference')
23
+ args = parser.parse_args()
24
+
25
+ # only one of model_name and model_path should be provided
26
+ if args.model_name == '' != args.model_path == '':
27
+ raise ValueError('Either model_name or model_path should be provided')
28
+
29
+ os.makedirs(args.output_dir, exist_ok=True)
30
+
31
+ # initialize the model
32
+ model = Pix2Pix_Turbo(pretrained_name=args.model_name, pretrained_path=args.model_path)
33
+ model.set_eval()
34
+ if args.use_fp16:
35
+ model.half()
36
+
37
+ # make sure that the input image is a multiple of 8
38
+ input_image = Image.open(args.input_image).convert('RGB')
39
+ new_width = input_image.width - input_image.width % 8
40
+ new_height = input_image.height - input_image.height % 8
41
+ input_image = input_image.resize((new_width, new_height), Image.LANCZOS)
42
+ bname = os.path.basename(args.input_image)
43
+
44
+ # translate the image
45
+ with torch.no_grad():
46
+ if args.model_name == 'edge_to_image':
47
+ canny = canny_from_pil(input_image, args.low_threshold, args.high_threshold)
48
+ canny_viz_inv = Image.fromarray(255 - np.array(canny))
49
+ canny_viz_inv.save(os.path.join(args.output_dir, bname.replace('.png', '_canny.png')))
50
+ c_t = F.to_tensor(canny).unsqueeze(0).cuda()
51
+ if args.use_fp16:
52
+ c_t = c_t.half()
53
+ output_image = model(c_t, args.prompt)
54
+
55
+ elif args.model_name == 'sketch_to_image_stochastic':
56
+ image_t = F.to_tensor(input_image) < 0.5
57
+ c_t = image_t.unsqueeze(0).cuda().float()
58
+ torch.manual_seed(args.seed)
59
+ B, C, H, W = c_t.shape
60
+ noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
61
+ if args.use_fp16:
62
+ c_t = c_t.half()
63
+ noise = noise.half()
64
+ output_image = model(c_t, args.prompt, deterministic=False, r=args.gamma, noise_map=noise)
65
+
66
+ else:
67
+ c_t = F.to_tensor(input_image).unsqueeze(0).cuda()
68
+ if args.use_fp16:
69
+ c_t = c_t.half()
70
+ output_image = model(c_t, args.prompt)
71
+
72
+ output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)
73
+
74
+ # save the output image
75
+ output_pil.save(os.path.join(args.output_dir, bname))
src/inference_unpaired.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from PIL import Image
4
+ import torch
5
+ from torchvision import transforms
6
+ from cyclegan_turbo import CycleGAN_Turbo
7
+ from my_utils.training_utils import build_transform
8
+
9
+
10
+ if __name__ == "__main__":
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument('--input_image', type=str, required=True, help='path to the input image')
13
+ parser.add_argument('--prompt', type=str, required=False, help='the prompt to be used. It is required when loading a custom model_path.')
14
+ parser.add_argument('--model_name', type=str, default=None, help='name of the pretrained model to be used')
15
+ parser.add_argument('--model_path', type=str, default=None, help='path to a local model state dict to be used')
16
+ parser.add_argument('--output_dir', type=str, default='output', help='the directory to save the output')
17
+ parser.add_argument('--image_prep', type=str, default='resize_512x512', help='the image preparation method')
18
+ parser.add_argument('--direction', type=str, default=None, help='the direction of translation. None for pretrained models, a2b or b2a for custom paths.')
19
+ parser.add_argument('--use_fp16', action='store_true', help='Use Float16 precision for faster inference')
20
+ args = parser.parse_args()
21
+
22
+ # only one of model_name and model_path should be provided
23
+ if args.model_name is None != args.model_path is None:
24
+ raise ValueError('Either model_name or model_path should be provided')
25
+
26
+ if args.model_path is not None and args.prompt is None:
27
+ raise ValueError('prompt is required when loading a custom model_path.')
28
+
29
+ if args.model_name is not None:
30
+ assert args.prompt is None, 'prompt is not required when loading a pretrained model.'
31
+ assert args.direction is None, 'direction is not required when loading a pretrained model.'
32
+
33
+ # initialize the model
34
+ model = CycleGAN_Turbo(pretrained_name=args.model_name, pretrained_path=args.model_path)
35
+ model.eval()
36
+ model.unet.enable_xformers_memory_efficient_attention()
37
+ if args.use_fp16:
38
+ model.half()
39
+
40
+ T_val = build_transform(args.image_prep)
41
+
42
+ input_image = Image.open(args.input_image).convert('RGB')
43
+ # translate the image
44
+ with torch.no_grad():
45
+ input_img = T_val(input_image)
46
+ x_t = transforms.ToTensor()(input_img)
47
+ x_t = transforms.Normalize([0.5], [0.5])(x_t).unsqueeze(0).cuda()
48
+ if args.use_fp16:
49
+ x_t = x_t.half()
50
+ output = model(x_t, direction=args.direction, caption=args.prompt)
51
+
52
+ output_pil = transforms.ToPILImage()(output[0].cpu() * 0.5 + 0.5)
53
+ output_pil = output_pil.resize((input_image.width, input_image.height), Image.LANCZOS)
54
+
55
+ # save the output image
56
+ bname = os.path.basename(args.input_image)
57
+ os.makedirs(args.output_dir, exist_ok=True)
58
+ output_pil.save(os.path.join(args.output_dir, bname))
src/model.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from tqdm import tqdm
4
+ from diffusers import DDPMScheduler
5
+
6
+
7
+ def make_1step_sched():
8
+ noise_scheduler_1step = DDPMScheduler.from_pretrained("stabilityai/sd-turbo", subfolder="scheduler")
9
+ noise_scheduler_1step.set_timesteps(1, device="cuda")
10
+ noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda()
11
+ return noise_scheduler_1step
12
+
13
+
14
+ def my_vae_encoder_fwd(self, sample):
15
+ sample = self.conv_in(sample)
16
+ l_blocks = []
17
+ # down
18
+ for down_block in self.down_blocks:
19
+ l_blocks.append(sample)
20
+ sample = down_block(sample)
21
+ # middle
22
+ sample = self.mid_block(sample)
23
+ sample = self.conv_norm_out(sample)
24
+ sample = self.conv_act(sample)
25
+ sample = self.conv_out(sample)
26
+ self.current_down_blocks = l_blocks
27
+ return sample
28
+
29
+
30
+ def my_vae_decoder_fwd(self, sample, latent_embeds=None):
31
+ sample = self.conv_in(sample)
32
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
33
+ # middle
34
+ sample = self.mid_block(sample, latent_embeds)
35
+ sample = sample.to(upscale_dtype)
36
+ if not self.ignore_skip:
37
+ skip_convs = [self.skip_conv_1, self.skip_conv_2, self.skip_conv_3, self.skip_conv_4]
38
+ # up
39
+ for idx, up_block in enumerate(self.up_blocks):
40
+ skip_in = skip_convs[idx](self.incoming_skip_acts[::-1][idx] * self.gamma)
41
+ # add skip
42
+ sample = sample + skip_in
43
+ sample = up_block(sample, latent_embeds)
44
+ else:
45
+ for idx, up_block in enumerate(self.up_blocks):
46
+ sample = up_block(sample, latent_embeds)
47
+ # post-process
48
+ if latent_embeds is None:
49
+ sample = self.conv_norm_out(sample)
50
+ else:
51
+ sample = self.conv_norm_out(sample, latent_embeds)
52
+ sample = self.conv_act(sample)
53
+ sample = self.conv_out(sample)
54
+ return sample
55
+
56
+
57
+ def download_url(url, outf):
58
+ if not os.path.exists(outf):
59
+ print(f"Downloading checkpoint to {outf}")
60
+ response = requests.get(url, stream=True)
61
+ total_size_in_bytes = int(response.headers.get('content-length', 0))
62
+ block_size = 1024 # 1 Kibibyte
63
+ progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
64
+ with open(outf, 'wb') as file:
65
+ for data in response.iter_content(block_size):
66
+ progress_bar.update(len(data))
67
+ file.write(data)
68
+ progress_bar.close()
69
+ if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
70
+ print("ERROR, something went wrong")
71
+ print(f"Downloaded successfully to {outf}")
72
+ else:
73
+ print(f"Skipping download, {outf} already exists")
src/my_utils/dino_struct.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def attn_cosine_sim(x, eps=1e-08):
7
+ x = x[0] # TEMP: getting rid of redundant dimension, TBF
8
+ norm1 = x.norm(dim=2, keepdim=True)
9
+ factor = torch.clamp(norm1 @ norm1.permute(0, 2, 1), min=eps)
10
+ sim_matrix = (x @ x.permute(0, 2, 1)) / factor
11
+ return sim_matrix
12
+
13
+
14
+ class VitExtractor:
15
+ BLOCK_KEY = 'block'
16
+ ATTN_KEY = 'attn'
17
+ PATCH_IMD_KEY = 'patch_imd'
18
+ QKV_KEY = 'qkv'
19
+ KEY_LIST = [BLOCK_KEY, ATTN_KEY, PATCH_IMD_KEY, QKV_KEY]
20
+
21
+ def __init__(self, model_name, device):
22
+ # pdb.set_trace()
23
+ self.model = torch.hub.load('facebookresearch/dino:main', model_name).to(device)
24
+ self.model.eval()
25
+ self.model_name = model_name
26
+ self.hook_handlers = []
27
+ self.layers_dict = {}
28
+ self.outputs_dict = {}
29
+ for key in VitExtractor.KEY_LIST:
30
+ self.layers_dict[key] = []
31
+ self.outputs_dict[key] = []
32
+ self._init_hooks_data()
33
+
34
+ def _init_hooks_data(self):
35
+ self.layers_dict[VitExtractor.BLOCK_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
36
+ self.layers_dict[VitExtractor.ATTN_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
37
+ self.layers_dict[VitExtractor.QKV_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
38
+ self.layers_dict[VitExtractor.PATCH_IMD_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
39
+ for key in VitExtractor.KEY_LIST:
40
+ # self.layers_dict[key] = kwargs[key] if key in kwargs.keys() else []
41
+ self.outputs_dict[key] = []
42
+
43
+ def _register_hooks(self, **kwargs):
44
+ for block_idx, block in enumerate(self.model.blocks):
45
+ if block_idx in self.layers_dict[VitExtractor.BLOCK_KEY]:
46
+ self.hook_handlers.append(block.register_forward_hook(self._get_block_hook()))
47
+ if block_idx in self.layers_dict[VitExtractor.ATTN_KEY]:
48
+ self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_attn_hook()))
49
+ if block_idx in self.layers_dict[VitExtractor.QKV_KEY]:
50
+ self.hook_handlers.append(block.attn.qkv.register_forward_hook(self._get_qkv_hook()))
51
+ if block_idx in self.layers_dict[VitExtractor.PATCH_IMD_KEY]:
52
+ self.hook_handlers.append(block.attn.register_forward_hook(self._get_patch_imd_hook()))
53
+
54
+ def _clear_hooks(self):
55
+ for handler in self.hook_handlers:
56
+ handler.remove()
57
+ self.hook_handlers = []
58
+
59
+ def _get_block_hook(self):
60
+ def _get_block_output(model, input, output):
61
+ self.outputs_dict[VitExtractor.BLOCK_KEY].append(output)
62
+
63
+ return _get_block_output
64
+
65
+ def _get_attn_hook(self):
66
+ def _get_attn_output(model, inp, output):
67
+ self.outputs_dict[VitExtractor.ATTN_KEY].append(output)
68
+
69
+ return _get_attn_output
70
+
71
+ def _get_qkv_hook(self):
72
+ def _get_qkv_output(model, inp, output):
73
+ self.outputs_dict[VitExtractor.QKV_KEY].append(output)
74
+
75
+ return _get_qkv_output
76
+
77
+ # TODO: CHECK ATTN OUTPUT TUPLE
78
+ def _get_patch_imd_hook(self):
79
+ def _get_attn_output(model, inp, output):
80
+ self.outputs_dict[VitExtractor.PATCH_IMD_KEY].append(output[0])
81
+
82
+ return _get_attn_output
83
+
84
+ def get_feature_from_input(self, input_img): # List([B, N, D])
85
+ self._register_hooks()
86
+ self.model(input_img)
87
+ feature = self.outputs_dict[VitExtractor.BLOCK_KEY]
88
+ self._clear_hooks()
89
+ self._init_hooks_data()
90
+ return feature
91
+
92
+ def get_qkv_feature_from_input(self, input_img):
93
+ self._register_hooks()
94
+ self.model(input_img)
95
+ feature = self.outputs_dict[VitExtractor.QKV_KEY]
96
+ self._clear_hooks()
97
+ self._init_hooks_data()
98
+ return feature
99
+
100
+ def get_attn_feature_from_input(self, input_img):
101
+ self._register_hooks()
102
+ self.model(input_img)
103
+ feature = self.outputs_dict[VitExtractor.ATTN_KEY]
104
+ self._clear_hooks()
105
+ self._init_hooks_data()
106
+ return feature
107
+
108
+ def get_patch_size(self):
109
+ return 8 if "8" in self.model_name else 16
110
+
111
+ def get_width_patch_num(self, input_img_shape):
112
+ b, c, h, w = input_img_shape
113
+ patch_size = self.get_patch_size()
114
+ return w // patch_size
115
+
116
+ def get_height_patch_num(self, input_img_shape):
117
+ b, c, h, w = input_img_shape
118
+ patch_size = self.get_patch_size()
119
+ return h // patch_size
120
+
121
+ def get_patch_num(self, input_img_shape):
122
+ patch_num = 1 + (self.get_height_patch_num(input_img_shape) * self.get_width_patch_num(input_img_shape))
123
+ return patch_num
124
+
125
+ def get_head_num(self):
126
+ if "dino" in self.model_name:
127
+ return 6 if "s" in self.model_name else 12
128
+ return 6 if "small" in self.model_name else 12
129
+
130
+ def get_embedding_dim(self):
131
+ if "dino" in self.model_name:
132
+ return 384 if "s" in self.model_name else 768
133
+ return 384 if "small" in self.model_name else 768
134
+
135
+ def get_queries_from_qkv(self, qkv, input_img_shape):
136
+ patch_num = self.get_patch_num(input_img_shape)
137
+ head_num = self.get_head_num()
138
+ embedding_dim = self.get_embedding_dim()
139
+ q = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[0]
140
+ return q
141
+
142
+ def get_keys_from_qkv(self, qkv, input_img_shape):
143
+ patch_num = self.get_patch_num(input_img_shape)
144
+ head_num = self.get_head_num()
145
+ embedding_dim = self.get_embedding_dim()
146
+ k = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[1]
147
+ return k
148
+
149
+ def get_values_from_qkv(self, qkv, input_img_shape):
150
+ patch_num = self.get_patch_num(input_img_shape)
151
+ head_num = self.get_head_num()
152
+ embedding_dim = self.get_embedding_dim()
153
+ v = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[2]
154
+ return v
155
+
156
+ def get_keys_from_input(self, input_img, layer_num):
157
+ qkv_features = self.get_qkv_feature_from_input(input_img)[layer_num]
158
+ keys = self.get_keys_from_qkv(qkv_features, input_img.shape)
159
+ return keys
160
+
161
+ def get_keys_self_sim_from_input(self, input_img, layer_num):
162
+ keys = self.get_keys_from_input(input_img, layer_num=layer_num)
163
+ h, t, d = keys.shape
164
+ concatenated_keys = keys.transpose(0, 1).reshape(t, h * d)
165
+ ssim_map = attn_cosine_sim(concatenated_keys[None, None, ...])
166
+ return ssim_map
167
+
168
+
169
+ class DinoStructureLoss:
170
+ def __init__(self, ):
171
+ self.extractor = VitExtractor(model_name="dino_vitb8", device="cuda")
172
+ self.preprocess = torchvision.transforms.Compose([
173
+ torchvision.transforms.Resize(224),
174
+ torchvision.transforms.ToTensor(),
175
+ torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
176
+ ])
177
+
178
+ def calculate_global_ssim_loss(self, outputs, inputs):
179
+ loss = 0.0
180
+ for a, b in zip(inputs, outputs): # avoid memory limitations
181
+ with torch.no_grad():
182
+ target_keys_self_sim = self.extractor.get_keys_self_sim_from_input(a.unsqueeze(0), layer_num=11)
183
+ keys_ssim = self.extractor.get_keys_self_sim_from_input(b.unsqueeze(0), layer_num=11)
184
+ loss += F.mse_loss(keys_ssim, target_keys_self_sim)
185
+ return loss
src/my_utils/training_utils.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import argparse
4
+ import json
5
+ import torch
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+ import torchvision.transforms.functional as F
9
+ from glob import glob
10
+
11
+
12
+ def parse_args_paired_training(input_args=None):
13
+ """
14
+ Parses command-line arguments used for configuring an paired session (pix2pix-Turbo).
15
+ This function sets up an argument parser to handle various training options.
16
+
17
+ Returns:
18
+ argparse.Namespace: The parsed command-line arguments.
19
+ """
20
+ parser = argparse.ArgumentParser()
21
+ # args for the loss function
22
+ parser.add_argument("--gan_disc_type", default="vagan_clip")
23
+ parser.add_argument("--gan_loss_type", default="multilevel_sigmoid_s")
24
+ parser.add_argument("--lambda_gan", default=0.5, type=float)
25
+ parser.add_argument("--lambda_lpips", default=5, type=float)
26
+ parser.add_argument("--lambda_l2", default=1.0, type=float)
27
+ parser.add_argument("--lambda_clipsim", default=5.0, type=float)
28
+
29
+ # dataset options
30
+ parser.add_argument("--dataset_folder", required=True, type=str)
31
+ parser.add_argument("--train_image_prep", default="resized_crop_512", type=str)
32
+ parser.add_argument("--test_image_prep", default="resized_crop_512", type=str)
33
+
34
+ # validation eval args
35
+ parser.add_argument("--eval_freq", default=100, type=int)
36
+ parser.add_argument("--track_val_fid", default=False, action="store_true")
37
+ parser.add_argument("--num_samples_eval", type=int, default=100, help="Number of samples to use for all evaluation")
38
+
39
+ parser.add_argument("--viz_freq", type=int, default=100, help="Frequency of visualizing the outputs.")
40
+ parser.add_argument("--tracker_project_name", type=str, default="train_pix2pix_turbo", help="The name of the wandb project to log to.")
41
+
42
+ # details about the model architecture
43
+ parser.add_argument("--pretrained_model_name_or_path")
44
+ parser.add_argument("--revision", type=str, default=None,)
45
+ parser.add_argument("--variant", type=str, default=None,)
46
+ parser.add_argument("--tokenizer_name", type=str, default=None)
47
+ parser.add_argument("--lora_rank_unet", default=8, type=int)
48
+ parser.add_argument("--lora_rank_vae", default=4, type=int)
49
+
50
+ # training details
51
+ parser.add_argument("--output_dir", required=True)
52
+ parser.add_argument("--cache_dir", default=None,)
53
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
54
+ parser.add_argument("--resolution", type=int, default=512,)
55
+ parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.")
56
+ parser.add_argument("--num_training_epochs", type=int, default=10)
57
+ parser.add_argument("--max_train_steps", type=int, default=10_000,)
58
+ parser.add_argument("--checkpointing_steps", type=int, default=500,)
59
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.",)
60
+ parser.add_argument("--gradient_checkpointing", action="store_true",)
61
+ parser.add_argument("--learning_rate", type=float, default=5e-6)
62
+ parser.add_argument("--lr_scheduler", type=str, default="constant",
63
+ help=(
64
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
65
+ ' "constant", "constant_with_warmup"]'
66
+ ),
67
+ )
68
+ parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler.")
69
+ parser.add_argument("--lr_num_cycles", type=int, default=1,
70
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
71
+ )
72
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
73
+
74
+ parser.add_argument("--dataloader_num_workers", type=int, default=0,)
75
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
76
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
77
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
78
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
79
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
80
+ parser.add_argument("--allow_tf32", action="store_true",
81
+ help=(
82
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
83
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
84
+ ),
85
+ )
86
+ parser.add_argument("--report_to", type=str, default="wandb",
87
+ help=(
88
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
89
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
90
+ ),
91
+ )
92
+ parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"],)
93
+ parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
94
+ parser.add_argument("--set_grads_to_none", action="store_true",)
95
+
96
+ if input_args is not None:
97
+ args = parser.parse_args(input_args)
98
+ else:
99
+ args = parser.parse_args()
100
+
101
+ return args
102
+
103
+
104
+ def parse_args_unpaired_training():
105
+ """
106
+ Parses command-line arguments used for configuring an unpaired session (CycleGAN-Turbo).
107
+ This function sets up an argument parser to handle various training options.
108
+
109
+ Returns:
110
+ argparse.Namespace: The parsed command-line arguments.
111
+ """
112
+
113
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
114
+
115
+ # fixed random seed
116
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
117
+
118
+ # args for the loss function
119
+ parser.add_argument("--gan_disc_type", default="vagan_clip")
120
+ parser.add_argument("--gan_loss_type", default="multilevel_sigmoid")
121
+ parser.add_argument("--lambda_gan", default=0.5, type=float)
122
+ parser.add_argument("--lambda_idt", default=1, type=float)
123
+ parser.add_argument("--lambda_cycle", default=1, type=float)
124
+ parser.add_argument("--lambda_cycle_lpips", default=10.0, type=float)
125
+ parser.add_argument("--lambda_idt_lpips", default=1.0, type=float)
126
+
127
+ # args for dataset and dataloader options
128
+ parser.add_argument("--dataset_folder", required=True, type=str)
129
+ parser.add_argument("--train_img_prep", required=True)
130
+ parser.add_argument("--val_img_prep", required=True)
131
+ parser.add_argument("--dataloader_num_workers", type=int, default=0)
132
+ parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.")
133
+ parser.add_argument("--max_train_epochs", type=int, default=100)
134
+ parser.add_argument("--max_train_steps", type=int, default=None)
135
+
136
+ # args for the model
137
+ parser.add_argument("--pretrained_model_name_or_path", default="stabilityai/sd-turbo")
138
+ parser.add_argument("--revision", default=None, type=str)
139
+ parser.add_argument("--variant", default=None, type=str)
140
+ parser.add_argument("--lora_rank_unet", default=128, type=int)
141
+ parser.add_argument("--lora_rank_vae", default=4, type=int)
142
+
143
+ # args for validation and logging
144
+ parser.add_argument("--viz_freq", type=int, default=20)
145
+ parser.add_argument("--output_dir", type=str, required=True)
146
+ parser.add_argument("--report_to", type=str, default="wandb")
147
+ parser.add_argument("--tracker_project_name", type=str, required=True)
148
+ parser.add_argument("--validation_steps", type=int, default=500,)
149
+ parser.add_argument("--validation_num_images", type=int, default=-1, help="Number of images to use for validation. -1 to use all images.")
150
+ parser.add_argument("--checkpointing_steps", type=int, default=500)
151
+
152
+ # args for the optimization options
153
+ parser.add_argument("--learning_rate", type=float, default=5e-6,)
154
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
155
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
156
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
157
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
158
+ parser.add_argument("--max_grad_norm", default=10.0, type=float, help="Max gradient norm.")
159
+ parser.add_argument("--lr_scheduler", type=str, default="constant", help=(
160
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
161
+ ' "constant", "constant_with_warmup"]'
162
+ ),
163
+ )
164
+ parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler.")
165
+ parser.add_argument("--lr_num_cycles", type=int, default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.",)
166
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
167
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
168
+
169
+ # memory saving options
170
+ parser.add_argument("--allow_tf32", action="store_true",
171
+ help=(
172
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
173
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
174
+ ),
175
+ )
176
+ parser.add_argument("--gradient_checkpointing", action="store_true",
177
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.")
178
+ parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
179
+
180
+ args = parser.parse_args()
181
+ return args
182
+
183
+
184
+ def build_transform(image_prep):
185
+ """
186
+ Constructs a transformation pipeline based on the specified image preparation method.
187
+
188
+ Parameters:
189
+ - image_prep (str): A string describing the desired image preparation
190
+
191
+ Returns:
192
+ - torchvision.transforms.Compose: A composable sequence of transformations to be applied to images.
193
+ """
194
+ if image_prep == "resized_crop_512":
195
+ T = transforms.Compose([
196
+ transforms.Resize(512, interpolation=transforms.InterpolationMode.LANCZOS),
197
+ transforms.CenterCrop(512),
198
+ ])
199
+ elif image_prep == "resize_286_randomcrop_256x256_hflip":
200
+ T = transforms.Compose([
201
+ transforms.Resize((286, 286), interpolation=Image.LANCZOS),
202
+ transforms.RandomCrop((256, 256)),
203
+ transforms.RandomHorizontalFlip(),
204
+ ])
205
+ elif image_prep in ["resize_256", "resize_256x256"]:
206
+ T = transforms.Compose([
207
+ transforms.Resize((256, 256), interpolation=Image.LANCZOS)
208
+ ])
209
+ elif image_prep in ["resize_512", "resize_512x512"]:
210
+ T = transforms.Compose([
211
+ transforms.Resize((512, 512), interpolation=Image.LANCZOS)
212
+ ])
213
+ elif image_prep == "no_resize":
214
+ T = transforms.Lambda(lambda x: x)
215
+ return T
216
+
217
+
218
+ class PairedDataset(torch.utils.data.Dataset):
219
+ def __init__(self, dataset_folder, split, image_prep, tokenizer):
220
+ """
221
+ Itialize the paired dataset object for loading and transforming paired data samples
222
+ from specified dataset folders.
223
+
224
+ This constructor sets up the paths to input and output folders based on the specified 'split',
225
+ loads the captions (or prompts) for the input images, and prepares the transformations and
226
+ tokenizer to be applied on the data.
227
+
228
+ Parameters:
229
+ - dataset_folder (str): The root folder containing the dataset, expected to include
230
+ sub-folders for different splits (e.g., 'train_A', 'train_B').
231
+ - split (str): The dataset split to use ('train' or 'test'), used to select the appropriate
232
+ sub-folders and caption files within the dataset folder.
233
+ - image_prep (str): The image preprocessing transformation to apply to each image.
234
+ - tokenizer: The tokenizer used for tokenizing the captions (or prompts).
235
+ """
236
+ super().__init__()
237
+ if split == "train":
238
+ self.input_folder = os.path.join(dataset_folder, "train_A")
239
+ self.output_folder = os.path.join(dataset_folder, "train_B")
240
+ captions = os.path.join(dataset_folder, "train_prompts.json")
241
+ elif split == "test":
242
+ self.input_folder = os.path.join(dataset_folder, "test_A")
243
+ self.output_folder = os.path.join(dataset_folder, "test_B")
244
+ captions = os.path.join(dataset_folder, "test_prompts.json")
245
+ with open(captions, "r") as f:
246
+ self.captions = json.load(f)
247
+ self.img_names = list(self.captions.keys())
248
+ self.T = build_transform(image_prep)
249
+ self.tokenizer = tokenizer
250
+
251
+ def __len__(self):
252
+ """
253
+ Returns:
254
+ int: The total number of items in the dataset.
255
+ """
256
+ return len(self.captions)
257
+
258
+ def __getitem__(self, idx):
259
+ """
260
+ Retrieves a dataset item given its index. Each item consists of an input image,
261
+ its corresponding output image, the captions associated with the input image,
262
+ and the tokenized form of this caption.
263
+
264
+ This method performs the necessary preprocessing on both the input and output images,
265
+ including scaling and normalization, as well as tokenizing the caption using a provided tokenizer.
266
+
267
+ Parameters:
268
+ - idx (int): The index of the item to retrieve.
269
+
270
+ Returns:
271
+ dict: A dictionary containing the following key-value pairs:
272
+ - "output_pixel_values": a tensor of the preprocessed output image with pixel values
273
+ scaled to [-1, 1].
274
+ - "conditioning_pixel_values": a tensor of the preprocessed input image with pixel values
275
+ scaled to [0, 1].
276
+ - "caption": the text caption.
277
+ - "input_ids": a tensor of the tokenized caption.
278
+
279
+ Note:
280
+ The actual preprocessing steps (scaling and normalization) for images are defined externally
281
+ and passed to this class through the `image_prep` parameter during initialization. The
282
+ tokenization process relies on the `tokenizer` also provided at initialization, which
283
+ should be compatible with the models intended to be used with this dataset.
284
+ """
285
+ img_name = self.img_names[idx]
286
+ input_img = Image.open(os.path.join(self.input_folder, img_name))
287
+ output_img = Image.open(os.path.join(self.output_folder, img_name))
288
+ caption = self.captions[img_name]
289
+
290
+ # input images scaled to 0,1
291
+ img_t = self.T(input_img)
292
+ img_t = F.to_tensor(img_t)
293
+ # output images scaled to -1,1
294
+ output_t = self.T(output_img)
295
+ output_t = F.to_tensor(output_t)
296
+ output_t = F.normalize(output_t, mean=[0.5], std=[0.5])
297
+
298
+ input_ids = self.tokenizer(
299
+ caption, max_length=self.tokenizer.model_max_length,
300
+ padding="max_length", truncation=True, return_tensors="pt"
301
+ ).input_ids
302
+
303
+ return {
304
+ "output_pixel_values": output_t,
305
+ "conditioning_pixel_values": img_t,
306
+ "caption": caption,
307
+ "input_ids": input_ids,
308
+ }
309
+
310
+
311
+ class UnpairedDataset(torch.utils.data.Dataset):
312
+ def __init__(self, dataset_folder, split, image_prep, tokenizer):
313
+ """
314
+ A dataset class for loading unpaired data samples from two distinct domains (source and target),
315
+ typically used in unsupervised learning tasks like image-to-image translation.
316
+
317
+ The class supports loading images from specified dataset folders, applying predefined image
318
+ preprocessing transformations, and utilizing fixed textual prompts (captions) for each domain,
319
+ tokenized using a provided tokenizer.
320
+
321
+ Parameters:
322
+ - dataset_folder (str): Base directory of the dataset containing subdirectories (train_A, train_B, test_A, test_B)
323
+ - split (str): Indicates the dataset split to use. Expected values are 'train' or 'test'.
324
+ - image_prep (str): he image preprocessing transformation to apply to each image.
325
+ - tokenizer: The tokenizer used for tokenizing the captions (or prompts).
326
+ """
327
+ super().__init__()
328
+ if split == "train":
329
+ self.source_folder = os.path.join(dataset_folder, "train_A")
330
+ self.target_folder = os.path.join(dataset_folder, "train_B")
331
+ elif split == "test":
332
+ self.source_folder = os.path.join(dataset_folder, "test_A")
333
+ self.target_folder = os.path.join(dataset_folder, "test_B")
334
+ self.tokenizer = tokenizer
335
+ with open(os.path.join(dataset_folder, "fixed_prompt_a.txt"), "r") as f:
336
+ self.fixed_caption_src = f.read().strip()
337
+ self.input_ids_src = self.tokenizer(
338
+ self.fixed_caption_src, max_length=self.tokenizer.model_max_length,
339
+ padding="max_length", truncation=True, return_tensors="pt"
340
+ ).input_ids
341
+
342
+ with open(os.path.join(dataset_folder, "fixed_prompt_b.txt"), "r") as f:
343
+ self.fixed_caption_tgt = f.read().strip()
344
+ self.input_ids_tgt = self.tokenizer(
345
+ self.fixed_caption_tgt, max_length=self.tokenizer.model_max_length,
346
+ padding="max_length", truncation=True, return_tensors="pt"
347
+ ).input_ids
348
+ # find all images in the source and target folders with all IMG extensions
349
+ self.l_imgs_src = []
350
+ for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif"]:
351
+ self.l_imgs_src.extend(glob(os.path.join(self.source_folder, ext)))
352
+ self.l_imgs_tgt = []
353
+ for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif"]:
354
+ self.l_imgs_tgt.extend(glob(os.path.join(self.target_folder, ext)))
355
+ self.T = build_transform(image_prep)
356
+
357
+ def __len__(self):
358
+ """
359
+ Returns:
360
+ int: The total number of items in the dataset.
361
+ """
362
+ return len(self.l_imgs_src) + len(self.l_imgs_tgt)
363
+
364
+ def __getitem__(self, index):
365
+ """
366
+ Fetches a pair of unaligned images from the source and target domains along with their
367
+ corresponding tokenized captions.
368
+
369
+ For the source domain, if the requested index is within the range of available images,
370
+ the specific image at that index is chosen. If the index exceeds the number of source
371
+ images, a random source image is selected. For the target domain,
372
+ an image is always randomly selected, irrespective of the index, to maintain the
373
+ unpaired nature of the dataset.
374
+
375
+ Both images are preprocessed according to the specified image transformation `T`, and normalized.
376
+ The fixed captions for both domains
377
+ are included along with their tokenized forms.
378
+
379
+ Parameters:
380
+ - index (int): The index of the source image to retrieve.
381
+
382
+ Returns:
383
+ dict: A dictionary containing processed data for a single training example, with the following keys:
384
+ - "pixel_values_src": The processed source image
385
+ - "pixel_values_tgt": The processed target image
386
+ - "caption_src": The fixed caption of the source domain.
387
+ - "caption_tgt": The fixed caption of the target domain.
388
+ - "input_ids_src": The source domain's fixed caption tokenized.
389
+ - "input_ids_tgt": The target domain's fixed caption tokenized.
390
+ """
391
+ if index < len(self.l_imgs_src):
392
+ img_path_src = self.l_imgs_src[index]
393
+ else:
394
+ img_path_src = random.choice(self.l_imgs_src)
395
+ img_path_tgt = random.choice(self.l_imgs_tgt)
396
+ img_pil_src = Image.open(img_path_src).convert("RGB")
397
+ img_pil_tgt = Image.open(img_path_tgt).convert("RGB")
398
+ img_t_src = F.to_tensor(self.T(img_pil_src))
399
+ img_t_tgt = F.to_tensor(self.T(img_pil_tgt))
400
+ img_t_src = F.normalize(img_t_src, mean=[0.5], std=[0.5])
401
+ img_t_tgt = F.normalize(img_t_tgt, mean=[0.5], std=[0.5])
402
+ return {
403
+ "pixel_values_src": img_t_src,
404
+ "pixel_values_tgt": img_t_tgt,
405
+ "caption_src": self.fixed_caption_src,
406
+ "caption_tgt": self.fixed_caption_tgt,
407
+ "input_ids_src": self.input_ids_src,
408
+ "input_ids_tgt": self.input_ids_tgt,
409
+ }