LarryTsai commited on
Commit
8b0e3d0
1 Parent(s): 4cc135a

Update text_encoder/config.json

Browse files
scheduler/scheduler_config.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "_class_name": "EulerAncestralDiscreteScheduler",
3
- "_diffusers_version": "0.30.3",
4
  "beta_end": 0.02,
5
  "beta_schedule": "linear",
6
  "beta_start": 0.0001,
 
1
  {
2
  "_class_name": "EulerAncestralDiscreteScheduler",
3
+ "_diffusers_version": "0.28.0",
4
  "beta_end": 0.02,
5
  "beta_schedule": "linear",
6
  "beta_start": 0.0001,
text_encoder/config.json CHANGED
@@ -1,9 +1,7 @@
1
  {
2
- "_name_or_path": "/cpfs/data/user/larrytsai/Projects/Yi-VG/allegro/text_encoder",
3
  "architectures": [
4
  "T5EncoderModel"
5
  ],
6
- "classifier_dropout": 0.0,
7
  "d_ff": 10240,
8
  "d_kv": 64,
9
  "d_model": 4096,
@@ -26,7 +24,7 @@
26
  "relative_attention_num_buckets": 32,
27
  "tie_word_embeddings": false,
28
  "torch_dtype": "float32",
29
- "transformers_version": "4.40.1",
30
  "use_cache": true,
31
  "vocab_size": 32128
32
  }
 
1
  {
 
2
  "architectures": [
3
  "T5EncoderModel"
4
  ],
 
5
  "d_ff": 10240,
6
  "d_kv": 64,
7
  "d_model": 4096,
 
24
  "relative_attention_num_buckets": 32,
25
  "tie_word_embeddings": false,
26
  "torch_dtype": "float32",
27
+ "transformers_version": "4.21.1",
28
  "use_cache": true,
29
  "vocab_size": 32128
30
  }
text_encoder/model-00001-of-00004.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7a68b2c8c080696a10109612a649bc69330991ecfea65930ccfdfbdb011f2686
3
- size 4989319680
 
 
 
 
text_encoder/model-00002-of-00004.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b8ed6556d7507e38af5b428c605fb2a6f2bdb7e80bd481308b865f7a40c551ca
3
- size 4999830656
 
 
 
 
text_encoder/model-00003-of-00004.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c831635f83041f83faf0024b39c6ecb21b45d70dd38a63ea5bac6c7c6e5e558c
3
- size 4865612720
 
 
 
 
text_encoder/model-00004-of-00004.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:02a5f2d69205be92ad48fe5d712d38c2ff55627969116aeffc58bd75a28da468
3
- size 4194506688
 
 
 
 
text_encoder/model.safetensors.index.json DELETED
@@ -1,226 +0,0 @@
1
- {
2
- "metadata": {
3
- "total_size": 19049242624
4
- },
5
- "weight_map": {
6
- "encoder.block.0.layer.0.SelfAttention.k.weight": "model-00001-of-00004.safetensors",
7
- "encoder.block.0.layer.0.SelfAttention.o.weight": "model-00001-of-00004.safetensors",
8
- "encoder.block.0.layer.0.SelfAttention.q.weight": "model-00001-of-00004.safetensors",
9
- "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight": "model-00001-of-00004.safetensors",
10
- "encoder.block.0.layer.0.SelfAttention.v.weight": "model-00001-of-00004.safetensors",
11
- "encoder.block.0.layer.0.layer_norm.weight": "model-00001-of-00004.safetensors",
12
- "encoder.block.0.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00004.safetensors",
13
- "encoder.block.0.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00004.safetensors",
14
- "encoder.block.0.layer.1.DenseReluDense.wo.weight": "model-00001-of-00004.safetensors",
15
- "encoder.block.0.layer.1.layer_norm.weight": "model-00001-of-00004.safetensors",
16
- "encoder.block.1.layer.0.SelfAttention.k.weight": "model-00001-of-00004.safetensors",
17
- "encoder.block.1.layer.0.SelfAttention.o.weight": "model-00001-of-00004.safetensors",
18
- "encoder.block.1.layer.0.SelfAttention.q.weight": "model-00001-of-00004.safetensors",
19
- "encoder.block.1.layer.0.SelfAttention.v.weight": "model-00001-of-00004.safetensors",
20
- "encoder.block.1.layer.0.layer_norm.weight": "model-00001-of-00004.safetensors",
21
- "encoder.block.1.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00004.safetensors",
22
- "encoder.block.1.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00004.safetensors",
23
- "encoder.block.1.layer.1.DenseReluDense.wo.weight": "model-00001-of-00004.safetensors",
24
- "encoder.block.1.layer.1.layer_norm.weight": "model-00001-of-00004.safetensors",
25
- "encoder.block.10.layer.0.SelfAttention.k.weight": "model-00002-of-00004.safetensors",
26
- "encoder.block.10.layer.0.SelfAttention.o.weight": "model-00002-of-00004.safetensors",
27
- "encoder.block.10.layer.0.SelfAttention.q.weight": "model-00002-of-00004.safetensors",
28
- "encoder.block.10.layer.0.SelfAttention.v.weight": "model-00002-of-00004.safetensors",
29
- "encoder.block.10.layer.0.layer_norm.weight": "model-00002-of-00004.safetensors",
30
- "encoder.block.10.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00004.safetensors",
31
- "encoder.block.10.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00004.safetensors",
32
- "encoder.block.10.layer.1.DenseReluDense.wo.weight": "model-00002-of-00004.safetensors",
33
- "encoder.block.10.layer.1.layer_norm.weight": "model-00002-of-00004.safetensors",
34
- "encoder.block.11.layer.0.SelfAttention.k.weight": "model-00002-of-00004.safetensors",
35
- "encoder.block.11.layer.0.SelfAttention.o.weight": "model-00002-of-00004.safetensors",
36
- "encoder.block.11.layer.0.SelfAttention.q.weight": "model-00002-of-00004.safetensors",
37
- "encoder.block.11.layer.0.SelfAttention.v.weight": "model-00002-of-00004.safetensors",
38
- "encoder.block.11.layer.0.layer_norm.weight": "model-00002-of-00004.safetensors",
39
- "encoder.block.11.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00004.safetensors",
40
- "encoder.block.11.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00004.safetensors",
41
- "encoder.block.11.layer.1.DenseReluDense.wo.weight": "model-00002-of-00004.safetensors",
42
- "encoder.block.11.layer.1.layer_norm.weight": "model-00002-of-00004.safetensors",
43
- "encoder.block.12.layer.0.SelfAttention.k.weight": "model-00002-of-00004.safetensors",
44
- "encoder.block.12.layer.0.SelfAttention.o.weight": "model-00003-of-00004.safetensors",
45
- "encoder.block.12.layer.0.SelfAttention.q.weight": "model-00002-of-00004.safetensors",
46
- "encoder.block.12.layer.0.SelfAttention.v.weight": "model-00002-of-00004.safetensors",
47
- "encoder.block.12.layer.0.layer_norm.weight": "model-00003-of-00004.safetensors",
48
- "encoder.block.12.layer.1.DenseReluDense.wi_0.weight": "model-00003-of-00004.safetensors",
49
- "encoder.block.12.layer.1.DenseReluDense.wi_1.weight": "model-00003-of-00004.safetensors",
50
- "encoder.block.12.layer.1.DenseReluDense.wo.weight": "model-00003-of-00004.safetensors",
51
- "encoder.block.12.layer.1.layer_norm.weight": "model-00003-of-00004.safetensors",
52
- "encoder.block.13.layer.0.SelfAttention.k.weight": "model-00003-of-00004.safetensors",
53
- "encoder.block.13.layer.0.SelfAttention.o.weight": "model-00003-of-00004.safetensors",
54
- "encoder.block.13.layer.0.SelfAttention.q.weight": "model-00003-of-00004.safetensors",
55
- "encoder.block.13.layer.0.SelfAttention.v.weight": "model-00003-of-00004.safetensors",
56
- "encoder.block.13.layer.0.layer_norm.weight": "model-00003-of-00004.safetensors",
57
- "encoder.block.13.layer.1.DenseReluDense.wi_0.weight": "model-00003-of-00004.safetensors",
58
- "encoder.block.13.layer.1.DenseReluDense.wi_1.weight": "model-00003-of-00004.safetensors",
59
- "encoder.block.13.layer.1.DenseReluDense.wo.weight": "model-00003-of-00004.safetensors",
60
- "encoder.block.13.layer.1.layer_norm.weight": "model-00003-of-00004.safetensors",
61
- "encoder.block.14.layer.0.SelfAttention.k.weight": "model-00003-of-00004.safetensors",
62
- "encoder.block.14.layer.0.SelfAttention.o.weight": "model-00003-of-00004.safetensors",
63
- "encoder.block.14.layer.0.SelfAttention.q.weight": "model-00003-of-00004.safetensors",
64
- "encoder.block.14.layer.0.SelfAttention.v.weight": "model-00003-of-00004.safetensors",
65
- "encoder.block.14.layer.0.layer_norm.weight": "model-00003-of-00004.safetensors",
66
- "encoder.block.14.layer.1.DenseReluDense.wi_0.weight": "model-00003-of-00004.safetensors",
67
- "encoder.block.14.layer.1.DenseReluDense.wi_1.weight": "model-00003-of-00004.safetensors",
68
- "encoder.block.14.layer.1.DenseReluDense.wo.weight": "model-00003-of-00004.safetensors",
69
- "encoder.block.14.layer.1.layer_norm.weight": "model-00003-of-00004.safetensors",
70
- "encoder.block.15.layer.0.SelfAttention.k.weight": "model-00003-of-00004.safetensors",
71
- "encoder.block.15.layer.0.SelfAttention.o.weight": "model-00003-of-00004.safetensors",
72
- "encoder.block.15.layer.0.SelfAttention.q.weight": "model-00003-of-00004.safetensors",
73
- "encoder.block.15.layer.0.SelfAttention.v.weight": "model-00003-of-00004.safetensors",
74
- "encoder.block.15.layer.0.layer_norm.weight": "model-00003-of-00004.safetensors",
75
- "encoder.block.15.layer.1.DenseReluDense.wi_0.weight": "model-00003-of-00004.safetensors",
76
- "encoder.block.15.layer.1.DenseReluDense.wi_1.weight": "model-00003-of-00004.safetensors",
77
- "encoder.block.15.layer.1.DenseReluDense.wo.weight": "model-00003-of-00004.safetensors",
78
- "encoder.block.15.layer.1.layer_norm.weight": "model-00003-of-00004.safetensors",
79
- "encoder.block.16.layer.0.SelfAttention.k.weight": "model-00003-of-00004.safetensors",
80
- "encoder.block.16.layer.0.SelfAttention.o.weight": "model-00003-of-00004.safetensors",
81
- "encoder.block.16.layer.0.SelfAttention.q.weight": "model-00003-of-00004.safetensors",
82
- "encoder.block.16.layer.0.SelfAttention.v.weight": "model-00003-of-00004.safetensors",
83
- "encoder.block.16.layer.0.layer_norm.weight": "model-00003-of-00004.safetensors",
84
- "encoder.block.16.layer.1.DenseReluDense.wi_0.weight": "model-00003-of-00004.safetensors",
85
- "encoder.block.16.layer.1.DenseReluDense.wi_1.weight": "model-00003-of-00004.safetensors",
86
- "encoder.block.16.layer.1.DenseReluDense.wo.weight": "model-00003-of-00004.safetensors",
87
- "encoder.block.16.layer.1.layer_norm.weight": "model-00003-of-00004.safetensors",
88
- "encoder.block.17.layer.0.SelfAttention.k.weight": "model-00003-of-00004.safetensors",
89
- "encoder.block.17.layer.0.SelfAttention.o.weight": "model-00003-of-00004.safetensors",
90
- "encoder.block.17.layer.0.SelfAttention.q.weight": "model-00003-of-00004.safetensors",
91
- "encoder.block.17.layer.0.SelfAttention.v.weight": "model-00003-of-00004.safetensors",
92
- "encoder.block.17.layer.0.layer_norm.weight": "model-00003-of-00004.safetensors",
93
- "encoder.block.17.layer.1.DenseReluDense.wi_0.weight": "model-00003-of-00004.safetensors",
94
- "encoder.block.17.layer.1.DenseReluDense.wi_1.weight": "model-00003-of-00004.safetensors",
95
- "encoder.block.17.layer.1.DenseReluDense.wo.weight": "model-00003-of-00004.safetensors",
96
- "encoder.block.17.layer.1.layer_norm.weight": "model-00003-of-00004.safetensors",
97
- "encoder.block.18.layer.0.SelfAttention.k.weight": "model-00003-of-00004.safetensors",
98
- "encoder.block.18.layer.0.SelfAttention.o.weight": "model-00003-of-00004.safetensors",
99
- "encoder.block.18.layer.0.SelfAttention.q.weight": "model-00003-of-00004.safetensors",
100
- "encoder.block.18.layer.0.SelfAttention.v.weight": "model-00003-of-00004.safetensors",
101
- "encoder.block.18.layer.0.layer_norm.weight": "model-00003-of-00004.safetensors",
102
- "encoder.block.18.layer.1.DenseReluDense.wi_0.weight": "model-00003-of-00004.safetensors",
103
- "encoder.block.18.layer.1.DenseReluDense.wi_1.weight": "model-00004-of-00004.safetensors",
104
- "encoder.block.18.layer.1.DenseReluDense.wo.weight": "model-00004-of-00004.safetensors",
105
- "encoder.block.18.layer.1.layer_norm.weight": "model-00004-of-00004.safetensors",
106
- "encoder.block.19.layer.0.SelfAttention.k.weight": "model-00004-of-00004.safetensors",
107
- "encoder.block.19.layer.0.SelfAttention.o.weight": "model-00004-of-00004.safetensors",
108
- "encoder.block.19.layer.0.SelfAttention.q.weight": "model-00004-of-00004.safetensors",
109
- "encoder.block.19.layer.0.SelfAttention.v.weight": "model-00004-of-00004.safetensors",
110
- "encoder.block.19.layer.0.layer_norm.weight": "model-00004-of-00004.safetensors",
111
- "encoder.block.19.layer.1.DenseReluDense.wi_0.weight": "model-00004-of-00004.safetensors",
112
- "encoder.block.19.layer.1.DenseReluDense.wi_1.weight": "model-00004-of-00004.safetensors",
113
- "encoder.block.19.layer.1.DenseReluDense.wo.weight": "model-00004-of-00004.safetensors",
114
- "encoder.block.19.layer.1.layer_norm.weight": "model-00004-of-00004.safetensors",
115
- "encoder.block.2.layer.0.SelfAttention.k.weight": "model-00001-of-00004.safetensors",
116
- "encoder.block.2.layer.0.SelfAttention.o.weight": "model-00001-of-00004.safetensors",
117
- "encoder.block.2.layer.0.SelfAttention.q.weight": "model-00001-of-00004.safetensors",
118
- "encoder.block.2.layer.0.SelfAttention.v.weight": "model-00001-of-00004.safetensors",
119
- "encoder.block.2.layer.0.layer_norm.weight": "model-00001-of-00004.safetensors",
120
- "encoder.block.2.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00004.safetensors",
121
- "encoder.block.2.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00004.safetensors",
122
- "encoder.block.2.layer.1.DenseReluDense.wo.weight": "model-00001-of-00004.safetensors",
123
- "encoder.block.2.layer.1.layer_norm.weight": "model-00001-of-00004.safetensors",
124
- "encoder.block.20.layer.0.SelfAttention.k.weight": "model-00004-of-00004.safetensors",
125
- "encoder.block.20.layer.0.SelfAttention.o.weight": "model-00004-of-00004.safetensors",
126
- "encoder.block.20.layer.0.SelfAttention.q.weight": "model-00004-of-00004.safetensors",
127
- "encoder.block.20.layer.0.SelfAttention.v.weight": "model-00004-of-00004.safetensors",
128
- "encoder.block.20.layer.0.layer_norm.weight": "model-00004-of-00004.safetensors",
129
- "encoder.block.20.layer.1.DenseReluDense.wi_0.weight": "model-00004-of-00004.safetensors",
130
- "encoder.block.20.layer.1.DenseReluDense.wi_1.weight": "model-00004-of-00004.safetensors",
131
- "encoder.block.20.layer.1.DenseReluDense.wo.weight": "model-00004-of-00004.safetensors",
132
- "encoder.block.20.layer.1.layer_norm.weight": "model-00004-of-00004.safetensors",
133
- "encoder.block.21.layer.0.SelfAttention.k.weight": "model-00004-of-00004.safetensors",
134
- "encoder.block.21.layer.0.SelfAttention.o.weight": "model-00004-of-00004.safetensors",
135
- "encoder.block.21.layer.0.SelfAttention.q.weight": "model-00004-of-00004.safetensors",
136
- "encoder.block.21.layer.0.SelfAttention.v.weight": "model-00004-of-00004.safetensors",
137
- "encoder.block.21.layer.0.layer_norm.weight": "model-00004-of-00004.safetensors",
138
- "encoder.block.21.layer.1.DenseReluDense.wi_0.weight": "model-00004-of-00004.safetensors",
139
- "encoder.block.21.layer.1.DenseReluDense.wi_1.weight": "model-00004-of-00004.safetensors",
140
- "encoder.block.21.layer.1.DenseReluDense.wo.weight": "model-00004-of-00004.safetensors",
141
- "encoder.block.21.layer.1.layer_norm.weight": "model-00004-of-00004.safetensors",
142
- "encoder.block.22.layer.0.SelfAttention.k.weight": "model-00004-of-00004.safetensors",
143
- "encoder.block.22.layer.0.SelfAttention.o.weight": "model-00004-of-00004.safetensors",
144
- "encoder.block.22.layer.0.SelfAttention.q.weight": "model-00004-of-00004.safetensors",
145
- "encoder.block.22.layer.0.SelfAttention.v.weight": "model-00004-of-00004.safetensors",
146
- "encoder.block.22.layer.0.layer_norm.weight": "model-00004-of-00004.safetensors",
147
- "encoder.block.22.layer.1.DenseReluDense.wi_0.weight": "model-00004-of-00004.safetensors",
148
- "encoder.block.22.layer.1.DenseReluDense.wi_1.weight": "model-00004-of-00004.safetensors",
149
- "encoder.block.22.layer.1.DenseReluDense.wo.weight": "model-00004-of-00004.safetensors",
150
- "encoder.block.22.layer.1.layer_norm.weight": "model-00004-of-00004.safetensors",
151
- "encoder.block.23.layer.0.SelfAttention.k.weight": "model-00004-of-00004.safetensors",
152
- "encoder.block.23.layer.0.SelfAttention.o.weight": "model-00004-of-00004.safetensors",
153
- "encoder.block.23.layer.0.SelfAttention.q.weight": "model-00004-of-00004.safetensors",
154
- "encoder.block.23.layer.0.SelfAttention.v.weight": "model-00004-of-00004.safetensors",
155
- "encoder.block.23.layer.0.layer_norm.weight": "model-00004-of-00004.safetensors",
156
- "encoder.block.23.layer.1.DenseReluDense.wi_0.weight": "model-00004-of-00004.safetensors",
157
- "encoder.block.23.layer.1.DenseReluDense.wi_1.weight": "model-00004-of-00004.safetensors",
158
- "encoder.block.23.layer.1.DenseReluDense.wo.weight": "model-00004-of-00004.safetensors",
159
- "encoder.block.23.layer.1.layer_norm.weight": "model-00004-of-00004.safetensors",
160
- "encoder.block.3.layer.0.SelfAttention.k.weight": "model-00001-of-00004.safetensors",
161
- "encoder.block.3.layer.0.SelfAttention.o.weight": "model-00001-of-00004.safetensors",
162
- "encoder.block.3.layer.0.SelfAttention.q.weight": "model-00001-of-00004.safetensors",
163
- "encoder.block.3.layer.0.SelfAttention.v.weight": "model-00001-of-00004.safetensors",
164
- "encoder.block.3.layer.0.layer_norm.weight": "model-00001-of-00004.safetensors",
165
- "encoder.block.3.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00004.safetensors",
166
- "encoder.block.3.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00004.safetensors",
167
- "encoder.block.3.layer.1.DenseReluDense.wo.weight": "model-00001-of-00004.safetensors",
168
- "encoder.block.3.layer.1.layer_norm.weight": "model-00001-of-00004.safetensors",
169
- "encoder.block.4.layer.0.SelfAttention.k.weight": "model-00001-of-00004.safetensors",
170
- "encoder.block.4.layer.0.SelfAttention.o.weight": "model-00001-of-00004.safetensors",
171
- "encoder.block.4.layer.0.SelfAttention.q.weight": "model-00001-of-00004.safetensors",
172
- "encoder.block.4.layer.0.SelfAttention.v.weight": "model-00001-of-00004.safetensors",
173
- "encoder.block.4.layer.0.layer_norm.weight": "model-00001-of-00004.safetensors",
174
- "encoder.block.4.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00004.safetensors",
175
- "encoder.block.4.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00004.safetensors",
176
- "encoder.block.4.layer.1.DenseReluDense.wo.weight": "model-00001-of-00004.safetensors",
177
- "encoder.block.4.layer.1.layer_norm.weight": "model-00001-of-00004.safetensors",
178
- "encoder.block.5.layer.0.SelfAttention.k.weight": "model-00001-of-00004.safetensors",
179
- "encoder.block.5.layer.0.SelfAttention.o.weight": "model-00001-of-00004.safetensors",
180
- "encoder.block.5.layer.0.SelfAttention.q.weight": "model-00001-of-00004.safetensors",
181
- "encoder.block.5.layer.0.SelfAttention.v.weight": "model-00001-of-00004.safetensors",
182
- "encoder.block.5.layer.0.layer_norm.weight": "model-00001-of-00004.safetensors",
183
- "encoder.block.5.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00004.safetensors",
184
- "encoder.block.5.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00004.safetensors",
185
- "encoder.block.5.layer.1.DenseReluDense.wo.weight": "model-00002-of-00004.safetensors",
186
- "encoder.block.5.layer.1.layer_norm.weight": "model-00002-of-00004.safetensors",
187
- "encoder.block.6.layer.0.SelfAttention.k.weight": "model-00002-of-00004.safetensors",
188
- "encoder.block.6.layer.0.SelfAttention.o.weight": "model-00002-of-00004.safetensors",
189
- "encoder.block.6.layer.0.SelfAttention.q.weight": "model-00002-of-00004.safetensors",
190
- "encoder.block.6.layer.0.SelfAttention.v.weight": "model-00002-of-00004.safetensors",
191
- "encoder.block.6.layer.0.layer_norm.weight": "model-00002-of-00004.safetensors",
192
- "encoder.block.6.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00004.safetensors",
193
- "encoder.block.6.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00004.safetensors",
194
- "encoder.block.6.layer.1.DenseReluDense.wo.weight": "model-00002-of-00004.safetensors",
195
- "encoder.block.6.layer.1.layer_norm.weight": "model-00002-of-00004.safetensors",
196
- "encoder.block.7.layer.0.SelfAttention.k.weight": "model-00002-of-00004.safetensors",
197
- "encoder.block.7.layer.0.SelfAttention.o.weight": "model-00002-of-00004.safetensors",
198
- "encoder.block.7.layer.0.SelfAttention.q.weight": "model-00002-of-00004.safetensors",
199
- "encoder.block.7.layer.0.SelfAttention.v.weight": "model-00002-of-00004.safetensors",
200
- "encoder.block.7.layer.0.layer_norm.weight": "model-00002-of-00004.safetensors",
201
- "encoder.block.7.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00004.safetensors",
202
- "encoder.block.7.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00004.safetensors",
203
- "encoder.block.7.layer.1.DenseReluDense.wo.weight": "model-00002-of-00004.safetensors",
204
- "encoder.block.7.layer.1.layer_norm.weight": "model-00002-of-00004.safetensors",
205
- "encoder.block.8.layer.0.SelfAttention.k.weight": "model-00002-of-00004.safetensors",
206
- "encoder.block.8.layer.0.SelfAttention.o.weight": "model-00002-of-00004.safetensors",
207
- "encoder.block.8.layer.0.SelfAttention.q.weight": "model-00002-of-00004.safetensors",
208
- "encoder.block.8.layer.0.SelfAttention.v.weight": "model-00002-of-00004.safetensors",
209
- "encoder.block.8.layer.0.layer_norm.weight": "model-00002-of-00004.safetensors",
210
- "encoder.block.8.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00004.safetensors",
211
- "encoder.block.8.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00004.safetensors",
212
- "encoder.block.8.layer.1.DenseReluDense.wo.weight": "model-00002-of-00004.safetensors",
213
- "encoder.block.8.layer.1.layer_norm.weight": "model-00002-of-00004.safetensors",
214
- "encoder.block.9.layer.0.SelfAttention.k.weight": "model-00002-of-00004.safetensors",
215
- "encoder.block.9.layer.0.SelfAttention.o.weight": "model-00002-of-00004.safetensors",
216
- "encoder.block.9.layer.0.SelfAttention.q.weight": "model-00002-of-00004.safetensors",
217
- "encoder.block.9.layer.0.SelfAttention.v.weight": "model-00002-of-00004.safetensors",
218
- "encoder.block.9.layer.0.layer_norm.weight": "model-00002-of-00004.safetensors",
219
- "encoder.block.9.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00004.safetensors",
220
- "encoder.block.9.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00004.safetensors",
221
- "encoder.block.9.layer.1.DenseReluDense.wo.weight": "model-00002-of-00004.safetensors",
222
- "encoder.block.9.layer.1.layer_norm.weight": "model-00002-of-00004.safetensors",
223
- "encoder.final_layer_norm.weight": "model-00004-of-00004.safetensors",
224
- "shared.weight": "model-00001-of-00004.safetensors"
225
- }
226
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
transformer/config.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "_class_name": "AllegroTransformer3DModel",
3
- "_diffusers_version": "0.30.3",
4
  "activation_fn": "gelu-approximate",
5
  "attention_bias": true,
6
  "attention_head_dim": 96,
 
1
  {
2
  "_class_name": "AllegroTransformer3DModel",
3
+ "_diffusers_version": "0.28.0",
4
  "activation_fn": "gelu-approximate",
5
  "attention_bias": true,
6
  "attention_head_dim": 96,
transformer/diffusion_pytorch_model-00001-of-00002.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:566c682d40b99cdf07b351a4ec57a01f5469dc6344dd1eca38939314d5f635bc
3
- size 9985256872
 
 
 
 
transformer/diffusion_pytorch_model-00002-of-00002.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3b1cffec0f067b2fb0bc0b6f228cafddd87d04e23772c8d5a4c9f40f8c9719eb
3
- size 1102452560
 
 
 
 
transformer/diffusion_pytorch_model.safetensors.index.json DELETED
@@ -1,694 +0,0 @@
1
- {
2
- "metadata": {
3
- "total_size": 11087631424
4
- },
5
- "weight_map": {
6
- "adaln_single.emb.timestep_embedder.linear_1.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
7
- "adaln_single.emb.timestep_embedder.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
8
- "adaln_single.emb.timestep_embedder.linear_2.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
9
- "adaln_single.emb.timestep_embedder.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
10
- "adaln_single.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
11
- "adaln_single.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
12
- "caption_projection.linear_1.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
13
- "caption_projection.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
14
- "caption_projection.linear_2.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
15
- "caption_projection.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
16
- "pos_embed.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
17
- "pos_embed.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
18
- "proj_out.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
19
- "proj_out.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
20
- "scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
21
- "transformer_blocks.0.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
22
- "transformer_blocks.0.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
23
- "transformer_blocks.0.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
24
- "transformer_blocks.0.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
25
- "transformer_blocks.0.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
26
- "transformer_blocks.0.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
27
- "transformer_blocks.0.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
28
- "transformer_blocks.0.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
29
- "transformer_blocks.0.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
30
- "transformer_blocks.0.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
31
- "transformer_blocks.0.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
32
- "transformer_blocks.0.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
33
- "transformer_blocks.0.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
34
- "transformer_blocks.0.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
35
- "transformer_blocks.0.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
36
- "transformer_blocks.0.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
37
- "transformer_blocks.0.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
38
- "transformer_blocks.0.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
39
- "transformer_blocks.0.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
40
- "transformer_blocks.0.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
41
- "transformer_blocks.0.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
42
- "transformer_blocks.1.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
43
- "transformer_blocks.1.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
44
- "transformer_blocks.1.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
45
- "transformer_blocks.1.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
46
- "transformer_blocks.1.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
47
- "transformer_blocks.1.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
48
- "transformer_blocks.1.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
49
- "transformer_blocks.1.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
50
- "transformer_blocks.1.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
51
- "transformer_blocks.1.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
52
- "transformer_blocks.1.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
53
- "transformer_blocks.1.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
54
- "transformer_blocks.1.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
55
- "transformer_blocks.1.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
56
- "transformer_blocks.1.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
57
- "transformer_blocks.1.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
58
- "transformer_blocks.1.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
59
- "transformer_blocks.1.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
60
- "transformer_blocks.1.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
61
- "transformer_blocks.1.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
62
- "transformer_blocks.1.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
63
- "transformer_blocks.10.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
64
- "transformer_blocks.10.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
65
- "transformer_blocks.10.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
66
- "transformer_blocks.10.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
67
- "transformer_blocks.10.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
68
- "transformer_blocks.10.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
69
- "transformer_blocks.10.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
70
- "transformer_blocks.10.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
71
- "transformer_blocks.10.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
72
- "transformer_blocks.10.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
73
- "transformer_blocks.10.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
74
- "transformer_blocks.10.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
75
- "transformer_blocks.10.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
76
- "transformer_blocks.10.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
77
- "transformer_blocks.10.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
78
- "transformer_blocks.10.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
79
- "transformer_blocks.10.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
80
- "transformer_blocks.10.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
81
- "transformer_blocks.10.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
82
- "transformer_blocks.10.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
83
- "transformer_blocks.10.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
84
- "transformer_blocks.11.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
85
- "transformer_blocks.11.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
86
- "transformer_blocks.11.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
87
- "transformer_blocks.11.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
88
- "transformer_blocks.11.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
89
- "transformer_blocks.11.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
90
- "transformer_blocks.11.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
91
- "transformer_blocks.11.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
92
- "transformer_blocks.11.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
93
- "transformer_blocks.11.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
94
- "transformer_blocks.11.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
95
- "transformer_blocks.11.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
96
- "transformer_blocks.11.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
97
- "transformer_blocks.11.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
98
- "transformer_blocks.11.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
99
- "transformer_blocks.11.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
100
- "transformer_blocks.11.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
101
- "transformer_blocks.11.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
102
- "transformer_blocks.11.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
103
- "transformer_blocks.11.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
104
- "transformer_blocks.11.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
105
- "transformer_blocks.12.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
106
- "transformer_blocks.12.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
107
- "transformer_blocks.12.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
108
- "transformer_blocks.12.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
109
- "transformer_blocks.12.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
110
- "transformer_blocks.12.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
111
- "transformer_blocks.12.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
112
- "transformer_blocks.12.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
113
- "transformer_blocks.12.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
114
- "transformer_blocks.12.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
115
- "transformer_blocks.12.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
116
- "transformer_blocks.12.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
117
- "transformer_blocks.12.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
118
- "transformer_blocks.12.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
119
- "transformer_blocks.12.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
120
- "transformer_blocks.12.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
121
- "transformer_blocks.12.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
122
- "transformer_blocks.12.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
123
- "transformer_blocks.12.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
124
- "transformer_blocks.12.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
125
- "transformer_blocks.12.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
126
- "transformer_blocks.13.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
127
- "transformer_blocks.13.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
128
- "transformer_blocks.13.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
129
- "transformer_blocks.13.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
130
- "transformer_blocks.13.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
131
- "transformer_blocks.13.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
132
- "transformer_blocks.13.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
133
- "transformer_blocks.13.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
134
- "transformer_blocks.13.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
135
- "transformer_blocks.13.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
136
- "transformer_blocks.13.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
137
- "transformer_blocks.13.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
138
- "transformer_blocks.13.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
139
- "transformer_blocks.13.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
140
- "transformer_blocks.13.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
141
- "transformer_blocks.13.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
142
- "transformer_blocks.13.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
143
- "transformer_blocks.13.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
144
- "transformer_blocks.13.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
145
- "transformer_blocks.13.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
146
- "transformer_blocks.13.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
147
- "transformer_blocks.14.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
148
- "transformer_blocks.14.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
149
- "transformer_blocks.14.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
150
- "transformer_blocks.14.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
151
- "transformer_blocks.14.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
152
- "transformer_blocks.14.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
153
- "transformer_blocks.14.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
154
- "transformer_blocks.14.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
155
- "transformer_blocks.14.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
156
- "transformer_blocks.14.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
157
- "transformer_blocks.14.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
158
- "transformer_blocks.14.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
159
- "transformer_blocks.14.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
160
- "transformer_blocks.14.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
161
- "transformer_blocks.14.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
162
- "transformer_blocks.14.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
163
- "transformer_blocks.14.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
164
- "transformer_blocks.14.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
165
- "transformer_blocks.14.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
166
- "transformer_blocks.14.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
167
- "transformer_blocks.14.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
168
- "transformer_blocks.15.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
169
- "transformer_blocks.15.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
170
- "transformer_blocks.15.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
171
- "transformer_blocks.15.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
172
- "transformer_blocks.15.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
173
- "transformer_blocks.15.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
174
- "transformer_blocks.15.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
175
- "transformer_blocks.15.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
176
- "transformer_blocks.15.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
177
- "transformer_blocks.15.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
178
- "transformer_blocks.15.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
179
- "transformer_blocks.15.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
180
- "transformer_blocks.15.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
181
- "transformer_blocks.15.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
182
- "transformer_blocks.15.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
183
- "transformer_blocks.15.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
184
- "transformer_blocks.15.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
185
- "transformer_blocks.15.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
186
- "transformer_blocks.15.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
187
- "transformer_blocks.15.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
188
- "transformer_blocks.15.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
189
- "transformer_blocks.16.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
190
- "transformer_blocks.16.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
191
- "transformer_blocks.16.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
192
- "transformer_blocks.16.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
193
- "transformer_blocks.16.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
194
- "transformer_blocks.16.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
195
- "transformer_blocks.16.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
196
- "transformer_blocks.16.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
197
- "transformer_blocks.16.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
198
- "transformer_blocks.16.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
199
- "transformer_blocks.16.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
200
- "transformer_blocks.16.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
201
- "transformer_blocks.16.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
202
- "transformer_blocks.16.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
203
- "transformer_blocks.16.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
204
- "transformer_blocks.16.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
205
- "transformer_blocks.16.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
206
- "transformer_blocks.16.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
207
- "transformer_blocks.16.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
208
- "transformer_blocks.16.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
209
- "transformer_blocks.16.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
210
- "transformer_blocks.17.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
211
- "transformer_blocks.17.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
212
- "transformer_blocks.17.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
213
- "transformer_blocks.17.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
214
- "transformer_blocks.17.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
215
- "transformer_blocks.17.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
216
- "transformer_blocks.17.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
217
- "transformer_blocks.17.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
218
- "transformer_blocks.17.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
219
- "transformer_blocks.17.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
220
- "transformer_blocks.17.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
221
- "transformer_blocks.17.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
222
- "transformer_blocks.17.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
223
- "transformer_blocks.17.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
224
- "transformer_blocks.17.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
225
- "transformer_blocks.17.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
226
- "transformer_blocks.17.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
227
- "transformer_blocks.17.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
228
- "transformer_blocks.17.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
229
- "transformer_blocks.17.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
230
- "transformer_blocks.17.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
231
- "transformer_blocks.18.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
232
- "transformer_blocks.18.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
233
- "transformer_blocks.18.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
234
- "transformer_blocks.18.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
235
- "transformer_blocks.18.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
236
- "transformer_blocks.18.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
237
- "transformer_blocks.18.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
238
- "transformer_blocks.18.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
239
- "transformer_blocks.18.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
240
- "transformer_blocks.18.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
241
- "transformer_blocks.18.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
242
- "transformer_blocks.18.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
243
- "transformer_blocks.18.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
244
- "transformer_blocks.18.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
245
- "transformer_blocks.18.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
246
- "transformer_blocks.18.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
247
- "transformer_blocks.18.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
248
- "transformer_blocks.18.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
249
- "transformer_blocks.18.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
250
- "transformer_blocks.18.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
251
- "transformer_blocks.18.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
252
- "transformer_blocks.19.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
253
- "transformer_blocks.19.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
254
- "transformer_blocks.19.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
255
- "transformer_blocks.19.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
256
- "transformer_blocks.19.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
257
- "transformer_blocks.19.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
258
- "transformer_blocks.19.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
259
- "transformer_blocks.19.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
260
- "transformer_blocks.19.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
261
- "transformer_blocks.19.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
262
- "transformer_blocks.19.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
263
- "transformer_blocks.19.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
264
- "transformer_blocks.19.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
265
- "transformer_blocks.19.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
266
- "transformer_blocks.19.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
267
- "transformer_blocks.19.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
268
- "transformer_blocks.19.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
269
- "transformer_blocks.19.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
270
- "transformer_blocks.19.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
271
- "transformer_blocks.19.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
272
- "transformer_blocks.19.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
273
- "transformer_blocks.2.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
274
- "transformer_blocks.2.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
275
- "transformer_blocks.2.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
276
- "transformer_blocks.2.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
277
- "transformer_blocks.2.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
278
- "transformer_blocks.2.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
279
- "transformer_blocks.2.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
280
- "transformer_blocks.2.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
281
- "transformer_blocks.2.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
282
- "transformer_blocks.2.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
283
- "transformer_blocks.2.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
284
- "transformer_blocks.2.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
285
- "transformer_blocks.2.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
286
- "transformer_blocks.2.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
287
- "transformer_blocks.2.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
288
- "transformer_blocks.2.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
289
- "transformer_blocks.2.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
290
- "transformer_blocks.2.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
291
- "transformer_blocks.2.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
292
- "transformer_blocks.2.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
293
- "transformer_blocks.2.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
294
- "transformer_blocks.20.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
295
- "transformer_blocks.20.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
296
- "transformer_blocks.20.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
297
- "transformer_blocks.20.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
298
- "transformer_blocks.20.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
299
- "transformer_blocks.20.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
300
- "transformer_blocks.20.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
301
- "transformer_blocks.20.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
302
- "transformer_blocks.20.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
303
- "transformer_blocks.20.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
304
- "transformer_blocks.20.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
305
- "transformer_blocks.20.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
306
- "transformer_blocks.20.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
307
- "transformer_blocks.20.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
308
- "transformer_blocks.20.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
309
- "transformer_blocks.20.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
310
- "transformer_blocks.20.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
311
- "transformer_blocks.20.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
312
- "transformer_blocks.20.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
313
- "transformer_blocks.20.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
314
- "transformer_blocks.20.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
315
- "transformer_blocks.21.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
316
- "transformer_blocks.21.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
317
- "transformer_blocks.21.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
318
- "transformer_blocks.21.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
319
- "transformer_blocks.21.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
320
- "transformer_blocks.21.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
321
- "transformer_blocks.21.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
322
- "transformer_blocks.21.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
323
- "transformer_blocks.21.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
324
- "transformer_blocks.21.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
325
- "transformer_blocks.21.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
326
- "transformer_blocks.21.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
327
- "transformer_blocks.21.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
328
- "transformer_blocks.21.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
329
- "transformer_blocks.21.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
330
- "transformer_blocks.21.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
331
- "transformer_blocks.21.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
332
- "transformer_blocks.21.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
333
- "transformer_blocks.21.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
334
- "transformer_blocks.21.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
335
- "transformer_blocks.21.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
336
- "transformer_blocks.22.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
337
- "transformer_blocks.22.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
338
- "transformer_blocks.22.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
339
- "transformer_blocks.22.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
340
- "transformer_blocks.22.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
341
- "transformer_blocks.22.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
342
- "transformer_blocks.22.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
343
- "transformer_blocks.22.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
344
- "transformer_blocks.22.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
345
- "transformer_blocks.22.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
346
- "transformer_blocks.22.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
347
- "transformer_blocks.22.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
348
- "transformer_blocks.22.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
349
- "transformer_blocks.22.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
350
- "transformer_blocks.22.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
351
- "transformer_blocks.22.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
352
- "transformer_blocks.22.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
353
- "transformer_blocks.22.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
354
- "transformer_blocks.22.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
355
- "transformer_blocks.22.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
356
- "transformer_blocks.22.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
357
- "transformer_blocks.23.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
358
- "transformer_blocks.23.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
359
- "transformer_blocks.23.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
360
- "transformer_blocks.23.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
361
- "transformer_blocks.23.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
362
- "transformer_blocks.23.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
363
- "transformer_blocks.23.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
364
- "transformer_blocks.23.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
365
- "transformer_blocks.23.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
366
- "transformer_blocks.23.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
367
- "transformer_blocks.23.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
368
- "transformer_blocks.23.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
369
- "transformer_blocks.23.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
370
- "transformer_blocks.23.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
371
- "transformer_blocks.23.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
372
- "transformer_blocks.23.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
373
- "transformer_blocks.23.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
374
- "transformer_blocks.23.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
375
- "transformer_blocks.23.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
376
- "transformer_blocks.23.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
377
- "transformer_blocks.23.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
378
- "transformer_blocks.24.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
379
- "transformer_blocks.24.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
380
- "transformer_blocks.24.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
381
- "transformer_blocks.24.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
382
- "transformer_blocks.24.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
383
- "transformer_blocks.24.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
384
- "transformer_blocks.24.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
385
- "transformer_blocks.24.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
386
- "transformer_blocks.24.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
387
- "transformer_blocks.24.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
388
- "transformer_blocks.24.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
389
- "transformer_blocks.24.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
390
- "transformer_blocks.24.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
391
- "transformer_blocks.24.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
392
- "transformer_blocks.24.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
393
- "transformer_blocks.24.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
394
- "transformer_blocks.24.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
395
- "transformer_blocks.24.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
396
- "transformer_blocks.24.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
397
- "transformer_blocks.24.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
398
- "transformer_blocks.24.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
399
- "transformer_blocks.25.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
400
- "transformer_blocks.25.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
401
- "transformer_blocks.25.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
402
- "transformer_blocks.25.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
403
- "transformer_blocks.25.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
404
- "transformer_blocks.25.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
405
- "transformer_blocks.25.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
406
- "transformer_blocks.25.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
407
- "transformer_blocks.25.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
408
- "transformer_blocks.25.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
409
- "transformer_blocks.25.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
410
- "transformer_blocks.25.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
411
- "transformer_blocks.25.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
412
- "transformer_blocks.25.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
413
- "transformer_blocks.25.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
414
- "transformer_blocks.25.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
415
- "transformer_blocks.25.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
416
- "transformer_blocks.25.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
417
- "transformer_blocks.25.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
418
- "transformer_blocks.25.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
419
- "transformer_blocks.25.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
420
- "transformer_blocks.26.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
421
- "transformer_blocks.26.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
422
- "transformer_blocks.26.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
423
- "transformer_blocks.26.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
424
- "transformer_blocks.26.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
425
- "transformer_blocks.26.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
426
- "transformer_blocks.26.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
427
- "transformer_blocks.26.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
428
- "transformer_blocks.26.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
429
- "transformer_blocks.26.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
430
- "transformer_blocks.26.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
431
- "transformer_blocks.26.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
432
- "transformer_blocks.26.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
433
- "transformer_blocks.26.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
434
- "transformer_blocks.26.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
435
- "transformer_blocks.26.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
436
- "transformer_blocks.26.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
437
- "transformer_blocks.26.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
438
- "transformer_blocks.26.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
439
- "transformer_blocks.26.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
440
- "transformer_blocks.26.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
441
- "transformer_blocks.27.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
442
- "transformer_blocks.27.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
443
- "transformer_blocks.27.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
444
- "transformer_blocks.27.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
445
- "transformer_blocks.27.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
446
- "transformer_blocks.27.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
447
- "transformer_blocks.27.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
448
- "transformer_blocks.27.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
449
- "transformer_blocks.27.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
450
- "transformer_blocks.27.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
451
- "transformer_blocks.27.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
452
- "transformer_blocks.27.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
453
- "transformer_blocks.27.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
454
- "transformer_blocks.27.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
455
- "transformer_blocks.27.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
456
- "transformer_blocks.27.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
457
- "transformer_blocks.27.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
458
- "transformer_blocks.27.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
459
- "transformer_blocks.27.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
460
- "transformer_blocks.27.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
461
- "transformer_blocks.27.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
462
- "transformer_blocks.28.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
463
- "transformer_blocks.28.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
464
- "transformer_blocks.28.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
465
- "transformer_blocks.28.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
466
- "transformer_blocks.28.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
467
- "transformer_blocks.28.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
468
- "transformer_blocks.28.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
469
- "transformer_blocks.28.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
470
- "transformer_blocks.28.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
471
- "transformer_blocks.28.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
472
- "transformer_blocks.28.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
473
- "transformer_blocks.28.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
474
- "transformer_blocks.28.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
475
- "transformer_blocks.28.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
476
- "transformer_blocks.28.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
477
- "transformer_blocks.28.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
478
- "transformer_blocks.28.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
479
- "transformer_blocks.28.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
480
- "transformer_blocks.28.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
481
- "transformer_blocks.28.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
482
- "transformer_blocks.28.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
483
- "transformer_blocks.29.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
484
- "transformer_blocks.29.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
485
- "transformer_blocks.29.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
486
- "transformer_blocks.29.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
487
- "transformer_blocks.29.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
488
- "transformer_blocks.29.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
489
- "transformer_blocks.29.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
490
- "transformer_blocks.29.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
491
- "transformer_blocks.29.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
492
- "transformer_blocks.29.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
493
- "transformer_blocks.29.attn2.to_out.0.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
494
- "transformer_blocks.29.attn2.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
495
- "transformer_blocks.29.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
496
- "transformer_blocks.29.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
497
- "transformer_blocks.29.attn2.to_v.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
498
- "transformer_blocks.29.attn2.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
499
- "transformer_blocks.29.ff.net.0.proj.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
500
- "transformer_blocks.29.ff.net.0.proj.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
501
- "transformer_blocks.29.ff.net.2.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
502
- "transformer_blocks.29.ff.net.2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
503
- "transformer_blocks.29.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
504
- "transformer_blocks.3.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
505
- "transformer_blocks.3.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
506
- "transformer_blocks.3.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
507
- "transformer_blocks.3.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
508
- "transformer_blocks.3.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
509
- "transformer_blocks.3.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
510
- "transformer_blocks.3.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
511
- "transformer_blocks.3.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
512
- "transformer_blocks.3.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
513
- "transformer_blocks.3.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
514
- "transformer_blocks.3.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
515
- "transformer_blocks.3.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
516
- "transformer_blocks.3.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
517
- "transformer_blocks.3.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
518
- "transformer_blocks.3.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
519
- "transformer_blocks.3.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
520
- "transformer_blocks.3.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
521
- "transformer_blocks.3.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
522
- "transformer_blocks.3.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
523
- "transformer_blocks.3.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
524
- "transformer_blocks.3.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
525
- "transformer_blocks.30.attn1.to_k.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
526
- "transformer_blocks.30.attn1.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
527
- "transformer_blocks.30.attn1.to_out.0.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
528
- "transformer_blocks.30.attn1.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
529
- "transformer_blocks.30.attn1.to_q.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
530
- "transformer_blocks.30.attn1.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
531
- "transformer_blocks.30.attn1.to_v.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
532
- "transformer_blocks.30.attn1.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
533
- "transformer_blocks.30.attn2.to_k.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
534
- "transformer_blocks.30.attn2.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
535
- "transformer_blocks.30.attn2.to_out.0.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
536
- "transformer_blocks.30.attn2.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
537
- "transformer_blocks.30.attn2.to_q.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
538
- "transformer_blocks.30.attn2.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
539
- "transformer_blocks.30.attn2.to_v.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
540
- "transformer_blocks.30.attn2.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
541
- "transformer_blocks.30.ff.net.0.proj.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
542
- "transformer_blocks.30.ff.net.0.proj.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
543
- "transformer_blocks.30.ff.net.2.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
544
- "transformer_blocks.30.ff.net.2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
545
- "transformer_blocks.30.scale_shift_table": "diffusion_pytorch_model-00002-of-00002.safetensors",
546
- "transformer_blocks.31.attn1.to_k.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
547
- "transformer_blocks.31.attn1.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
548
- "transformer_blocks.31.attn1.to_out.0.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
549
- "transformer_blocks.31.attn1.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
550
- "transformer_blocks.31.attn1.to_q.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
551
- "transformer_blocks.31.attn1.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
552
- "transformer_blocks.31.attn1.to_v.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
553
- "transformer_blocks.31.attn1.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
554
- "transformer_blocks.31.attn2.to_k.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
555
- "transformer_blocks.31.attn2.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
556
- "transformer_blocks.31.attn2.to_out.0.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
557
- "transformer_blocks.31.attn2.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
558
- "transformer_blocks.31.attn2.to_q.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
559
- "transformer_blocks.31.attn2.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
560
- "transformer_blocks.31.attn2.to_v.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
561
- "transformer_blocks.31.attn2.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
562
- "transformer_blocks.31.ff.net.0.proj.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
563
- "transformer_blocks.31.ff.net.0.proj.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
564
- "transformer_blocks.31.ff.net.2.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
565
- "transformer_blocks.31.ff.net.2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
566
- "transformer_blocks.31.scale_shift_table": "diffusion_pytorch_model-00002-of-00002.safetensors",
567
- "transformer_blocks.4.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
568
- "transformer_blocks.4.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
569
- "transformer_blocks.4.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
570
- "transformer_blocks.4.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
571
- "transformer_blocks.4.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
572
- "transformer_blocks.4.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
573
- "transformer_blocks.4.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
574
- "transformer_blocks.4.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
575
- "transformer_blocks.4.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
576
- "transformer_blocks.4.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
577
- "transformer_blocks.4.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
578
- "transformer_blocks.4.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
579
- "transformer_blocks.4.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
580
- "transformer_blocks.4.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
581
- "transformer_blocks.4.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
582
- "transformer_blocks.4.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
583
- "transformer_blocks.4.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
584
- "transformer_blocks.4.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
585
- "transformer_blocks.4.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
586
- "transformer_blocks.4.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
587
- "transformer_blocks.4.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
588
- "transformer_blocks.5.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
589
- "transformer_blocks.5.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
590
- "transformer_blocks.5.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
591
- "transformer_blocks.5.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
592
- "transformer_blocks.5.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
593
- "transformer_blocks.5.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
594
- "transformer_blocks.5.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
595
- "transformer_blocks.5.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
596
- "transformer_blocks.5.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
597
- "transformer_blocks.5.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
598
- "transformer_blocks.5.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
599
- "transformer_blocks.5.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
600
- "transformer_blocks.5.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
601
- "transformer_blocks.5.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
602
- "transformer_blocks.5.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
603
- "transformer_blocks.5.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
604
- "transformer_blocks.5.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
605
- "transformer_blocks.5.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
606
- "transformer_blocks.5.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
607
- "transformer_blocks.5.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
608
- "transformer_blocks.5.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
609
- "transformer_blocks.6.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
610
- "transformer_blocks.6.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
611
- "transformer_blocks.6.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
612
- "transformer_blocks.6.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
613
- "transformer_blocks.6.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
614
- "transformer_blocks.6.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
615
- "transformer_blocks.6.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
616
- "transformer_blocks.6.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
617
- "transformer_blocks.6.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
618
- "transformer_blocks.6.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
619
- "transformer_blocks.6.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
620
- "transformer_blocks.6.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
621
- "transformer_blocks.6.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
622
- "transformer_blocks.6.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
623
- "transformer_blocks.6.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
624
- "transformer_blocks.6.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
625
- "transformer_blocks.6.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
626
- "transformer_blocks.6.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
627
- "transformer_blocks.6.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
628
- "transformer_blocks.6.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
629
- "transformer_blocks.6.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
630
- "transformer_blocks.7.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
631
- "transformer_blocks.7.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
632
- "transformer_blocks.7.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
633
- "transformer_blocks.7.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
634
- "transformer_blocks.7.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
635
- "transformer_blocks.7.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
636
- "transformer_blocks.7.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
637
- "transformer_blocks.7.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
638
- "transformer_blocks.7.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
639
- "transformer_blocks.7.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
640
- "transformer_blocks.7.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
641
- "transformer_blocks.7.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
642
- "transformer_blocks.7.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
643
- "transformer_blocks.7.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
644
- "transformer_blocks.7.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
645
- "transformer_blocks.7.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
646
- "transformer_blocks.7.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
647
- "transformer_blocks.7.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
648
- "transformer_blocks.7.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
649
- "transformer_blocks.7.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
650
- "transformer_blocks.7.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
651
- "transformer_blocks.8.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
652
- "transformer_blocks.8.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
653
- "transformer_blocks.8.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
654
- "transformer_blocks.8.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
655
- "transformer_blocks.8.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
656
- "transformer_blocks.8.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
657
- "transformer_blocks.8.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
658
- "transformer_blocks.8.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
659
- "transformer_blocks.8.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
660
- "transformer_blocks.8.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
661
- "transformer_blocks.8.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
662
- "transformer_blocks.8.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
663
- "transformer_blocks.8.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
664
- "transformer_blocks.8.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
665
- "transformer_blocks.8.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
666
- "transformer_blocks.8.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
667
- "transformer_blocks.8.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
668
- "transformer_blocks.8.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
669
- "transformer_blocks.8.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
670
- "transformer_blocks.8.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
671
- "transformer_blocks.8.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors",
672
- "transformer_blocks.9.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
673
- "transformer_blocks.9.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
674
- "transformer_blocks.9.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
675
- "transformer_blocks.9.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
676
- "transformer_blocks.9.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
677
- "transformer_blocks.9.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
678
- "transformer_blocks.9.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
679
- "transformer_blocks.9.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
680
- "transformer_blocks.9.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
681
- "transformer_blocks.9.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
682
- "transformer_blocks.9.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
683
- "transformer_blocks.9.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
684
- "transformer_blocks.9.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
685
- "transformer_blocks.9.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
686
- "transformer_blocks.9.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
687
- "transformer_blocks.9.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
688
- "transformer_blocks.9.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
689
- "transformer_blocks.9.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
690
- "transformer_blocks.9.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
691
- "transformer_blocks.9.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
692
- "transformer_blocks.9.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors"
693
- }
694
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
transformer/transformer_3d_allegro.py DELETED
@@ -1,1776 +0,0 @@
1
- # Adapted from Open-Sora-Plan
2
-
3
- # This source code is licensed under the license found in the
4
- # LICENSE file in the root directory of this source tree.
5
- # --------------------------------------------------------
6
- # References:
7
- # Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
8
- # --------------------------------------------------------
9
-
10
-
11
- import json
12
- import os
13
- from dataclasses import dataclass
14
- from functools import partial
15
- from importlib import import_module
16
- from typing import Any, Callable, Dict, Optional, Tuple
17
-
18
- import numpy as np
19
- import torch
20
- import collections
21
- import torch.nn.functional as F
22
- from torch.nn.attention import SDPBackend, sdpa_kernel
23
- from diffusers.configuration_utils import ConfigMixin, register_to_config
24
- from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
25
- from diffusers.models.attention_processor import (
26
- AttnAddedKVProcessor,
27
- AttnAddedKVProcessor2_0,
28
- AttnProcessor,
29
- CustomDiffusionAttnProcessor,
30
- CustomDiffusionAttnProcessor2_0,
31
- CustomDiffusionXFormersAttnProcessor,
32
- LoRAAttnAddedKVProcessor,
33
- LoRAAttnProcessor,
34
- LoRAAttnProcessor2_0,
35
- LoRAXFormersAttnProcessor,
36
- SlicedAttnAddedKVProcessor,
37
- SlicedAttnProcessor,
38
- SpatialNorm,
39
- XFormersAttnAddedKVProcessor,
40
- XFormersAttnProcessor,
41
- )
42
- from diffusers.models.embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps
43
- from diffusers.models.modeling_utils import ModelMixin
44
- from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero
45
- from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_xformers_available
46
- from diffusers.utils.torch_utils import maybe_allow_in_graph
47
- from einops import rearrange, repeat
48
- from torch import nn
49
- from diffusers.models.embeddings import PixArtAlphaTextProjection
50
-
51
-
52
- if is_xformers_available():
53
- import xformers
54
- import xformers.ops
55
- else:
56
- xformers = None
57
-
58
- from diffusers.utils import logging
59
-
60
- logger = logging.get_logger(__name__)
61
-
62
-
63
- def to_2tuple(x):
64
- if isinstance(x, collections.abc.Iterable):
65
- return x
66
- return (x, x)
67
-
68
- class CombinedTimestepSizeEmbeddings(nn.Module):
69
- """
70
- For PixArt-Alpha.
71
-
72
- Reference:
73
- https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
74
- """
75
-
76
- def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
77
- super().__init__()
78
-
79
- self.outdim = size_emb_dim
80
- self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
81
- self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
82
-
83
- self.use_additional_conditions = use_additional_conditions
84
- if use_additional_conditions:
85
- self.use_additional_conditions = True
86
- self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
87
- self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
88
- self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
89
-
90
- def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module):
91
- if size.ndim == 1:
92
- size = size[:, None]
93
-
94
- if size.shape[0] != batch_size:
95
- size = size.repeat(batch_size // size.shape[0], 1)
96
- if size.shape[0] != batch_size:
97
- raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.")
98
-
99
- current_batch_size, dims = size.shape[0], size.shape[1]
100
- size = size.reshape(-1)
101
- size_freq = self.additional_condition_proj(size).to(size.dtype)
102
-
103
- size_emb = embedder(size_freq)
104
- size_emb = size_emb.reshape(current_batch_size, dims * self.outdim)
105
- return size_emb
106
-
107
- def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
108
- timesteps_proj = self.time_proj(timestep)
109
- timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
110
-
111
- if self.use_additional_conditions:
112
- resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder)
113
- aspect_ratio = self.apply_condition(
114
- aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder
115
- )
116
- conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1)
117
- else:
118
- conditioning = timesteps_emb
119
-
120
- return conditioning
121
-
122
-
123
- class PositionGetter3D(object):
124
- """ return positions of patches """
125
-
126
- def __init__(self, ):
127
- self.cache_positions = {}
128
-
129
- def __call__(self, b, t, h, w, device):
130
- if not (b, t,h,w) in self.cache_positions:
131
- x = torch.arange(w, device=device)
132
- y = torch.arange(h, device=device)
133
- z = torch.arange(t, device=device)
134
- pos = torch.cartesian_prod(z, y, x)
135
-
136
- pos = pos.reshape(t * h * w, 3).transpose(0, 1).reshape(3, 1, -1).contiguous().expand(3, b, -1).clone()
137
- poses = (pos[0].contiguous(), pos[1].contiguous(), pos[2].contiguous())
138
- max_poses = (int(poses[0].max()), int(poses[1].max()), int(poses[2].max()))
139
-
140
- self.cache_positions[b, t, h, w] = (poses, max_poses)
141
- pos = self.cache_positions[b, t, h, w]
142
-
143
- return pos
144
-
145
-
146
- class RoPE3D(torch.nn.Module):
147
-
148
- def __init__(self, freq=10000.0, F0=1.0, interpolation_scale_thw=(1, 1, 1)):
149
- super().__init__()
150
- self.base = freq
151
- self.F0 = F0
152
- self.interpolation_scale_t = interpolation_scale_thw[0]
153
- self.interpolation_scale_h = interpolation_scale_thw[1]
154
- self.interpolation_scale_w = interpolation_scale_thw[2]
155
- self.cache = {}
156
-
157
- def get_cos_sin(self, D, seq_len, device, dtype, interpolation_scale=1):
158
- if (D, seq_len, device, dtype) not in self.cache:
159
- inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
160
- t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) / interpolation_scale
161
- freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
162
- freqs = torch.cat((freqs, freqs), dim=-1)
163
- cos = freqs.cos() # (Seq, Dim)
164
- sin = freqs.sin()
165
- self.cache[D, seq_len, device, dtype] = (cos, sin)
166
- return self.cache[D, seq_len, device, dtype]
167
-
168
- @staticmethod
169
- def rotate_half(x):
170
- x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
171
- return torch.cat((-x2, x1), dim=-1)
172
-
173
- def apply_rope1d(self, tokens, pos1d, cos, sin):
174
- assert pos1d.ndim == 2
175
-
176
- # for (batch_size x ntokens x nheads x dim)
177
- cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
178
- sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
179
- return (tokens * cos) + (self.rotate_half(tokens) * sin)
180
-
181
- def forward(self, tokens, positions):
182
- """
183
- input:
184
- * tokens: batch_size x nheads x ntokens x dim
185
- * positions: batch_size x ntokens x 3 (t, y and x position of each token)
186
- output:
187
- * tokens after appplying RoPE3D (batch_size x nheads x ntokens x x dim)
188
- """
189
- assert tokens.size(3) % 3 == 0, "number of dimensions should be a multiple of three"
190
- D = tokens.size(3) // 3
191
- poses, max_poses = positions
192
- assert len(poses) == 3 and poses[0].ndim == 2# Batch, Seq, 3
193
- cos_t, sin_t = self.get_cos_sin(D, max_poses[0] + 1, tokens.device, tokens.dtype, self.interpolation_scale_t)
194
- cos_y, sin_y = self.get_cos_sin(D, max_poses[1] + 1, tokens.device, tokens.dtype, self.interpolation_scale_h)
195
- cos_x, sin_x = self.get_cos_sin(D, max_poses[2] + 1, tokens.device, tokens.dtype, self.interpolation_scale_w)
196
- # split features into three along the feature dimension, and apply rope1d on each half
197
- t, y, x = tokens.chunk(3, dim=-1)
198
- t = self.apply_rope1d(t, poses[0], cos_t, sin_t)
199
- y = self.apply_rope1d(y, poses[1], cos_y, sin_y)
200
- x = self.apply_rope1d(x, poses[2], cos_x, sin_x)
201
- tokens = torch.cat((t, y, x), dim=-1)
202
- return tokens
203
-
204
- class PatchEmbed2D(nn.Module):
205
- """2D Image to Patch Embedding"""
206
-
207
- def __init__(
208
- self,
209
- num_frames=1,
210
- height=224,
211
- width=224,
212
- patch_size_t=1,
213
- patch_size=16,
214
- in_channels=3,
215
- embed_dim=768,
216
- layer_norm=False,
217
- flatten=True,
218
- bias=True,
219
- interpolation_scale=(1, 1),
220
- interpolation_scale_t=1,
221
- use_abs_pos=False,
222
- ):
223
- super().__init__()
224
- self.use_abs_pos = use_abs_pos
225
- self.flatten = flatten
226
- self.layer_norm = layer_norm
227
-
228
- self.proj = nn.Conv2d(
229
- in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), bias=bias
230
- )
231
- if layer_norm:
232
- self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
233
- else:
234
- self.norm = None
235
-
236
- self.patch_size_t = patch_size_t
237
- self.patch_size = patch_size
238
-
239
- def forward(self, latent):
240
- b, _, _, _, _ = latent.shape
241
- video_latent = None
242
-
243
- latent = rearrange(latent, 'b c t h w -> (b t) c h w')
244
-
245
- latent = self.proj(latent)
246
- if self.flatten:
247
- latent = latent.flatten(2).transpose(1, 2) # BT C H W -> BT N C
248
- if self.layer_norm:
249
- latent = self.norm(latent)
250
-
251
- latent = rearrange(latent, '(b t) n c -> b (t n) c', b=b)
252
- video_latent = latent
253
-
254
- return video_latent
255
-
256
-
257
- @maybe_allow_in_graph
258
- class Attention(nn.Module):
259
- r"""
260
- A cross attention layer.
261
-
262
- Parameters:
263
- query_dim (`int`):
264
- The number of channels in the query.
265
- cross_attention_dim (`int`, *optional*):
266
- The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
267
- heads (`int`, *optional*, defaults to 8):
268
- The number of heads to use for multi-head attention.
269
- dim_head (`int`, *optional*, defaults to 64):
270
- The number of channels in each head.
271
- dropout (`float`, *optional*, defaults to 0.0):
272
- The dropout probability to use.
273
- bias (`bool`, *optional*, defaults to False):
274
- Set to `True` for the query, key, and value linear layers to contain a bias parameter.
275
- upcast_attention (`bool`, *optional*, defaults to False):
276
- Set to `True` to upcast the attention computation to `float32`.
277
- upcast_softmax (`bool`, *optional*, defaults to False):
278
- Set to `True` to upcast the softmax computation to `float32`.
279
- cross_attention_norm (`str`, *optional*, defaults to `None`):
280
- The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
281
- cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
282
- The number of groups to use for the group norm in the cross attention.
283
- added_kv_proj_dim (`int`, *optional*, defaults to `None`):
284
- The number of channels to use for the added key and value projections. If `None`, no projection is used.
285
- norm_num_groups (`int`, *optional*, defaults to `None`):
286
- The number of groups to use for the group norm in the attention.
287
- spatial_norm_dim (`int`, *optional*, defaults to `None`):
288
- The number of channels to use for the spatial normalization.
289
- out_bias (`bool`, *optional*, defaults to `True`):
290
- Set to `True` to use a bias in the output linear layer.
291
- scale_qk (`bool`, *optional*, defaults to `True`):
292
- Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
293
- only_cross_attention (`bool`, *optional*, defaults to `False`):
294
- Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
295
- `added_kv_proj_dim` is not `None`.
296
- eps (`float`, *optional*, defaults to 1e-5):
297
- An additional value added to the denominator in group normalization that is used for numerical stability.
298
- rescale_output_factor (`float`, *optional*, defaults to 1.0):
299
- A factor to rescale the output by dividing it with this value.
300
- residual_connection (`bool`, *optional*, defaults to `False`):
301
- Set to `True` to add the residual connection to the output.
302
- _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
303
- Set to `True` if the attention block is loaded from a deprecated state dict.
304
- processor (`AttnProcessor`, *optional*, defaults to `None`):
305
- The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
306
- `AttnProcessor` otherwise.
307
- """
308
-
309
- def __init__(
310
- self,
311
- query_dim: int,
312
- cross_attention_dim: Optional[int] = None,
313
- heads: int = 8,
314
- dim_head: int = 64,
315
- dropout: float = 0.0,
316
- bias: bool = False,
317
- upcast_attention: bool = False,
318
- upcast_softmax: bool = False,
319
- cross_attention_norm: Optional[str] = None,
320
- cross_attention_norm_num_groups: int = 32,
321
- added_kv_proj_dim: Optional[int] = None,
322
- norm_num_groups: Optional[int] = None,
323
- spatial_norm_dim: Optional[int] = None,
324
- out_bias: bool = True,
325
- scale_qk: bool = True,
326
- only_cross_attention: bool = False,
327
- eps: float = 1e-5,
328
- rescale_output_factor: float = 1.0,
329
- residual_connection: bool = False,
330
- _from_deprecated_attn_block: bool = False,
331
- processor: Optional["AttnProcessor"] = None,
332
- attention_mode: str = "xformers",
333
- use_rope: bool = False,
334
- interpolation_scale_thw=None,
335
- ):
336
- super().__init__()
337
- self.inner_dim = dim_head * heads
338
- self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
339
- self.upcast_attention = upcast_attention
340
- self.upcast_softmax = upcast_softmax
341
- self.rescale_output_factor = rescale_output_factor
342
- self.residual_connection = residual_connection
343
- self.dropout = dropout
344
- self.use_rope = use_rope
345
-
346
- # we make use of this private variable to know whether this class is loaded
347
- # with an deprecated state dict so that we can convert it on the fly
348
- self._from_deprecated_attn_block = _from_deprecated_attn_block
349
-
350
- self.scale_qk = scale_qk
351
- self.scale = dim_head**-0.5 if self.scale_qk else 1.0
352
-
353
- self.heads = heads
354
- # for slice_size > 0 the attention score computation
355
- # is split across the batch axis to save memory
356
- # You can set slice_size with `set_attention_slice`
357
- self.sliceable_head_dim = heads
358
-
359
- self.added_kv_proj_dim = added_kv_proj_dim
360
- self.only_cross_attention = only_cross_attention
361
-
362
- if self.added_kv_proj_dim is None and self.only_cross_attention:
363
- raise ValueError(
364
- "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
365
- )
366
-
367
- if norm_num_groups is not None:
368
- self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
369
- else:
370
- self.group_norm = None
371
-
372
- if spatial_norm_dim is not None:
373
- self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
374
- else:
375
- self.spatial_norm = None
376
-
377
- if cross_attention_norm is None:
378
- self.norm_cross = None
379
- elif cross_attention_norm == "layer_norm":
380
- self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
381
- elif cross_attention_norm == "group_norm":
382
- if self.added_kv_proj_dim is not None:
383
- # The given `encoder_hidden_states` are initially of shape
384
- # (batch_size, seq_len, added_kv_proj_dim) before being projected
385
- # to (batch_size, seq_len, cross_attention_dim). The norm is applied
386
- # before the projection, so we need to use `added_kv_proj_dim` as
387
- # the number of channels for the group norm.
388
- norm_cross_num_channels = added_kv_proj_dim
389
- else:
390
- norm_cross_num_channels = self.cross_attention_dim
391
-
392
- self.norm_cross = nn.GroupNorm(
393
- num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
394
- )
395
- else:
396
- raise ValueError(
397
- f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
398
- )
399
-
400
- linear_cls = nn.Linear
401
-
402
-
403
- self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
404
-
405
- if not self.only_cross_attention:
406
- # only relevant for the `AddedKVProcessor` classes
407
- self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
408
- self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
409
- else:
410
- self.to_k = None
411
- self.to_v = None
412
-
413
- if self.added_kv_proj_dim is not None:
414
- self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
415
- self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
416
-
417
- self.to_out = nn.ModuleList([])
418
- self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias))
419
- self.to_out.append(nn.Dropout(dropout))
420
-
421
- # set attention processor
422
- # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
423
- # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
424
- # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
425
- if processor is None:
426
- processor = (
427
- AttnProcessor2_0(
428
- attention_mode,
429
- use_rope,
430
- interpolation_scale_thw=interpolation_scale_thw,
431
- )
432
- if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
433
- else AttnProcessor()
434
- )
435
- self.set_processor(processor)
436
-
437
- def set_use_memory_efficient_attention_xformers(
438
- self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
439
- ) -> None:
440
- r"""
441
- Set whether to use memory efficient attention from `xformers` or not.
442
-
443
- Args:
444
- use_memory_efficient_attention_xformers (`bool`):
445
- Whether to use memory efficient attention from `xformers` or not.
446
- attention_op (`Callable`, *optional*):
447
- The attention operation to use. Defaults to `None` which uses the default attention operation from
448
- `xformers`.
449
- """
450
- is_lora = hasattr(self, "processor")
451
- is_custom_diffusion = hasattr(self, "processor") and isinstance(
452
- self.processor,
453
- (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
454
- )
455
- is_added_kv_processor = hasattr(self, "processor") and isinstance(
456
- self.processor,
457
- (
458
- AttnAddedKVProcessor,
459
- AttnAddedKVProcessor2_0,
460
- SlicedAttnAddedKVProcessor,
461
- XFormersAttnAddedKVProcessor,
462
- LoRAAttnAddedKVProcessor,
463
- ),
464
- )
465
-
466
- if use_memory_efficient_attention_xformers:
467
- if is_added_kv_processor and (is_lora or is_custom_diffusion):
468
- raise NotImplementedError(
469
- f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}"
470
- )
471
- if not is_xformers_available():
472
- raise ModuleNotFoundError(
473
- (
474
- "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
475
- " xformers"
476
- ),
477
- name="xformers",
478
- )
479
- elif not torch.cuda.is_available():
480
- raise ValueError(
481
- "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
482
- " only available for GPU "
483
- )
484
- else:
485
- try:
486
- # Make sure we can run the memory efficient attention
487
- _ = xformers.ops.memory_efficient_attention(
488
- torch.randn((1, 2, 40), device="cuda"),
489
- torch.randn((1, 2, 40), device="cuda"),
490
- torch.randn((1, 2, 40), device="cuda"),
491
- )
492
- except Exception as e:
493
- raise e
494
-
495
- if is_lora:
496
- # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
497
- # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
498
- processor = LoRAXFormersAttnProcessor(
499
- hidden_size=self.processor.hidden_size,
500
- cross_attention_dim=self.processor.cross_attention_dim,
501
- rank=self.processor.rank,
502
- attention_op=attention_op,
503
- )
504
- processor.load_state_dict(self.processor.state_dict())
505
- processor.to(self.processor.to_q_lora.up.weight.device)
506
- elif is_custom_diffusion:
507
- processor = CustomDiffusionXFormersAttnProcessor(
508
- train_kv=self.processor.train_kv,
509
- train_q_out=self.processor.train_q_out,
510
- hidden_size=self.processor.hidden_size,
511
- cross_attention_dim=self.processor.cross_attention_dim,
512
- attention_op=attention_op,
513
- )
514
- processor.load_state_dict(self.processor.state_dict())
515
- if hasattr(self.processor, "to_k_custom_diffusion"):
516
- processor.to(self.processor.to_k_custom_diffusion.weight.device)
517
- elif is_added_kv_processor:
518
- # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
519
- # which uses this type of cross attention ONLY because the attention mask of format
520
- # [0, ..., -10.000, ..., 0, ...,] is not supported
521
- # throw warning
522
- logger.info(
523
- "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
524
- )
525
- processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
526
- else:
527
- processor = XFormersAttnProcessor(attention_op=attention_op)
528
- else:
529
- if is_lora:
530
- attn_processor_class = (
531
- LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
532
- )
533
- processor = attn_processor_class(
534
- hidden_size=self.processor.hidden_size,
535
- cross_attention_dim=self.processor.cross_attention_dim,
536
- rank=self.processor.rank,
537
- )
538
- processor.load_state_dict(self.processor.state_dict())
539
- processor.to(self.processor.to_q_lora.up.weight.device)
540
- elif is_custom_diffusion:
541
- attn_processor_class = (
542
- CustomDiffusionAttnProcessor2_0
543
- if hasattr(F, "scaled_dot_product_attention")
544
- else CustomDiffusionAttnProcessor
545
- )
546
- processor = attn_processor_class(
547
- train_kv=self.processor.train_kv,
548
- train_q_out=self.processor.train_q_out,
549
- hidden_size=self.processor.hidden_size,
550
- cross_attention_dim=self.processor.cross_attention_dim,
551
- )
552
- processor.load_state_dict(self.processor.state_dict())
553
- if hasattr(self.processor, "to_k_custom_diffusion"):
554
- processor.to(self.processor.to_k_custom_diffusion.weight.device)
555
- else:
556
- # set attention processor
557
- # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
558
- # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
559
- # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
560
- processor = (
561
- AttnProcessor2_0()
562
- if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
563
- else AttnProcessor()
564
- )
565
-
566
- self.set_processor(processor)
567
-
568
- def set_attention_slice(self, slice_size: int) -> None:
569
- r"""
570
- Set the slice size for attention computation.
571
-
572
- Args:
573
- slice_size (`int`):
574
- The slice size for attention computation.
575
- """
576
- if slice_size is not None and slice_size > self.sliceable_head_dim:
577
- raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
578
-
579
- if slice_size is not None and self.added_kv_proj_dim is not None:
580
- processor = SlicedAttnAddedKVProcessor(slice_size)
581
- elif slice_size is not None:
582
- processor = SlicedAttnProcessor(slice_size)
583
- elif self.added_kv_proj_dim is not None:
584
- processor = AttnAddedKVProcessor()
585
- else:
586
- # set attention processor
587
- # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
588
- # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
589
- # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
590
- processor = (
591
- AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
592
- )
593
-
594
- self.set_processor(processor)
595
-
596
- def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None:
597
- r"""
598
- Set the attention processor to use.
599
-
600
- Args:
601
- processor (`AttnProcessor`):
602
- The attention processor to use.
603
- _remove_lora (`bool`, *optional*, defaults to `False`):
604
- Set to `True` to remove LoRA layers from the model.
605
- """
606
- if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
607
- deprecate(
608
- "set_processor to offload LoRA",
609
- "0.26.0",
610
- "In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
611
- )
612
- # TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete
613
- # We need to remove all LoRA layers
614
- # Don't forget to remove ALL `_remove_lora` from the codebase
615
- for module in self.modules():
616
- if hasattr(module, "set_lora_layer"):
617
- module.set_lora_layer(None)
618
-
619
- # if current processor is in `self._modules` and if passed `processor` is not, we need to
620
- # pop `processor` from `self._modules`
621
- if (
622
- hasattr(self, "processor")
623
- and isinstance(self.processor, torch.nn.Module)
624
- and not isinstance(processor, torch.nn.Module)
625
- ):
626
- logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
627
- self._modules.pop("processor")
628
-
629
- self.processor = processor
630
-
631
- def get_processor(self, return_deprecated_lora: bool = False):
632
- r"""
633
- Get the attention processor in use.
634
-
635
- Args:
636
- return_deprecated_lora (`bool`, *optional*, defaults to `False`):
637
- Set to `True` to return the deprecated LoRA attention processor.
638
-
639
- Returns:
640
- "AttentionProcessor": The attention processor in use.
641
- """
642
- if not return_deprecated_lora:
643
- return self.processor
644
-
645
- # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
646
- # serialization format for LoRA Attention Processors. It should be deleted once the integration
647
- # with PEFT is completed.
648
- is_lora_activated = {
649
- name: module.lora_layer is not None
650
- for name, module in self.named_modules()
651
- if hasattr(module, "lora_layer")
652
- }
653
-
654
- # 1. if no layer has a LoRA activated we can return the processor as usual
655
- if not any(is_lora_activated.values()):
656
- return self.processor
657
-
658
- # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
659
- is_lora_activated.pop("add_k_proj", None)
660
- is_lora_activated.pop("add_v_proj", None)
661
- # 2. else it is not posssible that only some layers have LoRA activated
662
- if not all(is_lora_activated.values()):
663
- raise ValueError(
664
- f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
665
- )
666
-
667
- # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
668
- non_lora_processor_cls_name = self.processor.__class__.__name__
669
- lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
670
-
671
- hidden_size = self.inner_dim
672
-
673
- # now create a LoRA attention processor from the LoRA layers
674
- if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
675
- kwargs = {
676
- "cross_attention_dim": self.cross_attention_dim,
677
- "rank": self.to_q.lora_layer.rank,
678
- "network_alpha": self.to_q.lora_layer.network_alpha,
679
- "q_rank": self.to_q.lora_layer.rank,
680
- "q_hidden_size": self.to_q.lora_layer.out_features,
681
- "k_rank": self.to_k.lora_layer.rank,
682
- "k_hidden_size": self.to_k.lora_layer.out_features,
683
- "v_rank": self.to_v.lora_layer.rank,
684
- "v_hidden_size": self.to_v.lora_layer.out_features,
685
- "out_rank": self.to_out[0].lora_layer.rank,
686
- "out_hidden_size": self.to_out[0].lora_layer.out_features,
687
- }
688
-
689
- if hasattr(self.processor, "attention_op"):
690
- kwargs["attention_op"] = self.processor.attention_op
691
-
692
- lora_processor = lora_processor_cls(hidden_size, **kwargs)
693
- lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
694
- lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
695
- lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
696
- lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
697
- elif lora_processor_cls == LoRAAttnAddedKVProcessor:
698
- lora_processor = lora_processor_cls(
699
- hidden_size,
700
- cross_attention_dim=self.add_k_proj.weight.shape[0],
701
- rank=self.to_q.lora_layer.rank,
702
- network_alpha=self.to_q.lora_layer.network_alpha,
703
- )
704
- lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
705
- lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
706
- lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
707
- lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
708
-
709
- # only save if used
710
- if self.add_k_proj.lora_layer is not None:
711
- lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
712
- lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
713
- else:
714
- lora_processor.add_k_proj_lora = None
715
- lora_processor.add_v_proj_lora = None
716
- else:
717
- raise ValueError(f"{lora_processor_cls} does not exist.")
718
-
719
- return lora_processor
720
-
721
- def forward(
722
- self,
723
- hidden_states: torch.FloatTensor,
724
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
725
- attention_mask: Optional[torch.FloatTensor] = None,
726
- **cross_attention_kwargs,
727
- ) -> torch.Tensor:
728
- r"""
729
- The forward method of the `Attention` class.
730
-
731
- Args:
732
- hidden_states (`torch.Tensor`):
733
- The hidden states of the query.
734
- encoder_hidden_states (`torch.Tensor`, *optional*):
735
- The hidden states of the encoder.
736
- attention_mask (`torch.Tensor`, *optional*):
737
- The attention mask to use. If `None`, no mask is applied.
738
- **cross_attention_kwargs:
739
- Additional keyword arguments to pass along to the cross attention.
740
-
741
- Returns:
742
- `torch.Tensor`: The output of the attention layer.
743
- """
744
- # The `Attention` class can call different attention processors / attention functions
745
- # here we simply pass along all tensors to the selected processor class
746
- # For standard processors that are defined here, `**cross_attention_kwargs` is empty
747
- return self.processor(
748
- self,
749
- hidden_states,
750
- encoder_hidden_states=encoder_hidden_states,
751
- attention_mask=attention_mask,
752
- **cross_attention_kwargs,
753
- )
754
-
755
- def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
756
- r"""
757
- Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
758
- is the number of heads initialized while constructing the `Attention` class.
759
-
760
- Args:
761
- tensor (`torch.Tensor`): The tensor to reshape.
762
-
763
- Returns:
764
- `torch.Tensor`: The reshaped tensor.
765
- """
766
- head_size = self.heads
767
- batch_size, seq_len, dim = tensor.shape
768
- tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
769
- tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
770
- return tensor
771
-
772
- def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
773
- r"""
774
- Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
775
- the number of heads initialized while constructing the `Attention` class.
776
-
777
- Args:
778
- tensor (`torch.Tensor`): The tensor to reshape.
779
- out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
780
- reshaped to `[batch_size * heads, seq_len, dim // heads]`.
781
-
782
- Returns:
783
- `torch.Tensor`: The reshaped tensor.
784
- """
785
- head_size = self.heads
786
- batch_size, seq_len, dim = tensor.shape
787
- tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
788
- tensor = tensor.permute(0, 2, 1, 3)
789
-
790
- if out_dim == 3:
791
- tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
792
-
793
- return tensor
794
-
795
- def get_attention_scores(
796
- self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
797
- ) -> torch.Tensor:
798
- r"""
799
- Compute the attention scores.
800
-
801
- Args:
802
- query (`torch.Tensor`): The query tensor.
803
- key (`torch.Tensor`): The key tensor.
804
- attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
805
-
806
- Returns:
807
- `torch.Tensor`: The attention probabilities/scores.
808
- """
809
- dtype = query.dtype
810
- if self.upcast_attention:
811
- query = query.float()
812
- key = key.float()
813
-
814
- if attention_mask is None:
815
- baddbmm_input = torch.empty(
816
- query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
817
- )
818
- beta = 0
819
- else:
820
- baddbmm_input = attention_mask
821
- beta = 1
822
-
823
- attention_scores = torch.baddbmm(
824
- baddbmm_input,
825
- query,
826
- key.transpose(-1, -2),
827
- beta=beta,
828
- alpha=self.scale,
829
- )
830
- del baddbmm_input
831
-
832
- if self.upcast_softmax:
833
- attention_scores = attention_scores.float()
834
-
835
- attention_probs = attention_scores.softmax(dim=-1)
836
- del attention_scores
837
-
838
- attention_probs = attention_probs.to(dtype)
839
-
840
- return attention_probs
841
-
842
- def prepare_attention_mask(
843
- self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3, head_size = None,
844
- ) -> torch.Tensor:
845
- r"""
846
- Prepare the attention mask for the attention computation.
847
-
848
- Args:
849
- attention_mask (`torch.Tensor`):
850
- The attention mask to prepare.
851
- target_length (`int`):
852
- The target length of the attention mask. This is the length of the attention mask after padding.
853
- batch_size (`int`):
854
- The batch size, which is used to repeat the attention mask.
855
- out_dim (`int`, *optional*, defaults to `3`):
856
- The output dimension of the attention mask. Can be either `3` or `4`.
857
-
858
- Returns:
859
- `torch.Tensor`: The prepared attention mask.
860
- """
861
- head_size = head_size if head_size is not None else self.heads
862
- if attention_mask is None:
863
- return attention_mask
864
-
865
- current_length: int = attention_mask.shape[-1]
866
- if current_length != target_length:
867
- if attention_mask.device.type == "mps":
868
- # HACK: MPS: Does not support padding by greater than dimension of input tensor.
869
- # Instead, we can manually construct the padding tensor.
870
- padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
871
- padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
872
- attention_mask = torch.cat([attention_mask, padding], dim=2)
873
- else:
874
- # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
875
- # we want to instead pad by (0, remaining_length), where remaining_length is:
876
- # remaining_length: int = target_length - current_length
877
- # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
878
- attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
879
-
880
- if out_dim == 3:
881
- if attention_mask.shape[0] < batch_size * head_size:
882
- attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
883
- elif out_dim == 4:
884
- attention_mask = attention_mask.unsqueeze(1)
885
- attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
886
-
887
- return attention_mask
888
-
889
- def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
890
- r"""
891
- Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
892
- `Attention` class.
893
-
894
- Args:
895
- encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
896
-
897
- Returns:
898
- `torch.Tensor`: The normalized encoder hidden states.
899
- """
900
- assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
901
-
902
- if isinstance(self.norm_cross, nn.LayerNorm):
903
- encoder_hidden_states = self.norm_cross(encoder_hidden_states)
904
- elif isinstance(self.norm_cross, nn.GroupNorm):
905
- # Group norm norms along the channels dimension and expects
906
- # input to be in the shape of (N, C, *). In this case, we want
907
- # to norm along the hidden dimension, so we need to move
908
- # (batch_size, sequence_length, hidden_size) ->
909
- # (batch_size, hidden_size, sequence_length)
910
- encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
911
- encoder_hidden_states = self.norm_cross(encoder_hidden_states)
912
- encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
913
- else:
914
- assert False
915
-
916
- return encoder_hidden_states
917
-
918
- def _init_compress(self):
919
- self.sr.bias.data.zero_()
920
- self.norm = nn.LayerNorm(self.inner_dim)
921
-
922
-
923
- class AttnProcessor2_0(nn.Module):
924
- r"""
925
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
926
- """
927
-
928
- def __init__(self, attention_mode="xformers", use_rope=False, interpolation_scale_thw=None):
929
- super().__init__()
930
- self.attention_mode = attention_mode
931
- self.use_rope = use_rope
932
- self.interpolation_scale_thw = interpolation_scale_thw
933
-
934
- if self.use_rope:
935
- self._init_rope(interpolation_scale_thw)
936
-
937
- if not hasattr(F, "scaled_dot_product_attention"):
938
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
939
-
940
- def _init_rope(self, interpolation_scale_thw):
941
- self.rope = RoPE3D(interpolation_scale_thw=interpolation_scale_thw)
942
- self.position_getter = PositionGetter3D()
943
-
944
- def __call__(
945
- self,
946
- attn: Attention,
947
- hidden_states: torch.FloatTensor,
948
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
949
- attention_mask: Optional[torch.FloatTensor] = None,
950
- temb: Optional[torch.FloatTensor] = None,
951
- frame: int = 8,
952
- height: int = 16,
953
- width: int = 16,
954
- ) -> torch.FloatTensor:
955
-
956
- residual = hidden_states
957
-
958
- if attn.spatial_norm is not None:
959
- hidden_states = attn.spatial_norm(hidden_states, temb)
960
-
961
- input_ndim = hidden_states.ndim
962
-
963
- if input_ndim == 4:
964
- batch_size, channel, height, width = hidden_states.shape
965
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
966
-
967
-
968
- batch_size, sequence_length, _ = (
969
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
970
- )
971
-
972
- if attention_mask is not None and self.attention_mode == 'xformers':
973
- attention_heads = attn.heads
974
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, head_size=attention_heads)
975
- attention_mask = attention_mask.view(batch_size, attention_heads, -1, attention_mask.shape[-1])
976
- else:
977
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
978
- # scaled_dot_product_attention expects attention_mask shape to be
979
- # (batch, heads, source_length, target_length)
980
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
981
-
982
- if attn.group_norm is not None:
983
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
984
-
985
- query = attn.to_q(hidden_states)
986
-
987
- if encoder_hidden_states is None:
988
- encoder_hidden_states = hidden_states
989
- elif attn.norm_cross:
990
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
991
-
992
- key = attn.to_k(encoder_hidden_states)
993
- value = attn.to_v(encoder_hidden_states)
994
-
995
-
996
-
997
- attn_heads = attn.heads
998
-
999
- inner_dim = key.shape[-1]
1000
- head_dim = inner_dim // attn_heads
1001
-
1002
- query = query.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
1003
- key = key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
1004
- value = value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
1005
-
1006
-
1007
- if self.use_rope:
1008
- # require the shape of (batch_size x nheads x ntokens x dim)
1009
- pos_thw = self.position_getter(batch_size, t=frame, h=height, w=width, device=query.device)
1010
- query = self.rope(query, pos_thw)
1011
- key = self.rope(key, pos_thw)
1012
-
1013
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
1014
- # TODO: add support for attn.scale when we move to Torch 2.1
1015
- if self.attention_mode == 'flash':
1016
- # assert attention_mask is None, 'flash-attn do not support attention_mask'
1017
- with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
1018
- hidden_states = F.scaled_dot_product_attention(
1019
- query, key, value, dropout_p=0.0, is_causal=False
1020
- )
1021
- elif self.attention_mode == 'xformers':
1022
- with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
1023
- hidden_states = F.scaled_dot_product_attention(
1024
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1025
- )
1026
-
1027
-
1028
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim)
1029
- hidden_states = hidden_states.to(query.dtype)
1030
-
1031
- # linear proj
1032
- hidden_states = attn.to_out[0](hidden_states)
1033
- # dropout
1034
- hidden_states = attn.to_out[1](hidden_states)
1035
-
1036
- if input_ndim == 4:
1037
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1038
-
1039
- if attn.residual_connection:
1040
- hidden_states = hidden_states + residual
1041
-
1042
- hidden_states = hidden_states / attn.rescale_output_factor
1043
-
1044
- return hidden_states
1045
-
1046
- class FeedForward(nn.Module):
1047
- r"""
1048
- A feed-forward layer.
1049
-
1050
- Parameters:
1051
- dim (`int`): The number of channels in the input.
1052
- dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
1053
- mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
1054
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
1055
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
1056
- final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
1057
- """
1058
-
1059
- def __init__(
1060
- self,
1061
- dim: int,
1062
- dim_out: Optional[int] = None,
1063
- mult: int = 4,
1064
- dropout: float = 0.0,
1065
- activation_fn: str = "geglu",
1066
- final_dropout: bool = False,
1067
- ):
1068
- super().__init__()
1069
- inner_dim = int(dim * mult)
1070
- dim_out = dim_out if dim_out is not None else dim
1071
- linear_cls = nn.Linear
1072
-
1073
- if activation_fn == "gelu":
1074
- act_fn = GELU(dim, inner_dim)
1075
- if activation_fn == "gelu-approximate":
1076
- act_fn = GELU(dim, inner_dim, approximate="tanh")
1077
- elif activation_fn == "geglu":
1078
- act_fn = GEGLU(dim, inner_dim)
1079
- elif activation_fn == "geglu-approximate":
1080
- act_fn = ApproximateGELU(dim, inner_dim)
1081
-
1082
- self.net = nn.ModuleList([])
1083
- # project in
1084
- self.net.append(act_fn)
1085
- # project dropout
1086
- self.net.append(nn.Dropout(dropout))
1087
- # project out
1088
- self.net.append(linear_cls(inner_dim, dim_out))
1089
- # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
1090
- if final_dropout:
1091
- self.net.append(nn.Dropout(dropout))
1092
-
1093
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1094
- for module in self.net:
1095
- hidden_states = module(hidden_states)
1096
- return hidden_states
1097
-
1098
-
1099
- @maybe_allow_in_graph
1100
- class BasicTransformerBlock(nn.Module):
1101
- r"""
1102
- A basic Transformer block.
1103
-
1104
- Parameters:
1105
- dim (`int`): The number of channels in the input and output.
1106
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
1107
- attention_head_dim (`int`): The number of channels in each head.
1108
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
1109
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
1110
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
1111
- num_embeds_ada_norm (:
1112
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
1113
- attention_bias (:
1114
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
1115
- only_cross_attention (`bool`, *optional*):
1116
- Whether to use only cross-attention layers. In this case two cross attention layers are used.
1117
- double_self_attention (`bool`, *optional*):
1118
- Whether to use two self-attention layers. In this case no cross attention layers are used.
1119
- upcast_attention (`bool`, *optional*):
1120
- Whether to upcast the attention computation to float32. This is useful for mixed precision training.
1121
- norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
1122
- Whether to use learnable elementwise affine parameters for normalization.
1123
- norm_type (`str`, *optional*, defaults to `"layer_norm"`):
1124
- The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
1125
- final_dropout (`bool` *optional*, defaults to False):
1126
- Whether to apply a final dropout after the last feed-forward layer.
1127
- positional_embeddings (`str`, *optional*, defaults to `None`):
1128
- The type of positional embeddings to apply to.
1129
- num_positional_embeddings (`int`, *optional*, defaults to `None`):
1130
- The maximum number of positional embeddings to apply.
1131
- """
1132
-
1133
- def __init__(
1134
- self,
1135
- dim: int,
1136
- num_attention_heads: int,
1137
- attention_head_dim: int,
1138
- dropout=0.0,
1139
- cross_attention_dim: Optional[int] = None,
1140
- activation_fn: str = "geglu",
1141
- num_embeds_ada_norm: Optional[int] = None,
1142
- attention_bias: bool = False,
1143
- only_cross_attention: bool = False,
1144
- double_self_attention: bool = False,
1145
- upcast_attention: bool = False,
1146
- norm_elementwise_affine: bool = True,
1147
- norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
1148
- norm_eps: float = 1e-5,
1149
- final_dropout: bool = False,
1150
- positional_embeddings: Optional[str] = None,
1151
- num_positional_embeddings: Optional[int] = None,
1152
- sa_attention_mode: str = "flash",
1153
- ca_attention_mode: str = "xformers",
1154
- use_rope: bool = False,
1155
- interpolation_scale_thw: Tuple[int] = (1, 1, 1),
1156
- block_idx: Optional[int] = None,
1157
- ):
1158
- super().__init__()
1159
- self.only_cross_attention = only_cross_attention
1160
-
1161
- self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
1162
- self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
1163
- self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
1164
- self.use_layer_norm = norm_type == "layer_norm"
1165
-
1166
- if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
1167
- raise ValueError(
1168
- f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
1169
- f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
1170
- )
1171
-
1172
- if positional_embeddings and (num_positional_embeddings is None):
1173
- raise ValueError(
1174
- "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
1175
- )
1176
-
1177
- if positional_embeddings == "sinusoidal":
1178
- self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
1179
- else:
1180
- self.pos_embed = None
1181
-
1182
- # Define 3 blocks. Each block has its own normalization layer.
1183
- # 1. Self-Attn
1184
- if self.use_ada_layer_norm:
1185
- self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
1186
- elif self.use_ada_layer_norm_zero:
1187
- self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
1188
- else:
1189
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1190
-
1191
- self.attn1 = Attention(
1192
- query_dim=dim,
1193
- heads=num_attention_heads,
1194
- dim_head=attention_head_dim,
1195
- dropout=dropout,
1196
- bias=attention_bias,
1197
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
1198
- upcast_attention=upcast_attention,
1199
- attention_mode=sa_attention_mode,
1200
- use_rope=use_rope,
1201
- interpolation_scale_thw=interpolation_scale_thw,
1202
- )
1203
-
1204
- # 2. Cross-Attn
1205
- if cross_attention_dim is not None or double_self_attention:
1206
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
1207
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
1208
- # the second cross attention block.
1209
- self.norm2 = (
1210
- AdaLayerNorm(dim, num_embeds_ada_norm)
1211
- if self.use_ada_layer_norm
1212
- else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1213
- )
1214
- self.attn2 = Attention(
1215
- query_dim=dim,
1216
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
1217
- heads=num_attention_heads,
1218
- dim_head=attention_head_dim,
1219
- dropout=dropout,
1220
- bias=attention_bias,
1221
- upcast_attention=upcast_attention,
1222
- attention_mode=ca_attention_mode, # only xformers support attention_mask
1223
- use_rope=False, # do not position in cross attention
1224
- interpolation_scale_thw=interpolation_scale_thw,
1225
- ) # is self-attn if encoder_hidden_states is none
1226
- else:
1227
- self.norm2 = None
1228
- self.attn2 = None
1229
-
1230
- # 3. Feed-forward
1231
-
1232
- if not self.use_ada_layer_norm_single:
1233
- self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1234
-
1235
- self.ff = FeedForward(
1236
- dim,
1237
- dropout=dropout,
1238
- activation_fn=activation_fn,
1239
- final_dropout=final_dropout,
1240
- )
1241
-
1242
- # 5. Scale-shift for PixArt-Alpha.
1243
- if self.use_ada_layer_norm_single:
1244
- self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
1245
-
1246
-
1247
- def forward(
1248
- self,
1249
- hidden_states: torch.FloatTensor,
1250
- attention_mask: Optional[torch.FloatTensor] = None,
1251
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1252
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
1253
- timestep: Optional[torch.LongTensor] = None,
1254
- cross_attention_kwargs: Dict[str, Any] = None,
1255
- class_labels: Optional[torch.LongTensor] = None,
1256
- frame: int = None,
1257
- height: int = None,
1258
- width: int = None,
1259
- ) -> torch.FloatTensor:
1260
- # Notice that normalization is always applied before the real computation in the following blocks.
1261
- cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
1262
-
1263
- # 0. Self-Attention
1264
- batch_size = hidden_states.shape[0]
1265
-
1266
- if self.use_ada_layer_norm:
1267
- norm_hidden_states = self.norm1(hidden_states, timestep)
1268
- elif self.use_ada_layer_norm_zero:
1269
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
1270
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
1271
- )
1272
- elif self.use_layer_norm:
1273
- norm_hidden_states = self.norm1(hidden_states)
1274
- elif self.use_ada_layer_norm_single:
1275
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
1276
- self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
1277
- ).chunk(6, dim=1)
1278
- norm_hidden_states = self.norm1(hidden_states)
1279
- norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
1280
- norm_hidden_states = norm_hidden_states.squeeze(1)
1281
- else:
1282
- raise ValueError("Incorrect norm used")
1283
-
1284
- if self.pos_embed is not None:
1285
- norm_hidden_states = self.pos_embed(norm_hidden_states)
1286
-
1287
- attn_output = self.attn1(
1288
- norm_hidden_states,
1289
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
1290
- attention_mask=attention_mask,
1291
- frame=frame,
1292
- height=height,
1293
- width=width,
1294
- **cross_attention_kwargs,
1295
- )
1296
- if self.use_ada_layer_norm_zero:
1297
- attn_output = gate_msa.unsqueeze(1) * attn_output
1298
- elif self.use_ada_layer_norm_single:
1299
- attn_output = gate_msa * attn_output
1300
-
1301
- hidden_states = attn_output + hidden_states
1302
- if hidden_states.ndim == 4:
1303
- hidden_states = hidden_states.squeeze(1)
1304
-
1305
- # 1. Cross-Attention
1306
- if self.attn2 is not None:
1307
-
1308
- if self.use_ada_layer_norm:
1309
- norm_hidden_states = self.norm2(hidden_states, timestep)
1310
- elif self.use_ada_layer_norm_zero or self.use_layer_norm:
1311
- norm_hidden_states = self.norm2(hidden_states)
1312
- elif self.use_ada_layer_norm_single:
1313
- # For PixArt norm2 isn't applied here:
1314
- # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
1315
- norm_hidden_states = hidden_states
1316
- else:
1317
- raise ValueError("Incorrect norm")
1318
-
1319
- if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
1320
- norm_hidden_states = self.pos_embed(norm_hidden_states)
1321
-
1322
- attn_output = self.attn2(
1323
- norm_hidden_states,
1324
- encoder_hidden_states=encoder_hidden_states,
1325
- attention_mask=encoder_attention_mask,
1326
- **cross_attention_kwargs,
1327
- )
1328
- hidden_states = attn_output + hidden_states
1329
-
1330
-
1331
- # 2. Feed-forward
1332
- if not self.use_ada_layer_norm_single:
1333
- norm_hidden_states = self.norm3(hidden_states)
1334
-
1335
- if self.use_ada_layer_norm_zero:
1336
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
1337
-
1338
- if self.use_ada_layer_norm_single:
1339
- norm_hidden_states = self.norm2(hidden_states)
1340
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
1341
-
1342
- ff_output = self.ff(norm_hidden_states)
1343
-
1344
- if self.use_ada_layer_norm_zero:
1345
- ff_output = gate_mlp.unsqueeze(1) * ff_output
1346
- elif self.use_ada_layer_norm_single:
1347
- ff_output = gate_mlp * ff_output
1348
-
1349
-
1350
- hidden_states = ff_output + hidden_states
1351
- if hidden_states.ndim == 4:
1352
- hidden_states = hidden_states.squeeze(1)
1353
-
1354
- return hidden_states
1355
-
1356
-
1357
- class AdaLayerNormSingle(nn.Module):
1358
- r"""
1359
- Norm layer adaptive layer norm single (adaLN-single).
1360
-
1361
- As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
1362
-
1363
- Parameters:
1364
- embedding_dim (`int`): The size of each embedding vector.
1365
- use_additional_conditions (`bool`): To use additional conditions for normalization or not.
1366
- """
1367
-
1368
- def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
1369
- super().__init__()
1370
-
1371
- self.emb = CombinedTimestepSizeEmbeddings(
1372
- embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
1373
- )
1374
-
1375
- self.silu = nn.SiLU()
1376
- self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
1377
-
1378
- def forward(
1379
- self,
1380
- timestep: torch.Tensor,
1381
- added_cond_kwargs: Dict[str, torch.Tensor] = None,
1382
- batch_size: int = None,
1383
- hidden_dtype: Optional[torch.dtype] = None,
1384
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1385
- # No modulation happening here.
1386
- embedded_timestep = self.emb(
1387
- timestep, batch_size=batch_size, hidden_dtype=hidden_dtype, resolution=None, aspect_ratio=None
1388
- )
1389
- return self.linear(self.silu(embedded_timestep)), embedded_timestep
1390
-
1391
-
1392
- @dataclass
1393
- class Transformer3DModelOutput(BaseOutput):
1394
- """
1395
- The output of [`Transformer2DModel`].
1396
-
1397
- Args:
1398
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
1399
- The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
1400
- distributions for the unnoised latent pixels.
1401
- """
1402
-
1403
- sample: torch.FloatTensor
1404
-
1405
-
1406
- class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
1407
- _supports_gradient_checkpointing = True
1408
-
1409
- """
1410
- A 2D Transformer model for image-like data.
1411
-
1412
- Parameters:
1413
- num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
1414
- attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
1415
- in_channels (`int`, *optional*):
1416
- The number of channels in the input and output (specify if the input is **continuous**).
1417
- num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
1418
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
1419
- cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
1420
- sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
1421
- This is fixed during training since it is used to learn a number of position embeddings.
1422
- num_vector_embeds (`int`, *optional*):
1423
- The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
1424
- Includes the class for the masked latent pixel.
1425
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
1426
- num_embeds_ada_norm ( `int`, *optional*):
1427
- The number of diffusion steps used during training. Pass if at least one of the norm_layers is
1428
- `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
1429
- added to the hidden states.
1430
-
1431
- During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
1432
- attention_bias (`bool`, *optional*):
1433
- Configure if the `TransformerBlocks` attention should contain a bias parameter.
1434
- """
1435
-
1436
- @register_to_config
1437
- def __init__(
1438
- self,
1439
- num_attention_heads: int = 16,
1440
- attention_head_dim: int = 88,
1441
- in_channels: Optional[int] = None,
1442
- out_channels: Optional[int] = None,
1443
- num_layers: int = 1,
1444
- dropout: float = 0.0,
1445
- cross_attention_dim: Optional[int] = None,
1446
- attention_bias: bool = False,
1447
- sample_size: Optional[int] = None,
1448
- sample_size_t: Optional[int] = None,
1449
- patch_size: Optional[int] = None,
1450
- patch_size_t: Optional[int] = None,
1451
- activation_fn: str = "geglu",
1452
- num_embeds_ada_norm: Optional[int] = None,
1453
- use_linear_projection: bool = False,
1454
- only_cross_attention: bool = False,
1455
- double_self_attention: bool = False,
1456
- upcast_attention: bool = False,
1457
- norm_type: str = "ada_norm",
1458
- norm_elementwise_affine: bool = True,
1459
- norm_eps: float = 1e-5,
1460
- caption_channels: int = None,
1461
- interpolation_scale_h: float = None,
1462
- interpolation_scale_w: float = None,
1463
- interpolation_scale_t: float = None,
1464
- use_additional_conditions: Optional[bool] = None,
1465
- sa_attention_mode: str = "flash",
1466
- ca_attention_mode: str = 'xformers',
1467
- downsampler: str = None,
1468
- use_rope: bool = False,
1469
- model_max_length: int = 300,
1470
- ):
1471
- super().__init__()
1472
- self.use_linear_projection = use_linear_projection
1473
- self.interpolation_scale_t = interpolation_scale_t
1474
- self.interpolation_scale_h = interpolation_scale_h
1475
- self.interpolation_scale_w = interpolation_scale_w
1476
- self.downsampler = downsampler
1477
- self.caption_channels = caption_channels
1478
- self.num_attention_heads = num_attention_heads
1479
- self.attention_head_dim = attention_head_dim
1480
- inner_dim = num_attention_heads * attention_head_dim
1481
- self.inner_dim = inner_dim
1482
- self.in_channels = in_channels
1483
- self.out_channels = in_channels if out_channels is None else out_channels
1484
- self.use_rope = use_rope
1485
- self.model_max_length = model_max_length
1486
- self.num_layers = num_layers
1487
- self.config.hidden_size = inner_dim
1488
-
1489
-
1490
- # 1. Transformer3DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
1491
- # Define whether input is continuous or discrete depending on configuration
1492
- assert in_channels is not None and patch_size is not None
1493
-
1494
- # 2. Initialize the right blocks.
1495
- # Initialize the output blocks and other projection blocks when necessary.
1496
-
1497
- assert self.config.sample_size_t is not None, "AllegroTransformer3DModel over patched input must provide sample_size_t"
1498
- assert self.config.sample_size is not None, "AllegroTransformer3DModel over patched input must provide sample_size"
1499
- #assert not (self.config.sample_size_t == 1 and self.config.patch_size_t == 2), "Image do not need patchfy in t-dim"
1500
-
1501
- self.num_frames = self.config.sample_size_t
1502
- self.config.sample_size = to_2tuple(self.config.sample_size)
1503
- self.height = self.config.sample_size[0]
1504
- self.width = self.config.sample_size[1]
1505
- self.patch_size_t = self.config.patch_size_t
1506
- self.patch_size = self.config.patch_size
1507
- interpolation_scale_t = ((self.config.sample_size_t - 1) // 16 + 1) if self.config.sample_size_t % 2 == 1 else self.config.sample_size_t / 16
1508
- interpolation_scale_t = (
1509
- self.config.interpolation_scale_t if self.config.interpolation_scale_t is not None else interpolation_scale_t
1510
- )
1511
- interpolation_scale = (
1512
- self.config.interpolation_scale_h if self.config.interpolation_scale_h is not None else self.config.sample_size[0] / 30,
1513
- self.config.interpolation_scale_w if self.config.interpolation_scale_w is not None else self.config.sample_size[1] / 40,
1514
- )
1515
- self.pos_embed = PatchEmbed2D(
1516
- num_frames=self.config.sample_size_t,
1517
- height=self.config.sample_size[0],
1518
- width=self.config.sample_size[1],
1519
- patch_size_t=self.config.patch_size_t,
1520
- patch_size=self.config.patch_size,
1521
- in_channels=self.in_channels,
1522
- embed_dim=self.inner_dim,
1523
- interpolation_scale=interpolation_scale,
1524
- interpolation_scale_t=interpolation_scale_t,
1525
- use_abs_pos=not self.config.use_rope,
1526
- )
1527
- interpolation_scale_thw = (interpolation_scale_t, *interpolation_scale)
1528
-
1529
- # 3. Define transformers blocks, spatial attention
1530
- self.transformer_blocks = nn.ModuleList(
1531
- [
1532
- BasicTransformerBlock(
1533
- inner_dim,
1534
- num_attention_heads,
1535
- attention_head_dim,
1536
- dropout=dropout,
1537
- cross_attention_dim=cross_attention_dim,
1538
- activation_fn=activation_fn,
1539
- num_embeds_ada_norm=num_embeds_ada_norm,
1540
- attention_bias=attention_bias,
1541
- only_cross_attention=only_cross_attention,
1542
- double_self_attention=double_self_attention,
1543
- upcast_attention=upcast_attention,
1544
- norm_type=norm_type,
1545
- norm_elementwise_affine=norm_elementwise_affine,
1546
- norm_eps=norm_eps,
1547
- sa_attention_mode=sa_attention_mode,
1548
- ca_attention_mode=ca_attention_mode,
1549
- use_rope=use_rope,
1550
- interpolation_scale_thw=interpolation_scale_thw,
1551
- block_idx=d,
1552
- )
1553
- for d in range(num_layers)
1554
- ]
1555
- )
1556
-
1557
- # 4. Define output layers
1558
-
1559
- if norm_type != "ada_norm_single":
1560
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
1561
- self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
1562
- self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
1563
- elif norm_type == "ada_norm_single":
1564
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
1565
- self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
1566
- self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
1567
-
1568
- # 5. PixArt-Alpha blocks.
1569
- self.adaln_single = None
1570
- self.use_additional_conditions = False
1571
- if norm_type == "ada_norm_single":
1572
- # self.use_additional_conditions = self.config.sample_size[0] == 128 # False, 128 -> 1024
1573
- # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
1574
- # additional conditions until we find better name
1575
- self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
1576
-
1577
- self.caption_projection = None
1578
- if caption_channels is not None:
1579
- self.caption_projection = PixArtAlphaTextProjection(
1580
- in_features=caption_channels, hidden_size=inner_dim
1581
- )
1582
-
1583
- self.gradient_checkpointing = False
1584
-
1585
- def _set_gradient_checkpointing(self, module, value=False):
1586
- self.gradient_checkpointing = value
1587
-
1588
-
1589
- def forward(
1590
- self,
1591
- hidden_states: torch.Tensor,
1592
- timestep: Optional[torch.LongTensor] = None,
1593
- encoder_hidden_states: Optional[torch.Tensor] = None,
1594
- added_cond_kwargs: Dict[str, torch.Tensor] = None,
1595
- class_labels: Optional[torch.LongTensor] = None,
1596
- cross_attention_kwargs: Dict[str, Any] = None,
1597
- attention_mask: Optional[torch.Tensor] = None,
1598
- encoder_attention_mask: Optional[torch.Tensor] = None,
1599
- return_dict: bool = True,
1600
- ):
1601
- """
1602
- The [`Transformer2DModel`] forward method.
1603
-
1604
- Args:
1605
- hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, frame, channel, height, width)` if continuous):
1606
- Input `hidden_states`.
1607
- encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
1608
- Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
1609
- self-attention.
1610
- timestep ( `torch.LongTensor`, *optional*):
1611
- Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
1612
- class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
1613
- Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
1614
- `AdaLayerZeroNorm`.
1615
- added_cond_kwargs ( `Dict[str, Any]`, *optional*):
1616
- A kwargs dictionary that if specified is passed along to the `AdaLayerNormSingle`
1617
- cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
1618
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1619
- `self.processor` in
1620
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1621
- attention_mask ( `torch.Tensor`, *optional*):
1622
- An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
1623
- is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
1624
- negative values to the attention scores corresponding to "discard" tokens.
1625
- encoder_attention_mask ( `torch.Tensor`, *optional*):
1626
- Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
1627
-
1628
- * Mask `(batch, sequence_length)` True = keep, False = discard.
1629
- * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
1630
-
1631
- If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
1632
- above. This bias will be added to the cross-attention scores.
1633
- return_dict (`bool`, *optional*, defaults to `True`):
1634
- Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
1635
- tuple.
1636
-
1637
- Returns:
1638
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
1639
- `tuple` where the first element is the sample tensor.
1640
- """
1641
- batch_size, c, frame, h, w = hidden_states.shape
1642
-
1643
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
1644
- # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
1645
- # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
1646
- # expects mask of shape:
1647
- # [batch, key_tokens]
1648
- # adds singleton query_tokens dimension:
1649
- # [batch, 1, key_tokens]
1650
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
1651
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
1652
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) attention_mask_vid, attention_mask_img = None, None
1653
- if attention_mask is not None and attention_mask.ndim == 4:
1654
- # assume that mask is expressed as:
1655
- # (1 = keep, 0 = discard)
1656
- # convert mask into a bias that can be added to attention scores:
1657
- # (keep = +0, discard = -10000.0)
1658
- # b, frame+use_image_num, h, w -> a video with images
1659
- # b, 1, h, w -> only images
1660
- attention_mask = attention_mask.to(self.dtype)
1661
- attention_mask_vid = attention_mask[:, :frame] # b, frame, h, w
1662
-
1663
- if attention_mask_vid.numel() > 0:
1664
- attention_mask_vid = attention_mask_vid.unsqueeze(1) # b 1 t h w
1665
- attention_mask_vid = F.max_pool3d(attention_mask_vid, kernel_size=(self.patch_size_t, self.patch_size, self.patch_size),
1666
- stride=(self.patch_size_t, self.patch_size, self.patch_size))
1667
- attention_mask_vid = rearrange(attention_mask_vid, 'b 1 t h w -> (b 1) 1 (t h w)')
1668
-
1669
- attention_mask_vid = (1 - attention_mask_vid.bool().to(self.dtype)) * -10000.0 if attention_mask_vid.numel() > 0 else None
1670
-
1671
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
1672
- if encoder_attention_mask is not None and encoder_attention_mask.ndim == 3:
1673
- # b, 1+use_image_num, l -> a video with images
1674
- # b, 1, l -> only images
1675
- encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0
1676
- encoder_attention_mask_vid = rearrange(encoder_attention_mask, 'b 1 l -> (b 1) 1 l') if encoder_attention_mask.numel() > 0 else None
1677
-
1678
- # 1. Input
1679
- frame = frame // self.patch_size_t # patchfy
1680
- # print('frame', frame)
1681
- height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
1682
-
1683
- added_cond_kwargs = {"resolution": None, "aspect_ratio": None} if added_cond_kwargs is None else added_cond_kwargs
1684
- hidden_states, encoder_hidden_states_vid, \
1685
- timestep_vid, embedded_timestep_vid = self._operate_on_patched_inputs(
1686
- hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, batch_size,
1687
- )
1688
-
1689
-
1690
- for _, block in enumerate(self.transformer_blocks):
1691
- hidden_states = block(
1692
- hidden_states,
1693
- attention_mask_vid,
1694
- encoder_hidden_states_vid,
1695
- encoder_attention_mask_vid,
1696
- timestep_vid,
1697
- cross_attention_kwargs,
1698
- class_labels,
1699
- frame=frame,
1700
- height=height,
1701
- width=width,
1702
- )
1703
-
1704
- # 3. Output
1705
- output = None
1706
- if hidden_states is not None:
1707
- output = self._get_output_for_patched_inputs(
1708
- hidden_states=hidden_states,
1709
- timestep=timestep_vid,
1710
- class_labels=class_labels,
1711
- embedded_timestep=embedded_timestep_vid,
1712
- num_frames=frame,
1713
- height=height,
1714
- width=width,
1715
- ) # b c t h w
1716
-
1717
- if not return_dict:
1718
- return (output,)
1719
-
1720
- return Transformer3DModelOutput(sample=output)
1721
-
1722
- def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, batch_size):
1723
- # batch_size = hidden_states.shape[0]
1724
- hidden_states_vid = self.pos_embed(hidden_states.to(self.dtype))
1725
- timestep_vid = None
1726
- embedded_timestep_vid = None
1727
- encoder_hidden_states_vid = None
1728
-
1729
- if self.adaln_single is not None:
1730
- if self.use_additional_conditions and added_cond_kwargs is None:
1731
- raise ValueError(
1732
- "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
1733
- )
1734
- timestep, embedded_timestep = self.adaln_single(
1735
- timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=self.dtype
1736
- ) # b 6d, b d
1737
-
1738
- timestep_vid = timestep
1739
- embedded_timestep_vid = embedded_timestep
1740
-
1741
- if self.caption_projection is not None:
1742
- encoder_hidden_states = self.caption_projection(encoder_hidden_states) # b, 1+use_image_num, l, d or b, 1, l, d
1743
- encoder_hidden_states_vid = rearrange(encoder_hidden_states[:, :1], 'b 1 l d -> (b 1) l d')
1744
-
1745
- return hidden_states_vid, encoder_hidden_states_vid, timestep_vid, embedded_timestep_vid
1746
-
1747
- def _get_output_for_patched_inputs(
1748
- self, hidden_states, timestep, class_labels, embedded_timestep, num_frames, height=None, width=None
1749
- ):
1750
- # import ipdb;ipdb.set_trace()
1751
- if self.config.norm_type != "ada_norm_single":
1752
- conditioning = self.transformer_blocks[0].norm1.emb(
1753
- timestep, class_labels, hidden_dtype=self.dtype
1754
- )
1755
- shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
1756
- hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
1757
- hidden_states = self.proj_out_2(hidden_states)
1758
- elif self.config.norm_type == "ada_norm_single":
1759
- shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
1760
- hidden_states = self.norm_out(hidden_states)
1761
- # Modulation
1762
- hidden_states = hidden_states * (1 + scale) + shift
1763
- hidden_states = self.proj_out(hidden_states)
1764
- hidden_states = hidden_states.squeeze(1)
1765
-
1766
- # unpatchify
1767
- if self.adaln_single is None:
1768
- height = width = int(hidden_states.shape[1] ** 0.5)
1769
- hidden_states = hidden_states.reshape(
1770
- shape=(-1, num_frames, height, width, self.patch_size_t, self.patch_size, self.patch_size, self.out_channels)
1771
- )
1772
- hidden_states = torch.einsum("nthwopqc->nctohpwq", hidden_states)
1773
- output = hidden_states.reshape(
1774
- shape=(-1, self.out_channels, num_frames * self.patch_size_t, height * self.patch_size, width * self.patch_size)
1775
- )
1776
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vae/config.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "_class_name": "AllegroAutoencoderKL3D",
3
- "_diffusers_version": "0.30.3",
4
  "act_fn": "silu",
5
  "block_out_channels": [
6
  128,
 
1
  {
2
  "_class_name": "AllegroAutoencoderKL3D",
3
+ "_diffusers_version": "0.28.0",
4
  "act_fn": "silu",
5
  "block_out_channels": [
6
  128,
vae/vae_allegro.py DELETED
@@ -1,978 +0,0 @@
1
- import math
2
- from dataclasses import dataclass
3
- import os
4
- from typing import Dict, Optional, Tuple, Union
5
- from einops import rearrange
6
-
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
-
11
- from diffusers.configuration_utils import ConfigMixin, register_to_config
12
- from diffusers.models.modeling_utils import ModelMixin
13
- from diffusers.models.modeling_outputs import AutoencoderKLOutput
14
- from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
15
- from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
16
- from diffusers.models.attention_processor import Attention
17
- from diffusers.models.resnet import ResnetBlock2D
18
- from diffusers.models.upsampling import Upsample2D
19
- from diffusers.models.downsampling import Downsample2D
20
- from diffusers.models.attention_processor import SpatialNorm
21
-
22
-
23
- class TemporalConvBlock(nn.Module):
24
- """
25
- Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
26
- https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
27
- """
28
-
29
- def __init__(self, in_dim, out_dim=None, dropout=0.0, up_sample=False, down_sample=False, spa_stride=1):
30
- super().__init__()
31
- out_dim = out_dim or in_dim
32
- self.in_dim = in_dim
33
- self.out_dim = out_dim
34
- spa_pad = int((spa_stride-1)*0.5)
35
- temp_pad = 0
36
- self.temp_pad = temp_pad
37
-
38
- if down_sample:
39
- self.conv1 = nn.Sequential(
40
- nn.GroupNorm(32, in_dim),
41
- nn.SiLU(),
42
- nn.Conv3d(in_dim, out_dim, (2, spa_stride, spa_stride), stride=(2,1,1), padding=(0, spa_pad, spa_pad))
43
- )
44
- elif up_sample:
45
- self.conv1 = nn.Sequential(
46
- nn.GroupNorm(32, in_dim),
47
- nn.SiLU(),
48
- nn.Conv3d(in_dim, out_dim*2, (1, spa_stride, spa_stride), padding=(0, spa_pad, spa_pad))
49
- )
50
- else:
51
- self.conv1 = nn.Sequential(
52
- nn.GroupNorm(32, in_dim),
53
- nn.SiLU(),
54
- nn.Conv3d(in_dim, out_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad))
55
- )
56
- self.conv2 = nn.Sequential(
57
- nn.GroupNorm(32, out_dim),
58
- nn.SiLU(),
59
- nn.Dropout(dropout),
60
- nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)),
61
- )
62
- self.conv3 = nn.Sequential(
63
- nn.GroupNorm(32, out_dim),
64
- nn.SiLU(),
65
- nn.Dropout(dropout),
66
- nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)),
67
- )
68
- self.conv4 = nn.Sequential(
69
- nn.GroupNorm(32, out_dim),
70
- nn.SiLU(),
71
- nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)),
72
- )
73
-
74
- # zero out the last layer params,so the conv block is identity
75
- nn.init.zeros_(self.conv4[-1].weight)
76
- nn.init.zeros_(self.conv4[-1].bias)
77
-
78
- self.down_sample = down_sample
79
- self.up_sample = up_sample
80
-
81
-
82
- def forward(self, hidden_states):
83
- identity = hidden_states
84
-
85
- if self.down_sample:
86
- identity = identity[:,:,::2]
87
- elif self.up_sample:
88
- hidden_states_new = torch.cat((hidden_states,hidden_states),dim=2)
89
- hidden_states_new[:, :, 0::2] = hidden_states
90
- hidden_states_new[:, :, 1::2] = hidden_states
91
- identity = hidden_states_new
92
- del hidden_states_new
93
-
94
- if self.down_sample or self.up_sample:
95
- hidden_states = self.conv1(hidden_states)
96
- else:
97
- hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2)
98
- hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2)
99
- hidden_states = self.conv1(hidden_states)
100
-
101
-
102
- if self.up_sample:
103
- hidden_states = rearrange(hidden_states, 'b (d c) f h w -> b c (f d) h w', d=2)
104
-
105
- hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2)
106
- hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2)
107
- hidden_states = self.conv2(hidden_states)
108
- hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2)
109
- hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2)
110
- hidden_states = self.conv3(hidden_states)
111
- hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2)
112
- hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2)
113
- hidden_states = self.conv4(hidden_states)
114
-
115
- hidden_states = identity + hidden_states
116
-
117
- return hidden_states
118
-
119
-
120
- class DownEncoderBlock3D(nn.Module):
121
- def __init__(
122
- self,
123
- in_channels: int,
124
- out_channels: int,
125
- dropout: float = 0.0,
126
- num_layers: int = 1,
127
- resnet_eps: float = 1e-6,
128
- resnet_time_scale_shift: str = "default",
129
- resnet_act_fn: str = "swish",
130
- resnet_groups: int = 32,
131
- resnet_pre_norm: bool = True,
132
- output_scale_factor=1.0,
133
- add_downsample=True,
134
- add_temp_downsample=False,
135
- downsample_padding=1,
136
- ):
137
- super().__init__()
138
- resnets = []
139
- temp_convs = []
140
-
141
- for i in range(num_layers):
142
- in_channels = in_channels if i == 0 else out_channels
143
- resnets.append(
144
- ResnetBlock2D(
145
- in_channels=in_channels,
146
- out_channels=out_channels,
147
- temb_channels=None,
148
- eps=resnet_eps,
149
- groups=resnet_groups,
150
- dropout=dropout,
151
- time_embedding_norm=resnet_time_scale_shift,
152
- non_linearity=resnet_act_fn,
153
- output_scale_factor=output_scale_factor,
154
- pre_norm=resnet_pre_norm,
155
- )
156
- )
157
- temp_convs.append(
158
- TemporalConvBlock(
159
- out_channels,
160
- out_channels,
161
- dropout=0.1,
162
- )
163
- )
164
-
165
- self.resnets = nn.ModuleList(resnets)
166
- self.temp_convs = nn.ModuleList(temp_convs)
167
-
168
- if add_temp_downsample:
169
- self.temp_convs_down = TemporalConvBlock(
170
- out_channels,
171
- out_channels,
172
- dropout=0.1,
173
- down_sample=True,
174
- spa_stride=3
175
- )
176
- self.add_temp_downsample = add_temp_downsample
177
-
178
- if add_downsample:
179
- self.downsamplers = nn.ModuleList(
180
- [
181
- Downsample2D(
182
- out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
183
- )
184
- ]
185
- )
186
- else:
187
- self.downsamplers = None
188
-
189
- def _set_partial_grad(self):
190
- for temp_conv in self.temp_convs:
191
- temp_conv.requires_grad_(True)
192
- if self.downsamplers:
193
- for down_layer in self.downsamplers:
194
- down_layer.requires_grad_(True)
195
-
196
- def forward(self, hidden_states):
197
- bz = hidden_states.shape[0]
198
-
199
- for resnet, temp_conv in zip(self.resnets, self.temp_convs):
200
- hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w')
201
- hidden_states = resnet(hidden_states, temb=None)
202
- hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz)
203
- hidden_states = temp_conv(hidden_states)
204
- if self.add_temp_downsample:
205
- hidden_states = self.temp_convs_down(hidden_states)
206
-
207
- if self.downsamplers is not None:
208
- hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w')
209
- for upsampler in self.downsamplers:
210
- hidden_states = upsampler(hidden_states)
211
- hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz)
212
- return hidden_states
213
-
214
-
215
- class UpDecoderBlock3D(nn.Module):
216
- def __init__(
217
- self,
218
- in_channels: int,
219
- out_channels: int,
220
- dropout: float = 0.0,
221
- num_layers: int = 1,
222
- resnet_eps: float = 1e-6,
223
- resnet_time_scale_shift: str = "default", # default, spatial
224
- resnet_act_fn: str = "swish",
225
- resnet_groups: int = 32,
226
- resnet_pre_norm: bool = True,
227
- output_scale_factor=1.0,
228
- add_upsample=True,
229
- add_temp_upsample=False,
230
- temb_channels=None,
231
- ):
232
- super().__init__()
233
- self.add_upsample = add_upsample
234
-
235
- resnets = []
236
- temp_convs = []
237
-
238
- for i in range(num_layers):
239
- input_channels = in_channels if i == 0 else out_channels
240
-
241
- resnets.append(
242
- ResnetBlock2D(
243
- in_channels=input_channels,
244
- out_channels=out_channels,
245
- temb_channels=temb_channels,
246
- eps=resnet_eps,
247
- groups=resnet_groups,
248
- dropout=dropout,
249
- time_embedding_norm=resnet_time_scale_shift,
250
- non_linearity=resnet_act_fn,
251
- output_scale_factor=output_scale_factor,
252
- pre_norm=resnet_pre_norm,
253
- )
254
- )
255
- temp_convs.append(
256
- TemporalConvBlock(
257
- out_channels,
258
- out_channels,
259
- dropout=0.1,
260
- )
261
- )
262
-
263
- self.resnets = nn.ModuleList(resnets)
264
- self.temp_convs = nn.ModuleList(temp_convs)
265
-
266
- self.add_temp_upsample = add_temp_upsample
267
- if add_temp_upsample:
268
- self.temp_conv_up = TemporalConvBlock(
269
- out_channels,
270
- out_channels,
271
- dropout=0.1,
272
- up_sample=True,
273
- spa_stride=3
274
- )
275
-
276
-
277
- if self.add_upsample:
278
- # self.upsamplers = nn.ModuleList([PSUpsample2D(out_channels, use_conv=True, use_pixel_shuffle=True, out_channels=out_channels)])
279
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
280
- else:
281
- self.upsamplers = None
282
-
283
- def _set_partial_grad(self):
284
- for temp_conv in self.temp_convs:
285
- temp_conv.requires_grad_(True)
286
- if self.add_upsample:
287
- self.upsamplers.requires_grad_(True)
288
-
289
- def forward(self, hidden_states):
290
- bz = hidden_states.shape[0]
291
-
292
- for resnet, temp_conv in zip(self.resnets, self.temp_convs):
293
- hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w')
294
- hidden_states = resnet(hidden_states, temb=None)
295
- hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz)
296
- hidden_states = temp_conv(hidden_states)
297
- if self.add_temp_upsample:
298
- hidden_states = self.temp_conv_up(hidden_states)
299
-
300
- if self.upsamplers is not None:
301
- hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w')
302
- for upsampler in self.upsamplers:
303
- hidden_states = upsampler(hidden_states)
304
- hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz)
305
- return hidden_states
306
-
307
-
308
- class UNetMidBlock3DConv(nn.Module):
309
- def __init__(
310
- self,
311
- in_channels: int,
312
- temb_channels: int,
313
- dropout: float = 0.0,
314
- num_layers: int = 1,
315
- resnet_eps: float = 1e-6,
316
- resnet_time_scale_shift: str = "default", # default, spatial
317
- resnet_act_fn: str = "swish",
318
- resnet_groups: int = 32,
319
- resnet_pre_norm: bool = True,
320
- add_attention: bool = True,
321
- attention_head_dim=1,
322
- output_scale_factor=1.0,
323
- ):
324
- super().__init__()
325
- resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
326
- self.add_attention = add_attention
327
-
328
- # there is always at least one resnet
329
- resnets = [
330
- ResnetBlock2D(
331
- in_channels=in_channels,
332
- out_channels=in_channels,
333
- temb_channels=temb_channels,
334
- eps=resnet_eps,
335
- groups=resnet_groups,
336
- dropout=dropout,
337
- time_embedding_norm=resnet_time_scale_shift,
338
- non_linearity=resnet_act_fn,
339
- output_scale_factor=output_scale_factor,
340
- pre_norm=resnet_pre_norm,
341
- )
342
- ]
343
- temp_convs = [
344
- TemporalConvBlock(
345
- in_channels,
346
- in_channels,
347
- dropout=0.1,
348
- )
349
- ]
350
- attentions = []
351
-
352
- if attention_head_dim is None:
353
- attention_head_dim = in_channels
354
-
355
- for _ in range(num_layers):
356
- if self.add_attention:
357
- attentions.append(
358
- Attention(
359
- in_channels,
360
- heads=in_channels // attention_head_dim,
361
- dim_head=attention_head_dim,
362
- rescale_output_factor=output_scale_factor,
363
- eps=resnet_eps,
364
- norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None,
365
- spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
366
- residual_connection=True,
367
- bias=True,
368
- upcast_softmax=True,
369
- _from_deprecated_attn_block=True,
370
- )
371
- )
372
- else:
373
- attentions.append(None)
374
-
375
- resnets.append(
376
- ResnetBlock2D(
377
- in_channels=in_channels,
378
- out_channels=in_channels,
379
- temb_channels=temb_channels,
380
- eps=resnet_eps,
381
- groups=resnet_groups,
382
- dropout=dropout,
383
- time_embedding_norm=resnet_time_scale_shift,
384
- non_linearity=resnet_act_fn,
385
- output_scale_factor=output_scale_factor,
386
- pre_norm=resnet_pre_norm,
387
- )
388
- )
389
-
390
- temp_convs.append(
391
- TemporalConvBlock(
392
- in_channels,
393
- in_channels,
394
- dropout=0.1,
395
- )
396
- )
397
-
398
- self.resnets = nn.ModuleList(resnets)
399
- self.temp_convs = nn.ModuleList(temp_convs)
400
- self.attentions = nn.ModuleList(attentions)
401
-
402
- def _set_partial_grad(self):
403
- for temp_conv in self.temp_convs:
404
- temp_conv.requires_grad_(True)
405
-
406
- def forward(
407
- self,
408
- hidden_states,
409
- ):
410
- bz = hidden_states.shape[0]
411
- hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w')
412
-
413
- hidden_states = self.resnets[0](hidden_states, temb=None)
414
- hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz)
415
- hidden_states = self.temp_convs[0](hidden_states)
416
- hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w')
417
-
418
- for attn, resnet, temp_conv in zip(
419
- self.attentions, self.resnets[1:], self.temp_convs[1:]
420
- ):
421
- hidden_states = attn(hidden_states)
422
- hidden_states = resnet(hidden_states, temb=None)
423
- hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz)
424
- hidden_states = temp_conv(hidden_states)
425
- return hidden_states
426
-
427
-
428
- class Encoder3D(nn.Module):
429
- def __init__(
430
- self,
431
- in_channels=3,
432
- out_channels=3,
433
- num_blocks=4,
434
- blocks_temp_li=[False, False, False, False],
435
- block_out_channels=(64,),
436
- layers_per_block=2,
437
- norm_num_groups=32,
438
- act_fn="silu",
439
- double_z=True,
440
- ):
441
- super().__init__()
442
- self.layers_per_block = layers_per_block
443
- self.blocks_temp_li = blocks_temp_li
444
-
445
- self.conv_in = nn.Conv2d(
446
- in_channels,
447
- block_out_channels[0],
448
- kernel_size=3,
449
- stride=1,
450
- padding=1,
451
- )
452
-
453
- self.temp_conv_in = nn.Conv3d(
454
- block_out_channels[0],
455
- block_out_channels[0],
456
- (3,1,1),
457
- padding = (1, 0, 0)
458
- )
459
-
460
- self.mid_block = None
461
- self.down_blocks = nn.ModuleList([])
462
-
463
- # down
464
- output_channel = block_out_channels[0]
465
- for i in range(num_blocks):
466
- input_channel = output_channel
467
- output_channel = block_out_channels[i]
468
- is_final_block = i == len(block_out_channels) - 1
469
-
470
- down_block = DownEncoderBlock3D(
471
- num_layers=self.layers_per_block,
472
- in_channels=input_channel,
473
- out_channels=output_channel,
474
- add_downsample=not is_final_block,
475
- add_temp_downsample=blocks_temp_li[i],
476
- resnet_eps=1e-6,
477
- downsample_padding=0,
478
- resnet_act_fn=act_fn,
479
- resnet_groups=norm_num_groups,
480
- )
481
- self.down_blocks.append(down_block)
482
-
483
- # mid
484
- self.mid_block = UNetMidBlock3DConv(
485
- in_channels=block_out_channels[-1],
486
- resnet_eps=1e-6,
487
- resnet_act_fn=act_fn,
488
- output_scale_factor=1,
489
- resnet_time_scale_shift="default",
490
- attention_head_dim=block_out_channels[-1],
491
- resnet_groups=norm_num_groups,
492
- temb_channels=None,
493
- )
494
-
495
- # out
496
- self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
497
- self.conv_act = nn.SiLU()
498
-
499
- conv_out_channels = 2 * out_channels if double_z else out_channels
500
-
501
- self.temp_conv_out = nn.Conv3d(block_out_channels[-1], block_out_channels[-1], (3,1,1), padding = (1, 0, 0))
502
-
503
- self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
504
-
505
- nn.init.zeros_(self.temp_conv_in.weight)
506
- nn.init.zeros_(self.temp_conv_in.bias)
507
- nn.init.zeros_(self.temp_conv_out.weight)
508
- nn.init.zeros_(self.temp_conv_out.bias)
509
-
510
- self.gradient_checkpointing = False
511
-
512
- def forward(self, x):
513
- '''
514
- x: [b, c, (tb f), h, w]
515
- '''
516
- bz = x.shape[0]
517
- sample = rearrange(x, 'b c n h w -> (b n) c h w')
518
- sample = self.conv_in(sample)
519
-
520
- sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz)
521
- temp_sample = sample
522
- sample = self.temp_conv_in(sample)
523
- sample = sample+temp_sample
524
- # down
525
- for b_id, down_block in enumerate(self.down_blocks):
526
- sample = down_block(sample)
527
- # middle
528
- sample = self.mid_block(sample)
529
-
530
- # post-process
531
- sample = rearrange(sample, 'b c n h w -> (b n) c h w')
532
- sample = self.conv_norm_out(sample)
533
- sample = self.conv_act(sample)
534
- sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz)
535
-
536
- temp_sample = sample
537
- sample = self.temp_conv_out(sample)
538
- sample = sample+temp_sample
539
- sample = rearrange(sample, 'b c n h w -> (b n) c h w')
540
-
541
- sample = self.conv_out(sample)
542
- sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz)
543
- return sample
544
-
545
- class Decoder3D(nn.Module):
546
- def __init__(
547
- self,
548
- in_channels=4,
549
- out_channels=3,
550
- num_blocks=4,
551
- blocks_temp_li=[False, False, False, False],
552
- block_out_channels=(64,),
553
- layers_per_block=2,
554
- norm_num_groups=32,
555
- act_fn="silu",
556
- norm_type="group", # group, spatial
557
- ):
558
- super().__init__()
559
- self.layers_per_block = layers_per_block
560
- self.blocks_temp_li = blocks_temp_li
561
-
562
- self.conv_in = nn.Conv2d(
563
- in_channels,
564
- block_out_channels[-1],
565
- kernel_size=3,
566
- stride=1,
567
- padding=1,
568
- )
569
-
570
- self.temp_conv_in = nn.Conv3d(
571
- block_out_channels[-1],
572
- block_out_channels[-1],
573
- (3,1,1),
574
- padding = (1, 0, 0)
575
- )
576
-
577
- self.mid_block = None
578
- self.up_blocks = nn.ModuleList([])
579
-
580
- temb_channels = in_channels if norm_type == "spatial" else None
581
-
582
- # mid
583
- self.mid_block = UNetMidBlock3DConv(
584
- in_channels=block_out_channels[-1],
585
- resnet_eps=1e-6,
586
- resnet_act_fn=act_fn,
587
- output_scale_factor=1,
588
- resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
589
- attention_head_dim=block_out_channels[-1],
590
- resnet_groups=norm_num_groups,
591
- temb_channels=temb_channels,
592
- )
593
-
594
- # up
595
- reversed_block_out_channels = list(reversed(block_out_channels))
596
- output_channel = reversed_block_out_channels[0]
597
- for i in range(num_blocks):
598
- prev_output_channel = output_channel
599
- output_channel = reversed_block_out_channels[i]
600
-
601
- is_final_block = i == len(block_out_channels) - 1
602
-
603
- up_block = UpDecoderBlock3D(
604
- num_layers=self.layers_per_block + 1,
605
- in_channels=prev_output_channel,
606
- out_channels=output_channel,
607
- add_upsample=not is_final_block,
608
- add_temp_upsample=blocks_temp_li[i],
609
- resnet_eps=1e-6,
610
- resnet_act_fn=act_fn,
611
- resnet_groups=norm_num_groups,
612
- temb_channels=temb_channels,
613
- resnet_time_scale_shift=norm_type,
614
- )
615
- self.up_blocks.append(up_block)
616
- prev_output_channel = output_channel
617
-
618
- # out
619
- if norm_type == "spatial":
620
- self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
621
- else:
622
- self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
623
- self.conv_act = nn.SiLU()
624
-
625
- self.temp_conv_out = nn.Conv3d(block_out_channels[0], block_out_channels[0], (3,1,1), padding = (1, 0, 0))
626
- self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
627
-
628
- nn.init.zeros_(self.temp_conv_in.weight)
629
- nn.init.zeros_(self.temp_conv_in.bias)
630
- nn.init.zeros_(self.temp_conv_out.weight)
631
- nn.init.zeros_(self.temp_conv_out.bias)
632
-
633
- self.gradient_checkpointing = False
634
-
635
- def forward(self, z):
636
- bz = z.shape[0]
637
- sample = rearrange(z, 'b c n h w -> (b n) c h w')
638
- sample = self.conv_in(sample)
639
-
640
- sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz)
641
- temp_sample = sample
642
- sample = self.temp_conv_in(sample)
643
- sample = sample+temp_sample
644
-
645
- upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
646
- # middle
647
- sample = self.mid_block(sample)
648
- sample = sample.to(upscale_dtype)
649
-
650
- # up
651
- for b_id, up_block in enumerate(self.up_blocks):
652
- sample = up_block(sample)
653
-
654
- # post-process
655
- sample = rearrange(sample, 'b c n h w -> (b n) c h w')
656
- sample = self.conv_norm_out(sample)
657
- sample = self.conv_act(sample)
658
-
659
- sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz)
660
- temp_sample = sample
661
- sample = self.temp_conv_out(sample)
662
- sample = sample+temp_sample
663
- sample = rearrange(sample, 'b c n h w -> (b n) c h w')
664
-
665
- sample = self.conv_out(sample)
666
- sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz)
667
- return sample
668
-
669
-
670
-
671
- class AllegroAutoencoderKL3D(ModelMixin, ConfigMixin):
672
- r"""
673
- A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
674
-
675
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
676
- for all models (such as downloading or saving).
677
-
678
- Parameters:
679
- in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
680
- out_channels (int, *optional*, defaults to 3): Number of channels in the output.
681
- down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
682
- Tuple of downsample block types.
683
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
684
- Tuple of upsample block types.
685
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
686
- Tuple of block output channels.
687
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
688
- latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
689
- sample_size (`int`, *optional*, defaults to `256`): Spatial Tiling Size.
690
- tile_overlap (`tuple`, *optional*, defaults to `(120, 80`): Spatial overlapping size while tiling (height, width)
691
- chunk_len (`int`, *optional*, defaults to `24`): Temporal Tiling Size.
692
- t_over (`int`, *optional*, defaults to `8`): Temporal overlapping size while tiling
693
- scaling_factor (`float`, *optional*, defaults to 0.13235):
694
- The component-wise standard deviation of the trained latent space computed using the first batch of the
695
- training set. This is used to scale the latent space to have unit variance when training the diffusion
696
- model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
697
- diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
698
- / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
699
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
700
- force_upcast (`bool`, *optional*, default to `True`):
701
- If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
702
- can be fine-tuned / trained to a lower range without loosing too much precision in which case
703
- `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
704
- blocks_tempdown_li (`List`, *optional*, defaults to `[True, True, False, False]`): Each item indicates whether each TemporalBlock in the Encoder performs temporal downsampling.
705
- blocks_tempup_li (`List`, *optional*, defaults to `[False, True, True, False]`): Each item indicates whether each TemporalBlock in the Decoder performs temporal upsampling.
706
- load_mode (`str`, *optional*, defaults to `full`): Load mode for the model. Can be one of `full`, `encoder_only`, `decoder_only`. which corresponds to loading the full model state dicts, only the encoder state dicts, or only the decoder state dicts.
707
- """
708
-
709
- _supports_gradient_checkpointing = True
710
-
711
- @register_to_config
712
- def __init__(
713
- self,
714
- in_channels: int = 3,
715
- out_channels: int = 3,
716
- down_block_num: int = 4,
717
- up_block_num: int = 4,
718
- block_out_channels: Tuple[int] = (128,256,512,512),
719
- layers_per_block: int = 2,
720
- act_fn: str = "silu",
721
- latent_channels: int = 4,
722
- norm_num_groups: int = 32,
723
- sample_size: int = 320,
724
- tile_overlap: tuple = (120, 80),
725
- force_upcast: bool = True,
726
- chunk_len: int = 24,
727
- t_over: int = 8,
728
- scale_factor: float = 0.13235,
729
- blocks_tempdown_li=[True, True, False, False],
730
- blocks_tempup_li=[False, True, True, False],
731
- load_mode = 'full',
732
- ):
733
- super().__init__()
734
-
735
- self.blocks_tempdown_li = blocks_tempdown_li
736
- self.blocks_tempup_li = blocks_tempup_li
737
- # pass init params to Encoder
738
- self.load_mode = load_mode
739
- if load_mode in ['full', 'encoder_only']:
740
- self.encoder = Encoder3D(
741
- in_channels=in_channels,
742
- out_channels=latent_channels,
743
- num_blocks=down_block_num,
744
- blocks_temp_li=blocks_tempdown_li,
745
- block_out_channels=block_out_channels,
746
- layers_per_block=layers_per_block,
747
- act_fn=act_fn,
748
- norm_num_groups=norm_num_groups,
749
- double_z=True,
750
- )
751
- self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
752
-
753
- if load_mode in ['full', 'decoder_only']:
754
- # pass init params to Decoder
755
- self.decoder = Decoder3D(
756
- in_channels=latent_channels,
757
- out_channels=out_channels,
758
- num_blocks=up_block_num,
759
- blocks_temp_li=blocks_tempup_li,
760
- block_out_channels=block_out_channels,
761
- layers_per_block=layers_per_block,
762
- norm_num_groups=norm_num_groups,
763
- act_fn=act_fn,
764
- )
765
- self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
766
-
767
-
768
- # only relevant if vae tiling is enabled
769
- sample_size = (
770
- sample_size[0]
771
- if isinstance(sample_size, (list, tuple))
772
- else sample_size
773
- )
774
- self.tile_overlap = tile_overlap
775
- self.vae_scale_factor=[4, 8, 8]
776
- self.scale_factor = scale_factor
777
- self.sample_size = sample_size
778
- self.chunk_len = chunk_len
779
- self.t_over = t_over
780
-
781
- self.latent_chunk_len = self.chunk_len//4
782
- self.latent_t_over = self.t_over//4
783
- self.kernel = (self.chunk_len, self.sample_size, self.sample_size) #(24, 256, 256)
784
- self.stride = (self.chunk_len - self.t_over, self.sample_size-self.tile_overlap[0], self.sample_size-self.tile_overlap[1]) # (16, 112, 192)
785
-
786
-
787
- def encode(self, input_imgs: torch.Tensor, return_dict: bool = True, local_batch_size=1) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
788
- KERNEL = self.kernel
789
- STRIDE = self.stride
790
- LOCAL_BS = local_batch_size
791
- OUT_C = 8
792
-
793
- B, C, N, H, W = input_imgs.shape
794
-
795
-
796
- out_n = math.floor((N - KERNEL[0]) / STRIDE[0]) + 1
797
- out_h = math.floor((H - KERNEL[1]) / STRIDE[1]) + 1
798
- out_w = math.floor((W - KERNEL[2]) / STRIDE[2]) + 1
799
-
800
- ## cut video into overlapped small cubes and batch forward
801
- num = 0
802
-
803
- out_latent = torch.zeros((out_n*out_h*out_w, OUT_C, KERNEL[0]//4, KERNEL[1]//8, KERNEL[2]//8), device=input_imgs.device, dtype=input_imgs.dtype)
804
- vae_batch_input = torch.zeros((LOCAL_BS, C, KERNEL[0], KERNEL[1], KERNEL[2]), device=input_imgs.device, dtype=input_imgs.dtype)
805
-
806
- for i in range(out_n):
807
- for j in range(out_h):
808
- for k in range(out_w):
809
- n_start, n_end = i * STRIDE[0], i * STRIDE[0] + KERNEL[0]
810
- h_start, h_end = j * STRIDE[1], j * STRIDE[1] + KERNEL[1]
811
- w_start, w_end = k * STRIDE[2], k * STRIDE[2] + KERNEL[2]
812
- video_cube = input_imgs[:, :, n_start:n_end, h_start:h_end, w_start:w_end]
813
- vae_batch_input[num%LOCAL_BS] = video_cube
814
-
815
- if num%LOCAL_BS == LOCAL_BS-1 or num == out_n*out_h*out_w-1:
816
- latent = self.encoder(vae_batch_input)
817
-
818
- if num == out_n*out_h*out_w-1 and num%LOCAL_BS != LOCAL_BS-1:
819
- out_latent[num-num%LOCAL_BS:] = latent[:num%LOCAL_BS+1]
820
- else:
821
- out_latent[num-LOCAL_BS+1:num+1] = latent
822
- vae_batch_input = torch.zeros((LOCAL_BS, C, KERNEL[0], KERNEL[1], KERNEL[2]), device=input_imgs.device, dtype=input_imgs.dtype)
823
- num+=1
824
-
825
- ## flatten the batched out latent to videos and supress the overlapped parts
826
- B, C, N, H, W = input_imgs.shape
827
-
828
- out_video_cube = torch.zeros((B, OUT_C, N//4, H//8, W//8), device=input_imgs.device, dtype=input_imgs.dtype)
829
- OUT_KERNEL = KERNEL[0]//4, KERNEL[1]//8, KERNEL[2]//8
830
- OUT_STRIDE = STRIDE[0]//4, STRIDE[1]//8, STRIDE[2]//8
831
- OVERLAP = OUT_KERNEL[0]-OUT_STRIDE[0], OUT_KERNEL[1]-OUT_STRIDE[1], OUT_KERNEL[2]-OUT_STRIDE[2]
832
-
833
- for i in range(out_n):
834
- n_start, n_end = i * OUT_STRIDE[0], i * OUT_STRIDE[0] + OUT_KERNEL[0]
835
- for j in range(out_h):
836
- h_start, h_end = j * OUT_STRIDE[1], j * OUT_STRIDE[1] + OUT_KERNEL[1]
837
- for k in range(out_w):
838
- w_start, w_end = k * OUT_STRIDE[2], k * OUT_STRIDE[2] + OUT_KERNEL[2]
839
- latent_mean_blend = prepare_for_blend((i, out_n, OVERLAP[0]), (j, out_h, OVERLAP[1]), (k, out_w, OVERLAP[2]), out_latent[i*out_h*out_w+j*out_w+k].unsqueeze(0))
840
- out_video_cube[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += latent_mean_blend
841
-
842
- ## final conv
843
- out_video_cube = rearrange(out_video_cube, 'b c n h w -> (b n) c h w')
844
- out_video_cube = self.quant_conv(out_video_cube)
845
- out_video_cube = rearrange(out_video_cube, '(b n) c h w -> b c n h w', b=B)
846
-
847
- posterior = DiagonalGaussianDistribution(out_video_cube)
848
-
849
- if not return_dict:
850
- return (posterior,)
851
-
852
- return AutoencoderKLOutput(latent_dist=posterior)
853
-
854
-
855
- def decode(self, input_latents: torch.Tensor, return_dict: bool = True, local_batch_size=1) -> Union[DecoderOutput, torch.Tensor]:
856
- KERNEL = self.kernel
857
- STRIDE = self.stride
858
-
859
- LOCAL_BS = local_batch_size
860
- OUT_C = 3
861
- IN_KERNEL = KERNEL[0]//4, KERNEL[1]//8, KERNEL[2]//8
862
- IN_STRIDE = STRIDE[0]//4, STRIDE[1]//8, STRIDE[2]//8
863
-
864
- B, C, N, H, W = input_latents.shape
865
-
866
- ## post quant conv (a mapping)
867
- input_latents = rearrange(input_latents, 'b c n h w -> (b n) c h w')
868
- input_latents = self.post_quant_conv(input_latents)
869
- input_latents = rearrange(input_latents, '(b n) c h w -> b c n h w', b=B)
870
-
871
- ## out tensor shape
872
- out_n = math.floor((N - IN_KERNEL[0]) / IN_STRIDE[0]) + 1
873
- out_h = math.floor((H - IN_KERNEL[1]) / IN_STRIDE[1]) + 1
874
- out_w = math.floor((W - IN_KERNEL[2]) / IN_STRIDE[2]) + 1
875
-
876
- ## cut latent into overlapped small cubes and batch forward
877
- num = 0
878
- decoded_cube = torch.zeros((out_n*out_h*out_w, OUT_C, KERNEL[0], KERNEL[1], KERNEL[2]), device=input_latents.device, dtype=input_latents.dtype)
879
- vae_batch_input = torch.zeros((LOCAL_BS, C, IN_KERNEL[0], IN_KERNEL[1], IN_KERNEL[2]), device=input_latents.device, dtype=input_latents.dtype)
880
- for i in range(out_n):
881
- for j in range(out_h):
882
- for k in range(out_w):
883
- n_start, n_end = i * IN_STRIDE[0], i * IN_STRIDE[0] + IN_KERNEL[0]
884
- h_start, h_end = j * IN_STRIDE[1], j * IN_STRIDE[1] + IN_KERNEL[1]
885
- w_start, w_end = k * IN_STRIDE[2], k * IN_STRIDE[2] + IN_KERNEL[2]
886
- latent_cube = input_latents[:, :, n_start:n_end, h_start:h_end, w_start:w_end]
887
- vae_batch_input[num%LOCAL_BS] = latent_cube
888
- if num%LOCAL_BS == LOCAL_BS-1 or num == out_n*out_h*out_w-1:
889
-
890
- latent = self.decoder(vae_batch_input)
891
-
892
- if num == out_n*out_h*out_w-1 and num%LOCAL_BS != LOCAL_BS-1:
893
- decoded_cube[num-num%LOCAL_BS:] = latent[:num%LOCAL_BS+1]
894
- else:
895
- decoded_cube[num-LOCAL_BS+1:num+1] = latent
896
- vae_batch_input = torch.zeros((LOCAL_BS, C, IN_KERNEL[0], IN_KERNEL[1], IN_KERNEL[2]), device=input_latents.device, dtype=input_latents.dtype)
897
- num+=1
898
- B, C, N, H, W = input_latents.shape
899
-
900
- out_video = torch.zeros((B, OUT_C, N*4, H*8, W*8), device=input_latents.device, dtype=input_latents.dtype)
901
- OVERLAP = KERNEL[0]-STRIDE[0], KERNEL[1]-STRIDE[1], KERNEL[2]-STRIDE[2]
902
- for i in range(out_n):
903
- n_start, n_end = i * STRIDE[0], i * STRIDE[0] + KERNEL[0]
904
- for j in range(out_h):
905
- h_start, h_end = j * STRIDE[1], j * STRIDE[1] + KERNEL[1]
906
- for k in range(out_w):
907
- w_start, w_end = k * STRIDE[2], k * STRIDE[2] + KERNEL[2]
908
- out_video_blend = prepare_for_blend((i, out_n, OVERLAP[0]), (j, out_h, OVERLAP[1]), (k, out_w, OVERLAP[2]), decoded_cube[i*out_h*out_w+j*out_w+k].unsqueeze(0))
909
- out_video[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += out_video_blend
910
-
911
- out_video = rearrange(out_video, 'b c t h w -> b t c h w').contiguous()
912
-
913
- decoded = out_video
914
- if not return_dict:
915
- return (decoded,)
916
-
917
- return DecoderOutput(sample=decoded)
918
-
919
- def forward(
920
- self,
921
- sample: torch.Tensor,
922
- sample_posterior: bool = False,
923
- return_dict: bool = True,
924
- generator: Optional[torch.Generator] = None,
925
- encoder_local_batch_size: int = 2,
926
- decoder_local_batch_size: int = 2,
927
- ) -> Union[DecoderOutput, torch.Tensor]:
928
- r"""
929
- Args:
930
- sample (`torch.Tensor`): Input sample.
931
- sample_posterior (`bool`, *optional*, defaults to `False`):
932
- Whether to sample from the posterior.
933
- return_dict (`bool`, *optional*, defaults to `True`):
934
- Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
935
- generator (`torch.Generator`, *optional*):
936
- PyTorch random number generator.
937
- encoder_local_batch_size (`int`, *optional*, defaults to 2):
938
- Local batch size for the encoder's batch inference.
939
- decoder_local_batch_size (`int`, *optional*, defaults to 2):
940
- Local batch size for the decoder's batch inference.
941
- """
942
- x = sample
943
- posterior = self.encode(x, local_batch_size=encoder_local_batch_size).latent_dist
944
- if sample_posterior:
945
- z = posterior.sample(generator=generator)
946
- else:
947
- z = posterior.mode()
948
- dec = self.decode(z, local_batch_size=decoder_local_batch_size).sample
949
-
950
- if not return_dict:
951
- return (dec,)
952
-
953
- return DecoderOutput(sample=dec)
954
-
955
- @classmethod
956
- def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
957
- kwargs["torch_type"] = torch.float32
958
- return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
959
-
960
-
961
- def prepare_for_blend(n_param, h_param, w_param, x):
962
- n, n_max, overlap_n = n_param
963
- h, h_max, overlap_h = h_param
964
- w, w_max, overlap_w = w_param
965
- if overlap_n > 0:
966
- if n > 0: # the head overlap part decays from 0 to 1
967
- x[:,:,0:overlap_n,:,:] = x[:,:,0:overlap_n,:,:] * (torch.arange(0, overlap_n).float().to(x.device) / overlap_n).reshape(overlap_n,1,1)
968
- if n < n_max-1: # the tail overlap part decays from 1 to 0
969
- x[:,:,-overlap_n:,:,:] = x[:,:,-overlap_n:,:,:] * (1 - torch.arange(0, overlap_n).float().to(x.device) / overlap_n).reshape(overlap_n,1,1)
970
- if h > 0:
971
- x[:,:,:,0:overlap_h,:] = x[:,:,:,0:overlap_h,:] * (torch.arange(0, overlap_h).float().to(x.device) / overlap_h).reshape(overlap_h,1)
972
- if h < h_max-1:
973
- x[:,:,:,-overlap_h:,:] = x[:,:,:,-overlap_h:,:] * (1 - torch.arange(0, overlap_h).float().to(x.device) / overlap_h).reshape(overlap_h,1)
974
- if w > 0:
975
- x[:,:,:,:,0:overlap_w] = x[:,:,:,:,0:overlap_w] * (torch.arange(0, overlap_w).float().to(x.device) / overlap_w)
976
- if w < w_max-1:
977
- x[:,:,:,:,-overlap_w:] = x[:,:,:,:,-overlap_w:] * (1 - torch.arange(0, overlap_w).float().to(x.device) / overlap_w)
978
- return x