LarryTsai commited on
Commit
547b60e
1 Parent(s): 3f99bb7

Upload folder using huggingface_hub

Browse files
model_index.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline_allegro",
4
+ "AllegroPipeline"
5
+ ],
6
+ "_diffusers_version": "0.28.0",
7
+ "scheduler": [
8
+ "diffusers",
9
+ "EulerAncestralDiscreteScheduler"
10
+ ],
11
+ "text_encoder": [
12
+ "transformers",
13
+ "T5EncoderModel"
14
+ ],
15
+ "tokenizer": [
16
+ "transformers",
17
+ "T5Tokenizer"
18
+ ],
19
+ "transformer": [
20
+ "transformer_3d_allegro",
21
+ "AllegroTransformer3DModel"
22
+ ],
23
+ "vae": [
24
+ "vae_allegro",
25
+ "AllegroAutoencoderKL3D"
26
+ ]
27
+ }
pipeline_allegro.py ADDED
@@ -0,0 +1,832 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from Open-Sora-Plan
2
+
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ # --------------------------------------------------------
6
+ # References:
7
+ # Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
8
+ # --------------------------------------------------------
9
+
10
+ import html
11
+ import inspect
12
+ import math
13
+ import re
14
+ import urllib.parse as ul
15
+ from typing import Callable, List, Optional, Tuple, Union
16
+ from einops import rearrange
17
+ import ftfy
18
+ import torch
19
+ from dataclasses import dataclass
20
+ import tqdm
21
+ from bs4 import BeautifulSoup
22
+
23
+ from diffusers import DiffusionPipeline, ModelMixin
24
+ from diffusers.schedulers import EulerAncestralDiscreteScheduler
25
+ from diffusers.utils import (
26
+ BACKENDS_MAPPING,
27
+ is_bs4_available,
28
+ is_ftfy_available,
29
+ logging,
30
+ replace_example_docstring,
31
+ BaseOutput
32
+ )
33
+ from diffusers.utils.torch_utils import randn_tensor
34
+ from transformers import T5EncoderModel, T5Tokenizer
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+ # from transformer_3d_allegro import AllegroTransformer3DModel
39
+ # from vae_allegro import AllegroAutoencoderKL3D
40
+ @dataclass
41
+ class AllegroPipelineOutput(BaseOutput):
42
+ r"""
43
+ Output class for Allegro pipelines.
44
+
45
+ Args:
46
+ video (`torch.Tensor`):
47
+ Torch tensor with shape `(batch_size, num_frames, channels, height, width)`.
48
+ """
49
+ video: torch.Tensor
50
+
51
+
52
+ EXAMPLE_DOC_STRING = """
53
+ Examples:
54
+ ```py
55
+ >>> import torch
56
+
57
+ >>> # You can replace the your_path_to_model with your own path.
58
+ >>> pipe = AllegroPipeline.from_pretrained(your_path_to_model, torch_dtype=torch.float16, trust_remote_code=True)
59
+
60
+ >>> prompt = "A small cactus with a happy face in the Sahara desert."
61
+ >>> image = pipe(prompt).video[0]
62
+ ```
63
+ """
64
+
65
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
66
+ def retrieve_timesteps(
67
+ scheduler,
68
+ num_inference_steps: Optional[int] = None,
69
+ device: Optional[Union[str, torch.device]] = None,
70
+ timesteps: Optional[List[int]] = None,
71
+ **kwargs,
72
+ ):
73
+ """
74
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
75
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
76
+
77
+ Args:
78
+ scheduler (`SchedulerMixin`):
79
+ The scheduler to get timesteps from.
80
+ num_inference_steps (`int`):
81
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
82
+ must be `None`.
83
+ device (`str` or `torch.device`, *optional*):
84
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
85
+ timesteps (`List[int]`, *optional*):
86
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
87
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
88
+ must be `None`.
89
+
90
+ Returns:
91
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
92
+ second element is the number of inference steps.
93
+ """
94
+ if timesteps is not None:
95
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
96
+ if not accepts_timesteps:
97
+ raise ValueError(
98
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
99
+ f" timestep schedules. Please check whether you are using the correct scheduler."
100
+ )
101
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
102
+ timesteps = scheduler.timesteps
103
+ num_inference_steps = len(timesteps)
104
+ else:
105
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
106
+ timesteps = scheduler.timesteps
107
+ return timesteps, num_inference_steps
108
+
109
+
110
+ class AllegroPipeline(DiffusionPipeline):
111
+ r"""
112
+ Pipeline for text-to-image generation using Allegro.
113
+
114
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
115
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
116
+
117
+ Args:
118
+ vae ([`AllegroAutoEncoderKL3D`]):
119
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
120
+ text_encoder ([`T5EncoderModel`]):
121
+ Frozen text-encoder. PixArt-Alpha uses
122
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
123
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
124
+ tokenizer (`T5Tokenizer`):
125
+ Tokenizer of class
126
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
127
+ transformer ([`AllegroTransformer3DModel`]):
128
+ A text conditioned `AllegroTransformer3DModel` to denoise the encoded image latents.
129
+ scheduler ([`SchedulerMixin`]):
130
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
131
+ """
132
+ bad_punct_regex = re.compile(
133
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
134
+ ) # noqa
135
+
136
+ _optional_components = ["tokenizer", "text_encoder", "vae", "transformer", "scheduler"]
137
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
138
+
139
+ def __init__(
140
+ self,
141
+ tokenizer: Optional[T5Tokenizer] = None,
142
+ text_encoder: Optional[T5EncoderModel] = None,
143
+ vae: Optional[ModelMixin] = None,
144
+ transformer: Optional[ModelMixin] = None,
145
+ scheduler: Optional[EulerAncestralDiscreteScheduler] = None,
146
+ device: torch.device = torch.device("cuda"),
147
+ dtype: torch.dtype = torch.float16,
148
+ ):
149
+ super().__init__()
150
+ # # init
151
+ # if tokenizer is None:
152
+ # tokenizer = T5Tokenizer.from_pretrained(tokenizer)
153
+ # if text_encoder is None:
154
+ # text_encoder = T5EncoderModel.from_pretrained(text_encoder, torch_dtype=torch.float16)
155
+ # if vae is None:
156
+ # vae = AllegroAutoencoderKL3D.from_pretrained(vae).to(dtype=torch.float32)
157
+ # if transformer is None:
158
+ # transformer = AllegroTransformer3DModel.from_pretrained(transformer, torch_dtype=dtype)
159
+ # if scheduler is None:
160
+ # scheduler = EulerAncestralDiscreteScheduler()
161
+ self.register_modules(
162
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
163
+ )
164
+
165
+
166
+ # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
167
+ def encode_prompt(
168
+ self,
169
+ prompt: Union[str, List[str]],
170
+ do_classifier_free_guidance: bool = True,
171
+ negative_prompt: str = "",
172
+ num_images_per_prompt: int = 1,
173
+ device: Optional[torch.device] = None,
174
+ prompt_embeds: Optional[torch.FloatTensor] = None,
175
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
176
+ prompt_attention_mask: Optional[torch.FloatTensor] = None,
177
+ negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
178
+ clean_caption: bool = False,
179
+ max_sequence_length: int = 120,
180
+ **kwargs,
181
+ ):
182
+ r"""
183
+ Encodes the prompt into text encoder hidden states.
184
+
185
+ Args:
186
+ prompt (`str` or `List[str]`, *optional*):
187
+ prompt to be encoded
188
+ negative_prompt (`str` or `List[str]`, *optional*):
189
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
190
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
191
+ PixArt-Alpha, this should be "".
192
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
193
+ whether to use classifier free guidance or not
194
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
195
+ number of images that should be generated per prompt
196
+ device: (`torch.device`, *optional*):
197
+ torch device to place the resulting embeddings on
198
+ prompt_embeds (`torch.FloatTensor`, *optional*):
199
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
200
+ provided, text embeddings will be generated from `prompt` input argument.
201
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
202
+ Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
203
+ string.
204
+ clean_caption (`bool`, defaults to `False`):
205
+ If `True`, the function will preprocess and clean the provided caption before encoding.
206
+ max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt.
207
+ """
208
+ embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None
209
+
210
+ if device is None:
211
+ device = self._execution_device
212
+
213
+ if prompt is not None and isinstance(prompt, str):
214
+ batch_size = 1
215
+ elif prompt is not None and isinstance(prompt, list):
216
+ batch_size = len(prompt)
217
+ else:
218
+ batch_size = prompt_embeds.shape[0]
219
+
220
+ # See Section 3.1. of the paper.
221
+ max_length = max_sequence_length
222
+
223
+ if prompt_embeds is None:
224
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
225
+ text_inputs = self.tokenizer(
226
+ prompt,
227
+ padding="max_length",
228
+ max_length=max_length,
229
+ truncation=True,
230
+ add_special_tokens=True,
231
+ return_tensors="pt",
232
+ )
233
+ text_input_ids = text_inputs.input_ids
234
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
235
+
236
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
237
+ text_input_ids, untruncated_ids
238
+ ):
239
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
240
+ logger.warning(
241
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
242
+ f" {max_length} tokens: {removed_text}"
243
+ )
244
+
245
+ prompt_attention_mask = text_inputs.attention_mask
246
+ prompt_attention_mask = prompt_attention_mask.to(device)
247
+
248
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
249
+ prompt_embeds = prompt_embeds[0]
250
+
251
+ if self.text_encoder is not None:
252
+ dtype = self.text_encoder.dtype
253
+ elif self.transformer is not None:
254
+ dtype = self.transformer.dtype
255
+ else:
256
+ dtype = None
257
+
258
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
259
+
260
+ bs_embed, seq_len, _ = prompt_embeds.shape
261
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
262
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
263
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
264
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
265
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
266
+
267
+ # get unconditional embeddings for classifier free guidance
268
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
269
+ uncond_tokens = [negative_prompt] * batch_size
270
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
271
+ max_length = prompt_embeds.shape[1]
272
+ uncond_input = self.tokenizer(
273
+ uncond_tokens,
274
+ padding="max_length",
275
+ max_length=max_length,
276
+ truncation=True,
277
+ return_attention_mask=True,
278
+ add_special_tokens=True,
279
+ return_tensors="pt",
280
+ )
281
+ negative_prompt_attention_mask = uncond_input.attention_mask
282
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
283
+
284
+ negative_prompt_embeds = self.text_encoder(
285
+ uncond_input.input_ids.to(device),
286
+ attention_mask=negative_prompt_attention_mask,
287
+ )
288
+ negative_prompt_embeds = negative_prompt_embeds[0]
289
+
290
+ if do_classifier_free_guidance:
291
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
292
+ seq_len = negative_prompt_embeds.shape[1]
293
+
294
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
295
+
296
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
297
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
298
+
299
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
300
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
301
+ else:
302
+ negative_prompt_embeds = None
303
+ negative_prompt_attention_mask = None
304
+
305
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
306
+
307
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
308
+ def prepare_extra_step_kwargs(self, generator, eta):
309
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
310
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
311
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
312
+ # and should be between [0, 1]
313
+
314
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
315
+ extra_step_kwargs = {}
316
+ if accepts_eta:
317
+ extra_step_kwargs["eta"] = eta
318
+
319
+ # check if the scheduler accepts generator
320
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
321
+ if accepts_generator:
322
+ extra_step_kwargs["generator"] = generator
323
+ return extra_step_kwargs
324
+
325
+ def check_inputs(
326
+ self,
327
+ prompt,
328
+ num_frames,
329
+ height,
330
+ width,
331
+ negative_prompt,
332
+ callback_steps,
333
+ prompt_embeds=None,
334
+ negative_prompt_embeds=None,
335
+ prompt_attention_mask=None,
336
+ negative_prompt_attention_mask=None,
337
+ ):
338
+
339
+ if num_frames <= 0:
340
+ raise ValueError(f"`num_frames` have to be positive but is {num_frames}.")
341
+ if height % 8 != 0 or width % 8 != 0:
342
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
343
+
344
+ if (callback_steps is None) or (
345
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
346
+ ):
347
+ raise ValueError(
348
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
349
+ f" {type(callback_steps)}."
350
+ )
351
+
352
+ if prompt is not None and prompt_embeds is not None:
353
+ raise ValueError(
354
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
355
+ " only forward one of the two."
356
+ )
357
+ elif prompt is None and prompt_embeds is None:
358
+ raise ValueError(
359
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
360
+ )
361
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
362
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
363
+
364
+ if prompt is not None and negative_prompt_embeds is not None:
365
+ raise ValueError(
366
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
367
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
368
+ )
369
+
370
+ if negative_prompt is not None and negative_prompt_embeds is not None:
371
+ raise ValueError(
372
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
373
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
374
+ )
375
+
376
+ if prompt_embeds is not None and prompt_attention_mask is None:
377
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
378
+
379
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
380
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
381
+
382
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
383
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
384
+ raise ValueError(
385
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
386
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
387
+ f" {negative_prompt_embeds.shape}."
388
+ )
389
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
390
+ raise ValueError(
391
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
392
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
393
+ f" {negative_prompt_attention_mask.shape}."
394
+ )
395
+
396
+
397
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
398
+ def _text_preprocessing(self, text, clean_caption=False):
399
+ if clean_caption and not is_bs4_available():
400
+ logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
401
+ logger.warning("Setting `clean_caption` to False...")
402
+ clean_caption = False
403
+
404
+ if clean_caption and not is_ftfy_available():
405
+ logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
406
+ logger.warning("Setting `clean_caption` to False...")
407
+ clean_caption = False
408
+
409
+ if not isinstance(text, (tuple, list)):
410
+ text = [text]
411
+
412
+ def process(text: str):
413
+ if clean_caption:
414
+ text = self._clean_caption(text)
415
+ text = self._clean_caption(text)
416
+ else:
417
+ text = text.lower().strip()
418
+ return text
419
+
420
+ return [process(t) for t in text]
421
+
422
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
423
+ def _clean_caption(self, caption):
424
+ caption = str(caption)
425
+ caption = ul.unquote_plus(caption)
426
+ caption = caption.strip().lower()
427
+ caption = re.sub("<person>", "person", caption)
428
+ # urls:
429
+ caption = re.sub(
430
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))",
431
+ # noqa
432
+ "",
433
+ caption,
434
+ ) # regex for urls
435
+ caption = re.sub(
436
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))",
437
+ # noqa
438
+ "",
439
+ caption,
440
+ ) # regex for urls
441
+ # html:
442
+ caption = BeautifulSoup(caption, features="html.parser").text
443
+
444
+ # @<nickname>
445
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
446
+
447
+ # 31C0—31EF CJK Strokes
448
+ # 31F0—31FF Katakana Phonetic Extensions
449
+ # 3200—32FF Enclosed CJK Letters and Months
450
+ # 3300—33FF CJK Compatibility
451
+ # 3400—4DBF CJK Unified Ideographs Extension A
452
+ # 4DC0—4DFF Yijing Hexagram Symbols
453
+ # 4E00—9FFF CJK Unified Ideographs
454
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
455
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
456
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
457
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
458
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
459
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
460
+ # caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
461
+ #######################################################
462
+
463
+ # все виды тире / all types of dash --> "-"
464
+ caption = re.sub(
465
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+",
466
+ # noqa
467
+ "-",
468
+ caption,
469
+ )
470
+
471
+ # кавычки к одному стандарту
472
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
473
+ caption = re.sub(r"[‘’]", "'", caption)
474
+
475
+ # &quot;
476
+ caption = re.sub(r"&quot;?", "", caption)
477
+ # &amp
478
+ caption = re.sub(r"&amp", "", caption)
479
+
480
+ # ip adresses:
481
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
482
+
483
+ # article ids:
484
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
485
+
486
+ # \n
487
+ caption = re.sub(r"\\n", " ", caption)
488
+
489
+ # "#123"
490
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
491
+ # "#12345.."
492
+ caption = re.sub(r"#\d{5,}\b", "", caption)
493
+ # "123456.."
494
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
495
+ # filenames:
496
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
497
+
498
+ #
499
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
500
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
501
+
502
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
503
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
504
+
505
+ # this-is-my-cute-cat / this_is_my_cute_cat
506
+ regex2 = re.compile(r"(?:\-|\_)")
507
+ if len(re.findall(regex2, caption)) > 3:
508
+ caption = re.sub(regex2, " ", caption)
509
+
510
+ caption = ftfy.fix_text(caption)
511
+ caption = html.unescape(html.unescape(caption))
512
+
513
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
514
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
515
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
516
+
517
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
518
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
519
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
520
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
521
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
522
+
523
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
524
+
525
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
526
+
527
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
528
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
529
+ caption = re.sub(r"\s+", " ", caption)
530
+
531
+ caption.strip()
532
+
533
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
534
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
535
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
536
+ caption = re.sub(r"^\.\S+$", "", caption)
537
+ return caption.strip()
538
+
539
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
540
+ def prepare_latents(
541
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
542
+ ):
543
+ shape = (
544
+ batch_size,
545
+ num_channels_latents,
546
+ (math.ceil((int(num_frames) - 1) / self.vae.vae_scale_factor[0]) + 1)
547
+ if int(num_frames) % 2 == 1
548
+ else math.ceil(int(num_frames) / self.vae.vae_scale_factor[0]),
549
+ math.ceil(int(height) / self.vae.vae_scale_factor[1]),
550
+ math.ceil(int(width) / self.vae.vae_scale_factor[2]),
551
+ )
552
+ if isinstance(generator, list) and len(generator) != batch_size:
553
+ raise ValueError(
554
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
555
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
556
+ )
557
+
558
+ if latents is None:
559
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
560
+ else:
561
+ latents = latents.to(device)
562
+
563
+ # scale the initial noise by the standard deviation required by the scheduler
564
+ latents = latents * self.scheduler.init_noise_sigma
565
+
566
+
567
+ return latents
568
+
569
+ @torch.no_grad()
570
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
571
+ def __call__(
572
+ self,
573
+ prompt: Union[str, List[str]] = None,
574
+ negative_prompt: str = "",
575
+ num_inference_steps: int = 100,
576
+ timesteps: List[int] = None,
577
+ guidance_scale: float = 7.5,
578
+ num_images_per_prompt: Optional[int] = 1,
579
+ num_frames: Optional[int] = None,
580
+ height: Optional[int] = None,
581
+ width: Optional[int] = None,
582
+ eta: float = 0.0,
583
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
584
+ latents: Optional[torch.FloatTensor] = None,
585
+ prompt_embeds: Optional[torch.FloatTensor] = None,
586
+ prompt_attention_mask: Optional[torch.FloatTensor] = None,
587
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
588
+ negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
589
+ output_type: Optional[str] = "pil",
590
+ return_dict: bool = True,
591
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
592
+ callback_steps: int = 1,
593
+ clean_caption: bool = True,
594
+ max_sequence_length: int = 512,
595
+ verbose: bool = True,
596
+ ) -> Union[AllegroPipelineOutput, Tuple]:
597
+ """
598
+ Function invoked when calling the pipeline for generation.
599
+
600
+ Args:
601
+ prompt (`str` or `List[str]`, *optional*):
602
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
603
+ instead.
604
+ negative_prompt (`str` or `List[str]`, *optional*):
605
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
606
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
607
+ less than `1`).
608
+ num_inference_steps (`int`, *optional*, defaults to 100):
609
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
610
+ expense of slower inference.
611
+ timesteps (`List[int]`, *optional*):
612
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
613
+ timesteps are used. Must be in descending order.
614
+ guidance_scale (`float`, *optional*, defaults to 7.0):
615
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
616
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
617
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
618
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
619
+ usually at the expense of lower image quality.
620
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
621
+ The number of images to generate per prompt.
622
+ num_frames: (`int`, *optional*, defaults to 88):
623
+ The number controls the generated video frames.
624
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
625
+ The height in pixels of the generated image.
626
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
627
+ The width in pixels of the generated image.
628
+ eta (`float`, *optional*, defaults to 0.0):
629
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
630
+ [`schedulers.DDIMScheduler`], will be ignored for others.
631
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
632
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
633
+ to make generation deterministic.
634
+ latents (`torch.FloatTensor`, *optional*):
635
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
636
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
637
+ tensor will ge generated by sampling using the supplied random `generator`.
638
+ prompt_embeds (`torch.FloatTensor`, *optional*):
639
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
640
+ provided, text embeddings will be generated from `prompt` input argument.
641
+ prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings.
642
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
643
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
644
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
645
+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
646
+ Pre-generated attention mask for negative text embeddings.
647
+ output_type (`str`, *optional*, defaults to `"pil"`):
648
+ The output format of the generate image. Choose between
649
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
650
+ return_dict (`bool`, *optional*, defaults to `True`):
651
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
652
+ callback (`Callable`, *optional*):
653
+ A function that will be called every `callback_steps` steps during inference. The function will be
654
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
655
+ callback_steps (`int`, *optional*, defaults to 1):
656
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
657
+ called at every step.
658
+ clean_caption (`bool`, *optional*, defaults to `True`):
659
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
660
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
661
+ prompt.
662
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
663
+
664
+ Examples:
665
+
666
+ Returns:
667
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
668
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
669
+ returned where the first element is a list with the generated images
670
+ """
671
+ # 1. Check inputs. Raise error if not correct
672
+ num_frames = num_frames or self.transformer.config.sample_size_t * self.vae.vae_scale_factor[0]
673
+ height = height or self.transformer.config.sample_size[0] * self.vae.vae_scale_factor[1]
674
+ width = width or self.transformer.config.sample_size[1] * self.vae.vae_scale_factor[2]
675
+
676
+ self.check_inputs(
677
+ prompt,
678
+ num_frames,
679
+ height,
680
+ width,
681
+ negative_prompt,
682
+ callback_steps,
683
+ prompt_embeds,
684
+ negative_prompt_embeds,
685
+ prompt_attention_mask,
686
+ negative_prompt_attention_mask,
687
+ )
688
+
689
+ # 2. Default height and width to transformer
690
+ if prompt is not None and isinstance(prompt, str):
691
+ batch_size = 1
692
+ elif prompt is not None and isinstance(prompt, list):
693
+ batch_size = len(prompt)
694
+ else:
695
+ batch_size = prompt_embeds.shape[0]
696
+
697
+ device = self._execution_device
698
+
699
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
700
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
701
+ # corresponds to doing no classifier free guidance.
702
+ do_classifier_free_guidance = guidance_scale > 1.0
703
+
704
+ # 3. Encode input prompt
705
+ (
706
+ prompt_embeds,
707
+ prompt_attention_mask,
708
+ negative_prompt_embeds,
709
+ negative_prompt_attention_mask,
710
+ ) = self.encode_prompt(
711
+ prompt,
712
+ do_classifier_free_guidance,
713
+ negative_prompt=negative_prompt,
714
+ num_images_per_prompt=num_images_per_prompt,
715
+ device=device,
716
+ prompt_embeds=prompt_embeds,
717
+ negative_prompt_embeds=negative_prompt_embeds,
718
+ prompt_attention_mask=prompt_attention_mask,
719
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
720
+ clean_caption=clean_caption,
721
+ max_sequence_length=max_sequence_length,
722
+ )
723
+ if do_classifier_free_guidance:
724
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
725
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
726
+
727
+ # 4. Prepare timesteps
728
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
729
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
730
+
731
+ # 5. Prepare latents.
732
+ latent_channels = self.transformer.config.in_channels
733
+ latents = self.prepare_latents(
734
+ batch_size * num_images_per_prompt,
735
+ latent_channels,
736
+ num_frames,
737
+ height,
738
+ width,
739
+ prompt_embeds.dtype,
740
+ device,
741
+ generator,
742
+ latents,
743
+ )
744
+
745
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
746
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
747
+
748
+ # 6.1 Prepare micro-conditions.
749
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
750
+
751
+ # 7. Denoising loop
752
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
753
+
754
+ progress_wrap = tqdm.tqdm if verbose else (lambda x: x)
755
+ for i, t in progress_wrap(list(enumerate(timesteps))):
756
+
757
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
758
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
759
+
760
+ current_timestep = t
761
+ if not torch.is_tensor(current_timestep):
762
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
763
+ # This would be a good case for the `match` statement (Python 3.10+)
764
+ is_mps = latent_model_input.device.type == "mps"
765
+ if isinstance(current_timestep, float):
766
+ dtype = torch.float32 if is_mps else torch.float64
767
+ else:
768
+ dtype = torch.int32 if is_mps else torch.int64
769
+ current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
770
+ elif len(current_timestep.shape) == 0:
771
+ current_timestep = current_timestep[None].to(latent_model_input.device)
772
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
773
+ current_timestep = current_timestep.expand(latent_model_input.shape[0])
774
+
775
+ if prompt_embeds.ndim == 3:
776
+ prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d
777
+ if prompt_attention_mask.ndim == 2:
778
+ prompt_attention_mask = prompt_attention_mask.unsqueeze(1) # b l -> b 1 l
779
+ # prepare attention_mask.
780
+ # b c t h w -> b t h w
781
+ attention_mask = torch.ones_like(latent_model_input)[:, 0]
782
+ # predict noise model_output
783
+ noise_pred = self.transformer(
784
+ latent_model_input,
785
+ attention_mask=attention_mask,
786
+ encoder_hidden_states=prompt_embeds,
787
+ encoder_attention_mask=prompt_attention_mask,
788
+ timestep=current_timestep,
789
+ added_cond_kwargs=added_cond_kwargs,
790
+ return_dict=False,
791
+ )[0]
792
+
793
+ # perform guidance
794
+ if do_classifier_free_guidance:
795
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
796
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
797
+
798
+ # learned sigma
799
+ if self.transformer.config.out_channels // 2 == latent_channels:
800
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
801
+ else:
802
+ noise_pred = noise_pred
803
+
804
+ # compute previous image: x_t -> x_t-1
805
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
806
+
807
+ # call the callback, if provided
808
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
809
+ if callback is not None and i % callback_steps == 0:
810
+ step_idx = i // getattr(self.scheduler, "order", 1)
811
+ callback(step_idx, t, latents)
812
+
813
+ if not output_type == "latents":
814
+ video = self.decode_latents(latents)
815
+ video = video[:, :num_frames, :height, :width]
816
+ else:
817
+ video = latents
818
+ return AllegroPipelineOutput(video=video)
819
+
820
+ # Offload all models
821
+ self.maybe_free_model_hooks()
822
+
823
+ if not return_dict:
824
+ return (video,)
825
+
826
+ return AllegroPipelineOutput(video=video)
827
+
828
+ def decode_latents(self, latents):
829
+ video = self.vae.decode(latents.to(self.vae.dtype) / self.vae.scale_factor).sample
830
+ # b t c h w -> b t h w c
831
+ video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().permute(0, 1, 3, 4, 2).contiguous()
832
+ return video
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "EulerAncestralDiscreteScheduler",
3
+ "_diffusers_version": "0.28.0",
4
+ "beta_end": 0.02,
5
+ "beta_schedule": "linear",
6
+ "beta_start": 0.0001,
7
+ "num_train_timesteps": 1000,
8
+ "prediction_type": "epsilon",
9
+ "rescale_betas_zero_snr": false,
10
+ "steps_offset": 0,
11
+ "timestep_spacing": "linspace",
12
+ "trained_betas": null
13
+ }
text_encoder/config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "T5EncoderModel"
4
+ ],
5
+ "d_ff": 10240,
6
+ "d_kv": 64,
7
+ "d_model": 4096,
8
+ "decoder_start_token_id": 0,
9
+ "dense_act_fn": "gelu_new",
10
+ "dropout_rate": 0.1,
11
+ "eos_token_id": 1,
12
+ "feed_forward_proj": "gated-gelu",
13
+ "initializer_factor": 1.0,
14
+ "is_encoder_decoder": true,
15
+ "is_gated_act": true,
16
+ "layer_norm_epsilon": 1e-06,
17
+ "model_type": "t5",
18
+ "num_decoder_layers": 24,
19
+ "num_heads": 64,
20
+ "num_layers": 24,
21
+ "output_past": true,
22
+ "pad_token_id": 0,
23
+ "relative_attention_max_distance": 128,
24
+ "relative_attention_num_buckets": 32,
25
+ "tie_word_embeddings": false,
26
+ "torch_dtype": "float32",
27
+ "transformers_version": "4.21.1",
28
+ "use_cache": true,
29
+ "vocab_size": 32128
30
+ }
text_encoder/pytorch_model-00001-of-00002.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f71ad0624095dae788b1023081dda1b4040bd24f7244a5b5b46eebc09825839
3
+ size 9452285635
text_encoder/pytorch_model-00002-of-00002.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f68f80678299ac59f69b3550ebd47b966571920d8f9e71f42ab61fabaaed868
3
+ size 9597031749
text_encoder/pytorch_model.bin.index.json ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 19575627776
4
+ },
5
+ "weight_map": {
6
+ "encoder.block.0.layer.0.SelfAttention.k.weight": "pytorch_model-00001-of-00002.bin",
7
+ "encoder.block.0.layer.0.SelfAttention.o.weight": "pytorch_model-00001-of-00002.bin",
8
+ "encoder.block.0.layer.0.SelfAttention.q.weight": "pytorch_model-00001-of-00002.bin",
9
+ "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight": "pytorch_model-00001-of-00002.bin",
10
+ "encoder.block.0.layer.0.SelfAttention.v.weight": "pytorch_model-00001-of-00002.bin",
11
+ "encoder.block.0.layer.0.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
12
+ "encoder.block.0.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00001-of-00002.bin",
13
+ "encoder.block.0.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00001-of-00002.bin",
14
+ "encoder.block.0.layer.1.DenseReluDense.wo.weight": "pytorch_model-00001-of-00002.bin",
15
+ "encoder.block.0.layer.1.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
16
+ "encoder.block.1.layer.0.SelfAttention.k.weight": "pytorch_model-00001-of-00002.bin",
17
+ "encoder.block.1.layer.0.SelfAttention.o.weight": "pytorch_model-00001-of-00002.bin",
18
+ "encoder.block.1.layer.0.SelfAttention.q.weight": "pytorch_model-00001-of-00002.bin",
19
+ "encoder.block.1.layer.0.SelfAttention.v.weight": "pytorch_model-00001-of-00002.bin",
20
+ "encoder.block.1.layer.0.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
21
+ "encoder.block.1.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00001-of-00002.bin",
22
+ "encoder.block.1.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00001-of-00002.bin",
23
+ "encoder.block.1.layer.1.DenseReluDense.wo.weight": "pytorch_model-00001-of-00002.bin",
24
+ "encoder.block.1.layer.1.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
25
+ "encoder.block.10.layer.0.SelfAttention.k.weight": "pytorch_model-00001-of-00002.bin",
26
+ "encoder.block.10.layer.0.SelfAttention.o.weight": "pytorch_model-00001-of-00002.bin",
27
+ "encoder.block.10.layer.0.SelfAttention.q.weight": "pytorch_model-00001-of-00002.bin",
28
+ "encoder.block.10.layer.0.SelfAttention.v.weight": "pytorch_model-00001-of-00002.bin",
29
+ "encoder.block.10.layer.0.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
30
+ "encoder.block.10.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00001-of-00002.bin",
31
+ "encoder.block.10.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00001-of-00002.bin",
32
+ "encoder.block.10.layer.1.DenseReluDense.wo.weight": "pytorch_model-00001-of-00002.bin",
33
+ "encoder.block.10.layer.1.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
34
+ "encoder.block.11.layer.0.SelfAttention.k.weight": "pytorch_model-00001-of-00002.bin",
35
+ "encoder.block.11.layer.0.SelfAttention.o.weight": "pytorch_model-00001-of-00002.bin",
36
+ "encoder.block.11.layer.0.SelfAttention.q.weight": "pytorch_model-00001-of-00002.bin",
37
+ "encoder.block.11.layer.0.SelfAttention.v.weight": "pytorch_model-00001-of-00002.bin",
38
+ "encoder.block.11.layer.0.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
39
+ "encoder.block.11.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00001-of-00002.bin",
40
+ "encoder.block.11.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00002-of-00002.bin",
41
+ "encoder.block.11.layer.1.DenseReluDense.wo.weight": "pytorch_model-00002-of-00002.bin",
42
+ "encoder.block.11.layer.1.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
43
+ "encoder.block.12.layer.0.SelfAttention.k.weight": "pytorch_model-00002-of-00002.bin",
44
+ "encoder.block.12.layer.0.SelfAttention.o.weight": "pytorch_model-00002-of-00002.bin",
45
+ "encoder.block.12.layer.0.SelfAttention.q.weight": "pytorch_model-00002-of-00002.bin",
46
+ "encoder.block.12.layer.0.SelfAttention.v.weight": "pytorch_model-00002-of-00002.bin",
47
+ "encoder.block.12.layer.0.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
48
+ "encoder.block.12.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00002-of-00002.bin",
49
+ "encoder.block.12.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00002-of-00002.bin",
50
+ "encoder.block.12.layer.1.DenseReluDense.wo.weight": "pytorch_model-00002-of-00002.bin",
51
+ "encoder.block.12.layer.1.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
52
+ "encoder.block.13.layer.0.SelfAttention.k.weight": "pytorch_model-00002-of-00002.bin",
53
+ "encoder.block.13.layer.0.SelfAttention.o.weight": "pytorch_model-00002-of-00002.bin",
54
+ "encoder.block.13.layer.0.SelfAttention.q.weight": "pytorch_model-00002-of-00002.bin",
55
+ "encoder.block.13.layer.0.SelfAttention.v.weight": "pytorch_model-00002-of-00002.bin",
56
+ "encoder.block.13.layer.0.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
57
+ "encoder.block.13.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00002-of-00002.bin",
58
+ "encoder.block.13.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00002-of-00002.bin",
59
+ "encoder.block.13.layer.1.DenseReluDense.wo.weight": "pytorch_model-00002-of-00002.bin",
60
+ "encoder.block.13.layer.1.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
61
+ "encoder.block.14.layer.0.SelfAttention.k.weight": "pytorch_model-00002-of-00002.bin",
62
+ "encoder.block.14.layer.0.SelfAttention.o.weight": "pytorch_model-00002-of-00002.bin",
63
+ "encoder.block.14.layer.0.SelfAttention.q.weight": "pytorch_model-00002-of-00002.bin",
64
+ "encoder.block.14.layer.0.SelfAttention.v.weight": "pytorch_model-00002-of-00002.bin",
65
+ "encoder.block.14.layer.0.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
66
+ "encoder.block.14.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00002-of-00002.bin",
67
+ "encoder.block.14.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00002-of-00002.bin",
68
+ "encoder.block.14.layer.1.DenseReluDense.wo.weight": "pytorch_model-00002-of-00002.bin",
69
+ "encoder.block.14.layer.1.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
70
+ "encoder.block.15.layer.0.SelfAttention.k.weight": "pytorch_model-00002-of-00002.bin",
71
+ "encoder.block.15.layer.0.SelfAttention.o.weight": "pytorch_model-00002-of-00002.bin",
72
+ "encoder.block.15.layer.0.SelfAttention.q.weight": "pytorch_model-00002-of-00002.bin",
73
+ "encoder.block.15.layer.0.SelfAttention.v.weight": "pytorch_model-00002-of-00002.bin",
74
+ "encoder.block.15.layer.0.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
75
+ "encoder.block.15.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00002-of-00002.bin",
76
+ "encoder.block.15.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00002-of-00002.bin",
77
+ "encoder.block.15.layer.1.DenseReluDense.wo.weight": "pytorch_model-00002-of-00002.bin",
78
+ "encoder.block.15.layer.1.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
79
+ "encoder.block.16.layer.0.SelfAttention.k.weight": "pytorch_model-00002-of-00002.bin",
80
+ "encoder.block.16.layer.0.SelfAttention.o.weight": "pytorch_model-00002-of-00002.bin",
81
+ "encoder.block.16.layer.0.SelfAttention.q.weight": "pytorch_model-00002-of-00002.bin",
82
+ "encoder.block.16.layer.0.SelfAttention.v.weight": "pytorch_model-00002-of-00002.bin",
83
+ "encoder.block.16.layer.0.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
84
+ "encoder.block.16.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00002-of-00002.bin",
85
+ "encoder.block.16.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00002-of-00002.bin",
86
+ "encoder.block.16.layer.1.DenseReluDense.wo.weight": "pytorch_model-00002-of-00002.bin",
87
+ "encoder.block.16.layer.1.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
88
+ "encoder.block.17.layer.0.SelfAttention.k.weight": "pytorch_model-00002-of-00002.bin",
89
+ "encoder.block.17.layer.0.SelfAttention.o.weight": "pytorch_model-00002-of-00002.bin",
90
+ "encoder.block.17.layer.0.SelfAttention.q.weight": "pytorch_model-00002-of-00002.bin",
91
+ "encoder.block.17.layer.0.SelfAttention.v.weight": "pytorch_model-00002-of-00002.bin",
92
+ "encoder.block.17.layer.0.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
93
+ "encoder.block.17.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00002-of-00002.bin",
94
+ "encoder.block.17.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00002-of-00002.bin",
95
+ "encoder.block.17.layer.1.DenseReluDense.wo.weight": "pytorch_model-00002-of-00002.bin",
96
+ "encoder.block.17.layer.1.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
97
+ "encoder.block.18.layer.0.SelfAttention.k.weight": "pytorch_model-00002-of-00002.bin",
98
+ "encoder.block.18.layer.0.SelfAttention.o.weight": "pytorch_model-00002-of-00002.bin",
99
+ "encoder.block.18.layer.0.SelfAttention.q.weight": "pytorch_model-00002-of-00002.bin",
100
+ "encoder.block.18.layer.0.SelfAttention.v.weight": "pytorch_model-00002-of-00002.bin",
101
+ "encoder.block.18.layer.0.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
102
+ "encoder.block.18.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00002-of-00002.bin",
103
+ "encoder.block.18.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00002-of-00002.bin",
104
+ "encoder.block.18.layer.1.DenseReluDense.wo.weight": "pytorch_model-00002-of-00002.bin",
105
+ "encoder.block.18.layer.1.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
106
+ "encoder.block.19.layer.0.SelfAttention.k.weight": "pytorch_model-00002-of-00002.bin",
107
+ "encoder.block.19.layer.0.SelfAttention.o.weight": "pytorch_model-00002-of-00002.bin",
108
+ "encoder.block.19.layer.0.SelfAttention.q.weight": "pytorch_model-00002-of-00002.bin",
109
+ "encoder.block.19.layer.0.SelfAttention.v.weight": "pytorch_model-00002-of-00002.bin",
110
+ "encoder.block.19.layer.0.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
111
+ "encoder.block.19.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00002-of-00002.bin",
112
+ "encoder.block.19.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00002-of-00002.bin",
113
+ "encoder.block.19.layer.1.DenseReluDense.wo.weight": "pytorch_model-00002-of-00002.bin",
114
+ "encoder.block.19.layer.1.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
115
+ "encoder.block.2.layer.0.SelfAttention.k.weight": "pytorch_model-00001-of-00002.bin",
116
+ "encoder.block.2.layer.0.SelfAttention.o.weight": "pytorch_model-00001-of-00002.bin",
117
+ "encoder.block.2.layer.0.SelfAttention.q.weight": "pytorch_model-00001-of-00002.bin",
118
+ "encoder.block.2.layer.0.SelfAttention.v.weight": "pytorch_model-00001-of-00002.bin",
119
+ "encoder.block.2.layer.0.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
120
+ "encoder.block.2.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00001-of-00002.bin",
121
+ "encoder.block.2.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00001-of-00002.bin",
122
+ "encoder.block.2.layer.1.DenseReluDense.wo.weight": "pytorch_model-00001-of-00002.bin",
123
+ "encoder.block.2.layer.1.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
124
+ "encoder.block.20.layer.0.SelfAttention.k.weight": "pytorch_model-00002-of-00002.bin",
125
+ "encoder.block.20.layer.0.SelfAttention.o.weight": "pytorch_model-00002-of-00002.bin",
126
+ "encoder.block.20.layer.0.SelfAttention.q.weight": "pytorch_model-00002-of-00002.bin",
127
+ "encoder.block.20.layer.0.SelfAttention.v.weight": "pytorch_model-00002-of-00002.bin",
128
+ "encoder.block.20.layer.0.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
129
+ "encoder.block.20.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00002-of-00002.bin",
130
+ "encoder.block.20.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00002-of-00002.bin",
131
+ "encoder.block.20.layer.1.DenseReluDense.wo.weight": "pytorch_model-00002-of-00002.bin",
132
+ "encoder.block.20.layer.1.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
133
+ "encoder.block.21.layer.0.SelfAttention.k.weight": "pytorch_model-00002-of-00002.bin",
134
+ "encoder.block.21.layer.0.SelfAttention.o.weight": "pytorch_model-00002-of-00002.bin",
135
+ "encoder.block.21.layer.0.SelfAttention.q.weight": "pytorch_model-00002-of-00002.bin",
136
+ "encoder.block.21.layer.0.SelfAttention.v.weight": "pytorch_model-00002-of-00002.bin",
137
+ "encoder.block.21.layer.0.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
138
+ "encoder.block.21.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00002-of-00002.bin",
139
+ "encoder.block.21.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00002-of-00002.bin",
140
+ "encoder.block.21.layer.1.DenseReluDense.wo.weight": "pytorch_model-00002-of-00002.bin",
141
+ "encoder.block.21.layer.1.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
142
+ "encoder.block.22.layer.0.SelfAttention.k.weight": "pytorch_model-00002-of-00002.bin",
143
+ "encoder.block.22.layer.0.SelfAttention.o.weight": "pytorch_model-00002-of-00002.bin",
144
+ "encoder.block.22.layer.0.SelfAttention.q.weight": "pytorch_model-00002-of-00002.bin",
145
+ "encoder.block.22.layer.0.SelfAttention.v.weight": "pytorch_model-00002-of-00002.bin",
146
+ "encoder.block.22.layer.0.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
147
+ "encoder.block.22.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00002-of-00002.bin",
148
+ "encoder.block.22.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00002-of-00002.bin",
149
+ "encoder.block.22.layer.1.DenseReluDense.wo.weight": "pytorch_model-00002-of-00002.bin",
150
+ "encoder.block.22.layer.1.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
151
+ "encoder.block.23.layer.0.SelfAttention.k.weight": "pytorch_model-00002-of-00002.bin",
152
+ "encoder.block.23.layer.0.SelfAttention.o.weight": "pytorch_model-00002-of-00002.bin",
153
+ "encoder.block.23.layer.0.SelfAttention.q.weight": "pytorch_model-00002-of-00002.bin",
154
+ "encoder.block.23.layer.0.SelfAttention.v.weight": "pytorch_model-00002-of-00002.bin",
155
+ "encoder.block.23.layer.0.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
156
+ "encoder.block.23.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00002-of-00002.bin",
157
+ "encoder.block.23.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00002-of-00002.bin",
158
+ "encoder.block.23.layer.1.DenseReluDense.wo.weight": "pytorch_model-00002-of-00002.bin",
159
+ "encoder.block.23.layer.1.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
160
+ "encoder.block.3.layer.0.SelfAttention.k.weight": "pytorch_model-00001-of-00002.bin",
161
+ "encoder.block.3.layer.0.SelfAttention.o.weight": "pytorch_model-00001-of-00002.bin",
162
+ "encoder.block.3.layer.0.SelfAttention.q.weight": "pytorch_model-00001-of-00002.bin",
163
+ "encoder.block.3.layer.0.SelfAttention.v.weight": "pytorch_model-00001-of-00002.bin",
164
+ "encoder.block.3.layer.0.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
165
+ "encoder.block.3.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00001-of-00002.bin",
166
+ "encoder.block.3.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00001-of-00002.bin",
167
+ "encoder.block.3.layer.1.DenseReluDense.wo.weight": "pytorch_model-00001-of-00002.bin",
168
+ "encoder.block.3.layer.1.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
169
+ "encoder.block.4.layer.0.SelfAttention.k.weight": "pytorch_model-00001-of-00002.bin",
170
+ "encoder.block.4.layer.0.SelfAttention.o.weight": "pytorch_model-00001-of-00002.bin",
171
+ "encoder.block.4.layer.0.SelfAttention.q.weight": "pytorch_model-00001-of-00002.bin",
172
+ "encoder.block.4.layer.0.SelfAttention.v.weight": "pytorch_model-00001-of-00002.bin",
173
+ "encoder.block.4.layer.0.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
174
+ "encoder.block.4.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00001-of-00002.bin",
175
+ "encoder.block.4.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00001-of-00002.bin",
176
+ "encoder.block.4.layer.1.DenseReluDense.wo.weight": "pytorch_model-00001-of-00002.bin",
177
+ "encoder.block.4.layer.1.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
178
+ "encoder.block.5.layer.0.SelfAttention.k.weight": "pytorch_model-00001-of-00002.bin",
179
+ "encoder.block.5.layer.0.SelfAttention.o.weight": "pytorch_model-00001-of-00002.bin",
180
+ "encoder.block.5.layer.0.SelfAttention.q.weight": "pytorch_model-00001-of-00002.bin",
181
+ "encoder.block.5.layer.0.SelfAttention.v.weight": "pytorch_model-00001-of-00002.bin",
182
+ "encoder.block.5.layer.0.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
183
+ "encoder.block.5.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00001-of-00002.bin",
184
+ "encoder.block.5.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00001-of-00002.bin",
185
+ "encoder.block.5.layer.1.DenseReluDense.wo.weight": "pytorch_model-00001-of-00002.bin",
186
+ "encoder.block.5.layer.1.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
187
+ "encoder.block.6.layer.0.SelfAttention.k.weight": "pytorch_model-00001-of-00002.bin",
188
+ "encoder.block.6.layer.0.SelfAttention.o.weight": "pytorch_model-00001-of-00002.bin",
189
+ "encoder.block.6.layer.0.SelfAttention.q.weight": "pytorch_model-00001-of-00002.bin",
190
+ "encoder.block.6.layer.0.SelfAttention.v.weight": "pytorch_model-00001-of-00002.bin",
191
+ "encoder.block.6.layer.0.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
192
+ "encoder.block.6.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00001-of-00002.bin",
193
+ "encoder.block.6.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00001-of-00002.bin",
194
+ "encoder.block.6.layer.1.DenseReluDense.wo.weight": "pytorch_model-00001-of-00002.bin",
195
+ "encoder.block.6.layer.1.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
196
+ "encoder.block.7.layer.0.SelfAttention.k.weight": "pytorch_model-00001-of-00002.bin",
197
+ "encoder.block.7.layer.0.SelfAttention.o.weight": "pytorch_model-00001-of-00002.bin",
198
+ "encoder.block.7.layer.0.SelfAttention.q.weight": "pytorch_model-00001-of-00002.bin",
199
+ "encoder.block.7.layer.0.SelfAttention.v.weight": "pytorch_model-00001-of-00002.bin",
200
+ "encoder.block.7.layer.0.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
201
+ "encoder.block.7.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00001-of-00002.bin",
202
+ "encoder.block.7.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00001-of-00002.bin",
203
+ "encoder.block.7.layer.1.DenseReluDense.wo.weight": "pytorch_model-00001-of-00002.bin",
204
+ "encoder.block.7.layer.1.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
205
+ "encoder.block.8.layer.0.SelfAttention.k.weight": "pytorch_model-00001-of-00002.bin",
206
+ "encoder.block.8.layer.0.SelfAttention.o.weight": "pytorch_model-00001-of-00002.bin",
207
+ "encoder.block.8.layer.0.SelfAttention.q.weight": "pytorch_model-00001-of-00002.bin",
208
+ "encoder.block.8.layer.0.SelfAttention.v.weight": "pytorch_model-00001-of-00002.bin",
209
+ "encoder.block.8.layer.0.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
210
+ "encoder.block.8.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00001-of-00002.bin",
211
+ "encoder.block.8.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00001-of-00002.bin",
212
+ "encoder.block.8.layer.1.DenseReluDense.wo.weight": "pytorch_model-00001-of-00002.bin",
213
+ "encoder.block.8.layer.1.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
214
+ "encoder.block.9.layer.0.SelfAttention.k.weight": "pytorch_model-00001-of-00002.bin",
215
+ "encoder.block.9.layer.0.SelfAttention.o.weight": "pytorch_model-00001-of-00002.bin",
216
+ "encoder.block.9.layer.0.SelfAttention.q.weight": "pytorch_model-00001-of-00002.bin",
217
+ "encoder.block.9.layer.0.SelfAttention.v.weight": "pytorch_model-00001-of-00002.bin",
218
+ "encoder.block.9.layer.0.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
219
+ "encoder.block.9.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00001-of-00002.bin",
220
+ "encoder.block.9.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00001-of-00002.bin",
221
+ "encoder.block.9.layer.1.DenseReluDense.wo.weight": "pytorch_model-00001-of-00002.bin",
222
+ "encoder.block.9.layer.1.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
223
+ "encoder.embed_tokens.weight": "pytorch_model-00001-of-00002.bin",
224
+ "encoder.final_layer_norm.weight": "pytorch_model-00002-of-00002.bin",
225
+ "shared.weight": "pytorch_model-00001-of-00002.bin"
226
+ }
227
+ }
tokenizer/added_tokens.json ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<extra_id_0>": 32099,
3
+ "<extra_id_10>": 32089,
4
+ "<extra_id_11>": 32088,
5
+ "<extra_id_12>": 32087,
6
+ "<extra_id_13>": 32086,
7
+ "<extra_id_14>": 32085,
8
+ "<extra_id_15>": 32084,
9
+ "<extra_id_16>": 32083,
10
+ "<extra_id_17>": 32082,
11
+ "<extra_id_18>": 32081,
12
+ "<extra_id_19>": 32080,
13
+ "<extra_id_1>": 32098,
14
+ "<extra_id_20>": 32079,
15
+ "<extra_id_21>": 32078,
16
+ "<extra_id_22>": 32077,
17
+ "<extra_id_23>": 32076,
18
+ "<extra_id_24>": 32075,
19
+ "<extra_id_25>": 32074,
20
+ "<extra_id_26>": 32073,
21
+ "<extra_id_27>": 32072,
22
+ "<extra_id_28>": 32071,
23
+ "<extra_id_29>": 32070,
24
+ "<extra_id_2>": 32097,
25
+ "<extra_id_30>": 32069,
26
+ "<extra_id_31>": 32068,
27
+ "<extra_id_32>": 32067,
28
+ "<extra_id_33>": 32066,
29
+ "<extra_id_34>": 32065,
30
+ "<extra_id_35>": 32064,
31
+ "<extra_id_36>": 32063,
32
+ "<extra_id_37>": 32062,
33
+ "<extra_id_38>": 32061,
34
+ "<extra_id_39>": 32060,
35
+ "<extra_id_3>": 32096,
36
+ "<extra_id_40>": 32059,
37
+ "<extra_id_41>": 32058,
38
+ "<extra_id_42>": 32057,
39
+ "<extra_id_43>": 32056,
40
+ "<extra_id_44>": 32055,
41
+ "<extra_id_45>": 32054,
42
+ "<extra_id_46>": 32053,
43
+ "<extra_id_47>": 32052,
44
+ "<extra_id_48>": 32051,
45
+ "<extra_id_49>": 32050,
46
+ "<extra_id_4>": 32095,
47
+ "<extra_id_50>": 32049,
48
+ "<extra_id_51>": 32048,
49
+ "<extra_id_52>": 32047,
50
+ "<extra_id_53>": 32046,
51
+ "<extra_id_54>": 32045,
52
+ "<extra_id_55>": 32044,
53
+ "<extra_id_56>": 32043,
54
+ "<extra_id_57>": 32042,
55
+ "<extra_id_58>": 32041,
56
+ "<extra_id_59>": 32040,
57
+ "<extra_id_5>": 32094,
58
+ "<extra_id_60>": 32039,
59
+ "<extra_id_61>": 32038,
60
+ "<extra_id_62>": 32037,
61
+ "<extra_id_63>": 32036,
62
+ "<extra_id_64>": 32035,
63
+ "<extra_id_65>": 32034,
64
+ "<extra_id_66>": 32033,
65
+ "<extra_id_67>": 32032,
66
+ "<extra_id_68>": 32031,
67
+ "<extra_id_69>": 32030,
68
+ "<extra_id_6>": 32093,
69
+ "<extra_id_70>": 32029,
70
+ "<extra_id_71>": 32028,
71
+ "<extra_id_72>": 32027,
72
+ "<extra_id_73>": 32026,
73
+ "<extra_id_74>": 32025,
74
+ "<extra_id_75>": 32024,
75
+ "<extra_id_76>": 32023,
76
+ "<extra_id_77>": 32022,
77
+ "<extra_id_78>": 32021,
78
+ "<extra_id_79>": 32020,
79
+ "<extra_id_7>": 32092,
80
+ "<extra_id_80>": 32019,
81
+ "<extra_id_81>": 32018,
82
+ "<extra_id_82>": 32017,
83
+ "<extra_id_83>": 32016,
84
+ "<extra_id_84>": 32015,
85
+ "<extra_id_85>": 32014,
86
+ "<extra_id_86>": 32013,
87
+ "<extra_id_87>": 32012,
88
+ "<extra_id_88>": 32011,
89
+ "<extra_id_89>": 32010,
90
+ "<extra_id_8>": 32091,
91
+ "<extra_id_90>": 32009,
92
+ "<extra_id_91>": 32008,
93
+ "<extra_id_92>": 32007,
94
+ "<extra_id_93>": 32006,
95
+ "<extra_id_94>": 32005,
96
+ "<extra_id_95>": 32004,
97
+ "<extra_id_96>": 32003,
98
+ "<extra_id_97>": 32002,
99
+ "<extra_id_98>": 32001,
100
+ "<extra_id_99>": 32000,
101
+ "<extra_id_9>": 32090
102
+ }
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<extra_id_0>",
4
+ "<extra_id_1>",
5
+ "<extra_id_2>",
6
+ "<extra_id_3>",
7
+ "<extra_id_4>",
8
+ "<extra_id_5>",
9
+ "<extra_id_6>",
10
+ "<extra_id_7>",
11
+ "<extra_id_8>",
12
+ "<extra_id_9>",
13
+ "<extra_id_10>",
14
+ "<extra_id_11>",
15
+ "<extra_id_12>",
16
+ "<extra_id_13>",
17
+ "<extra_id_14>",
18
+ "<extra_id_15>",
19
+ "<extra_id_16>",
20
+ "<extra_id_17>",
21
+ "<extra_id_18>",
22
+ "<extra_id_19>",
23
+ "<extra_id_20>",
24
+ "<extra_id_21>",
25
+ "<extra_id_22>",
26
+ "<extra_id_23>",
27
+ "<extra_id_24>",
28
+ "<extra_id_25>",
29
+ "<extra_id_26>",
30
+ "<extra_id_27>",
31
+ "<extra_id_28>",
32
+ "<extra_id_29>",
33
+ "<extra_id_30>",
34
+ "<extra_id_31>",
35
+ "<extra_id_32>",
36
+ "<extra_id_33>",
37
+ "<extra_id_34>",
38
+ "<extra_id_35>",
39
+ "<extra_id_36>",
40
+ "<extra_id_37>",
41
+ "<extra_id_38>",
42
+ "<extra_id_39>",
43
+ "<extra_id_40>",
44
+ "<extra_id_41>",
45
+ "<extra_id_42>",
46
+ "<extra_id_43>",
47
+ "<extra_id_44>",
48
+ "<extra_id_45>",
49
+ "<extra_id_46>",
50
+ "<extra_id_47>",
51
+ "<extra_id_48>",
52
+ "<extra_id_49>",
53
+ "<extra_id_50>",
54
+ "<extra_id_51>",
55
+ "<extra_id_52>",
56
+ "<extra_id_53>",
57
+ "<extra_id_54>",
58
+ "<extra_id_55>",
59
+ "<extra_id_56>",
60
+ "<extra_id_57>",
61
+ "<extra_id_58>",
62
+ "<extra_id_59>",
63
+ "<extra_id_60>",
64
+ "<extra_id_61>",
65
+ "<extra_id_62>",
66
+ "<extra_id_63>",
67
+ "<extra_id_64>",
68
+ "<extra_id_65>",
69
+ "<extra_id_66>",
70
+ "<extra_id_67>",
71
+ "<extra_id_68>",
72
+ "<extra_id_69>",
73
+ "<extra_id_70>",
74
+ "<extra_id_71>",
75
+ "<extra_id_72>",
76
+ "<extra_id_73>",
77
+ "<extra_id_74>",
78
+ "<extra_id_75>",
79
+ "<extra_id_76>",
80
+ "<extra_id_77>",
81
+ "<extra_id_78>",
82
+ "<extra_id_79>",
83
+ "<extra_id_80>",
84
+ "<extra_id_81>",
85
+ "<extra_id_82>",
86
+ "<extra_id_83>",
87
+ "<extra_id_84>",
88
+ "<extra_id_85>",
89
+ "<extra_id_86>",
90
+ "<extra_id_87>",
91
+ "<extra_id_88>",
92
+ "<extra_id_89>",
93
+ "<extra_id_90>",
94
+ "<extra_id_91>",
95
+ "<extra_id_92>",
96
+ "<extra_id_93>",
97
+ "<extra_id_94>",
98
+ "<extra_id_95>",
99
+ "<extra_id_96>",
100
+ "<extra_id_97>",
101
+ "<extra_id_98>",
102
+ "<extra_id_99>"
103
+ ],
104
+ "eos_token": {
105
+ "content": "</s>",
106
+ "lstrip": false,
107
+ "normalized": false,
108
+ "rstrip": false,
109
+ "single_word": false
110
+ },
111
+ "pad_token": {
112
+ "content": "<pad>",
113
+ "lstrip": false,
114
+ "normalized": false,
115
+ "rstrip": false,
116
+ "single_word": false
117
+ },
118
+ "unk_token": {
119
+ "content": "<unk>",
120
+ "lstrip": false,
121
+ "normalized": false,
122
+ "rstrip": false,
123
+ "single_word": false
124
+ }
125
+ }
tokenizer/spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
3
+ size 791656
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,940 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": true,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<pad>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "</s>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "<unk>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "32000": {
29
+ "content": "<extra_id_99>",
30
+ "lstrip": true,
31
+ "normalized": false,
32
+ "rstrip": true,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "32001": {
37
+ "content": "<extra_id_98>",
38
+ "lstrip": true,
39
+ "normalized": false,
40
+ "rstrip": true,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "32002": {
45
+ "content": "<extra_id_97>",
46
+ "lstrip": true,
47
+ "normalized": false,
48
+ "rstrip": true,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "32003": {
53
+ "content": "<extra_id_96>",
54
+ "lstrip": true,
55
+ "normalized": false,
56
+ "rstrip": true,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "32004": {
61
+ "content": "<extra_id_95>",
62
+ "lstrip": true,
63
+ "normalized": false,
64
+ "rstrip": true,
65
+ "single_word": false,
66
+ "special": true
67
+ },
68
+ "32005": {
69
+ "content": "<extra_id_94>",
70
+ "lstrip": true,
71
+ "normalized": false,
72
+ "rstrip": true,
73
+ "single_word": false,
74
+ "special": true
75
+ },
76
+ "32006": {
77
+ "content": "<extra_id_93>",
78
+ "lstrip": true,
79
+ "normalized": false,
80
+ "rstrip": true,
81
+ "single_word": false,
82
+ "special": true
83
+ },
84
+ "32007": {
85
+ "content": "<extra_id_92>",
86
+ "lstrip": true,
87
+ "normalized": false,
88
+ "rstrip": true,
89
+ "single_word": false,
90
+ "special": true
91
+ },
92
+ "32008": {
93
+ "content": "<extra_id_91>",
94
+ "lstrip": true,
95
+ "normalized": false,
96
+ "rstrip": true,
97
+ "single_word": false,
98
+ "special": true
99
+ },
100
+ "32009": {
101
+ "content": "<extra_id_90>",
102
+ "lstrip": true,
103
+ "normalized": false,
104
+ "rstrip": true,
105
+ "single_word": false,
106
+ "special": true
107
+ },
108
+ "32010": {
109
+ "content": "<extra_id_89>",
110
+ "lstrip": true,
111
+ "normalized": false,
112
+ "rstrip": true,
113
+ "single_word": false,
114
+ "special": true
115
+ },
116
+ "32011": {
117
+ "content": "<extra_id_88>",
118
+ "lstrip": true,
119
+ "normalized": false,
120
+ "rstrip": true,
121
+ "single_word": false,
122
+ "special": true
123
+ },
124
+ "32012": {
125
+ "content": "<extra_id_87>",
126
+ "lstrip": true,
127
+ "normalized": false,
128
+ "rstrip": true,
129
+ "single_word": false,
130
+ "special": true
131
+ },
132
+ "32013": {
133
+ "content": "<extra_id_86>",
134
+ "lstrip": true,
135
+ "normalized": false,
136
+ "rstrip": true,
137
+ "single_word": false,
138
+ "special": true
139
+ },
140
+ "32014": {
141
+ "content": "<extra_id_85>",
142
+ "lstrip": true,
143
+ "normalized": false,
144
+ "rstrip": true,
145
+ "single_word": false,
146
+ "special": true
147
+ },
148
+ "32015": {
149
+ "content": "<extra_id_84>",
150
+ "lstrip": true,
151
+ "normalized": false,
152
+ "rstrip": true,
153
+ "single_word": false,
154
+ "special": true
155
+ },
156
+ "32016": {
157
+ "content": "<extra_id_83>",
158
+ "lstrip": true,
159
+ "normalized": false,
160
+ "rstrip": true,
161
+ "single_word": false,
162
+ "special": true
163
+ },
164
+ "32017": {
165
+ "content": "<extra_id_82>",
166
+ "lstrip": true,
167
+ "normalized": false,
168
+ "rstrip": true,
169
+ "single_word": false,
170
+ "special": true
171
+ },
172
+ "32018": {
173
+ "content": "<extra_id_81>",
174
+ "lstrip": true,
175
+ "normalized": false,
176
+ "rstrip": true,
177
+ "single_word": false,
178
+ "special": true
179
+ },
180
+ "32019": {
181
+ "content": "<extra_id_80>",
182
+ "lstrip": true,
183
+ "normalized": false,
184
+ "rstrip": true,
185
+ "single_word": false,
186
+ "special": true
187
+ },
188
+ "32020": {
189
+ "content": "<extra_id_79>",
190
+ "lstrip": true,
191
+ "normalized": false,
192
+ "rstrip": true,
193
+ "single_word": false,
194
+ "special": true
195
+ },
196
+ "32021": {
197
+ "content": "<extra_id_78>",
198
+ "lstrip": true,
199
+ "normalized": false,
200
+ "rstrip": true,
201
+ "single_word": false,
202
+ "special": true
203
+ },
204
+ "32022": {
205
+ "content": "<extra_id_77>",
206
+ "lstrip": true,
207
+ "normalized": false,
208
+ "rstrip": true,
209
+ "single_word": false,
210
+ "special": true
211
+ },
212
+ "32023": {
213
+ "content": "<extra_id_76>",
214
+ "lstrip": true,
215
+ "normalized": false,
216
+ "rstrip": true,
217
+ "single_word": false,
218
+ "special": true
219
+ },
220
+ "32024": {
221
+ "content": "<extra_id_75>",
222
+ "lstrip": true,
223
+ "normalized": false,
224
+ "rstrip": true,
225
+ "single_word": false,
226
+ "special": true
227
+ },
228
+ "32025": {
229
+ "content": "<extra_id_74>",
230
+ "lstrip": true,
231
+ "normalized": false,
232
+ "rstrip": true,
233
+ "single_word": false,
234
+ "special": true
235
+ },
236
+ "32026": {
237
+ "content": "<extra_id_73>",
238
+ "lstrip": true,
239
+ "normalized": false,
240
+ "rstrip": true,
241
+ "single_word": false,
242
+ "special": true
243
+ },
244
+ "32027": {
245
+ "content": "<extra_id_72>",
246
+ "lstrip": true,
247
+ "normalized": false,
248
+ "rstrip": true,
249
+ "single_word": false,
250
+ "special": true
251
+ },
252
+ "32028": {
253
+ "content": "<extra_id_71>",
254
+ "lstrip": true,
255
+ "normalized": false,
256
+ "rstrip": true,
257
+ "single_word": false,
258
+ "special": true
259
+ },
260
+ "32029": {
261
+ "content": "<extra_id_70>",
262
+ "lstrip": true,
263
+ "normalized": false,
264
+ "rstrip": true,
265
+ "single_word": false,
266
+ "special": true
267
+ },
268
+ "32030": {
269
+ "content": "<extra_id_69>",
270
+ "lstrip": true,
271
+ "normalized": false,
272
+ "rstrip": true,
273
+ "single_word": false,
274
+ "special": true
275
+ },
276
+ "32031": {
277
+ "content": "<extra_id_68>",
278
+ "lstrip": true,
279
+ "normalized": false,
280
+ "rstrip": true,
281
+ "single_word": false,
282
+ "special": true
283
+ },
284
+ "32032": {
285
+ "content": "<extra_id_67>",
286
+ "lstrip": true,
287
+ "normalized": false,
288
+ "rstrip": true,
289
+ "single_word": false,
290
+ "special": true
291
+ },
292
+ "32033": {
293
+ "content": "<extra_id_66>",
294
+ "lstrip": true,
295
+ "normalized": false,
296
+ "rstrip": true,
297
+ "single_word": false,
298
+ "special": true
299
+ },
300
+ "32034": {
301
+ "content": "<extra_id_65>",
302
+ "lstrip": true,
303
+ "normalized": false,
304
+ "rstrip": true,
305
+ "single_word": false,
306
+ "special": true
307
+ },
308
+ "32035": {
309
+ "content": "<extra_id_64>",
310
+ "lstrip": true,
311
+ "normalized": false,
312
+ "rstrip": true,
313
+ "single_word": false,
314
+ "special": true
315
+ },
316
+ "32036": {
317
+ "content": "<extra_id_63>",
318
+ "lstrip": true,
319
+ "normalized": false,
320
+ "rstrip": true,
321
+ "single_word": false,
322
+ "special": true
323
+ },
324
+ "32037": {
325
+ "content": "<extra_id_62>",
326
+ "lstrip": true,
327
+ "normalized": false,
328
+ "rstrip": true,
329
+ "single_word": false,
330
+ "special": true
331
+ },
332
+ "32038": {
333
+ "content": "<extra_id_61>",
334
+ "lstrip": true,
335
+ "normalized": false,
336
+ "rstrip": true,
337
+ "single_word": false,
338
+ "special": true
339
+ },
340
+ "32039": {
341
+ "content": "<extra_id_60>",
342
+ "lstrip": true,
343
+ "normalized": false,
344
+ "rstrip": true,
345
+ "single_word": false,
346
+ "special": true
347
+ },
348
+ "32040": {
349
+ "content": "<extra_id_59>",
350
+ "lstrip": true,
351
+ "normalized": false,
352
+ "rstrip": true,
353
+ "single_word": false,
354
+ "special": true
355
+ },
356
+ "32041": {
357
+ "content": "<extra_id_58>",
358
+ "lstrip": true,
359
+ "normalized": false,
360
+ "rstrip": true,
361
+ "single_word": false,
362
+ "special": true
363
+ },
364
+ "32042": {
365
+ "content": "<extra_id_57>",
366
+ "lstrip": true,
367
+ "normalized": false,
368
+ "rstrip": true,
369
+ "single_word": false,
370
+ "special": true
371
+ },
372
+ "32043": {
373
+ "content": "<extra_id_56>",
374
+ "lstrip": true,
375
+ "normalized": false,
376
+ "rstrip": true,
377
+ "single_word": false,
378
+ "special": true
379
+ },
380
+ "32044": {
381
+ "content": "<extra_id_55>",
382
+ "lstrip": true,
383
+ "normalized": false,
384
+ "rstrip": true,
385
+ "single_word": false,
386
+ "special": true
387
+ },
388
+ "32045": {
389
+ "content": "<extra_id_54>",
390
+ "lstrip": true,
391
+ "normalized": false,
392
+ "rstrip": true,
393
+ "single_word": false,
394
+ "special": true
395
+ },
396
+ "32046": {
397
+ "content": "<extra_id_53>",
398
+ "lstrip": true,
399
+ "normalized": false,
400
+ "rstrip": true,
401
+ "single_word": false,
402
+ "special": true
403
+ },
404
+ "32047": {
405
+ "content": "<extra_id_52>",
406
+ "lstrip": true,
407
+ "normalized": false,
408
+ "rstrip": true,
409
+ "single_word": false,
410
+ "special": true
411
+ },
412
+ "32048": {
413
+ "content": "<extra_id_51>",
414
+ "lstrip": true,
415
+ "normalized": false,
416
+ "rstrip": true,
417
+ "single_word": false,
418
+ "special": true
419
+ },
420
+ "32049": {
421
+ "content": "<extra_id_50>",
422
+ "lstrip": true,
423
+ "normalized": false,
424
+ "rstrip": true,
425
+ "single_word": false,
426
+ "special": true
427
+ },
428
+ "32050": {
429
+ "content": "<extra_id_49>",
430
+ "lstrip": true,
431
+ "normalized": false,
432
+ "rstrip": true,
433
+ "single_word": false,
434
+ "special": true
435
+ },
436
+ "32051": {
437
+ "content": "<extra_id_48>",
438
+ "lstrip": true,
439
+ "normalized": false,
440
+ "rstrip": true,
441
+ "single_word": false,
442
+ "special": true
443
+ },
444
+ "32052": {
445
+ "content": "<extra_id_47>",
446
+ "lstrip": true,
447
+ "normalized": false,
448
+ "rstrip": true,
449
+ "single_word": false,
450
+ "special": true
451
+ },
452
+ "32053": {
453
+ "content": "<extra_id_46>",
454
+ "lstrip": true,
455
+ "normalized": false,
456
+ "rstrip": true,
457
+ "single_word": false,
458
+ "special": true
459
+ },
460
+ "32054": {
461
+ "content": "<extra_id_45>",
462
+ "lstrip": true,
463
+ "normalized": false,
464
+ "rstrip": true,
465
+ "single_word": false,
466
+ "special": true
467
+ },
468
+ "32055": {
469
+ "content": "<extra_id_44>",
470
+ "lstrip": true,
471
+ "normalized": false,
472
+ "rstrip": true,
473
+ "single_word": false,
474
+ "special": true
475
+ },
476
+ "32056": {
477
+ "content": "<extra_id_43>",
478
+ "lstrip": true,
479
+ "normalized": false,
480
+ "rstrip": true,
481
+ "single_word": false,
482
+ "special": true
483
+ },
484
+ "32057": {
485
+ "content": "<extra_id_42>",
486
+ "lstrip": true,
487
+ "normalized": false,
488
+ "rstrip": true,
489
+ "single_word": false,
490
+ "special": true
491
+ },
492
+ "32058": {
493
+ "content": "<extra_id_41>",
494
+ "lstrip": true,
495
+ "normalized": false,
496
+ "rstrip": true,
497
+ "single_word": false,
498
+ "special": true
499
+ },
500
+ "32059": {
501
+ "content": "<extra_id_40>",
502
+ "lstrip": true,
503
+ "normalized": false,
504
+ "rstrip": true,
505
+ "single_word": false,
506
+ "special": true
507
+ },
508
+ "32060": {
509
+ "content": "<extra_id_39>",
510
+ "lstrip": true,
511
+ "normalized": false,
512
+ "rstrip": true,
513
+ "single_word": false,
514
+ "special": true
515
+ },
516
+ "32061": {
517
+ "content": "<extra_id_38>",
518
+ "lstrip": true,
519
+ "normalized": false,
520
+ "rstrip": true,
521
+ "single_word": false,
522
+ "special": true
523
+ },
524
+ "32062": {
525
+ "content": "<extra_id_37>",
526
+ "lstrip": true,
527
+ "normalized": false,
528
+ "rstrip": true,
529
+ "single_word": false,
530
+ "special": true
531
+ },
532
+ "32063": {
533
+ "content": "<extra_id_36>",
534
+ "lstrip": true,
535
+ "normalized": false,
536
+ "rstrip": true,
537
+ "single_word": false,
538
+ "special": true
539
+ },
540
+ "32064": {
541
+ "content": "<extra_id_35>",
542
+ "lstrip": true,
543
+ "normalized": false,
544
+ "rstrip": true,
545
+ "single_word": false,
546
+ "special": true
547
+ },
548
+ "32065": {
549
+ "content": "<extra_id_34>",
550
+ "lstrip": true,
551
+ "normalized": false,
552
+ "rstrip": true,
553
+ "single_word": false,
554
+ "special": true
555
+ },
556
+ "32066": {
557
+ "content": "<extra_id_33>",
558
+ "lstrip": true,
559
+ "normalized": false,
560
+ "rstrip": true,
561
+ "single_word": false,
562
+ "special": true
563
+ },
564
+ "32067": {
565
+ "content": "<extra_id_32>",
566
+ "lstrip": true,
567
+ "normalized": false,
568
+ "rstrip": true,
569
+ "single_word": false,
570
+ "special": true
571
+ },
572
+ "32068": {
573
+ "content": "<extra_id_31>",
574
+ "lstrip": true,
575
+ "normalized": false,
576
+ "rstrip": true,
577
+ "single_word": false,
578
+ "special": true
579
+ },
580
+ "32069": {
581
+ "content": "<extra_id_30>",
582
+ "lstrip": true,
583
+ "normalized": false,
584
+ "rstrip": true,
585
+ "single_word": false,
586
+ "special": true
587
+ },
588
+ "32070": {
589
+ "content": "<extra_id_29>",
590
+ "lstrip": true,
591
+ "normalized": false,
592
+ "rstrip": true,
593
+ "single_word": false,
594
+ "special": true
595
+ },
596
+ "32071": {
597
+ "content": "<extra_id_28>",
598
+ "lstrip": true,
599
+ "normalized": false,
600
+ "rstrip": true,
601
+ "single_word": false,
602
+ "special": true
603
+ },
604
+ "32072": {
605
+ "content": "<extra_id_27>",
606
+ "lstrip": true,
607
+ "normalized": false,
608
+ "rstrip": true,
609
+ "single_word": false,
610
+ "special": true
611
+ },
612
+ "32073": {
613
+ "content": "<extra_id_26>",
614
+ "lstrip": true,
615
+ "normalized": false,
616
+ "rstrip": true,
617
+ "single_word": false,
618
+ "special": true
619
+ },
620
+ "32074": {
621
+ "content": "<extra_id_25>",
622
+ "lstrip": true,
623
+ "normalized": false,
624
+ "rstrip": true,
625
+ "single_word": false,
626
+ "special": true
627
+ },
628
+ "32075": {
629
+ "content": "<extra_id_24>",
630
+ "lstrip": true,
631
+ "normalized": false,
632
+ "rstrip": true,
633
+ "single_word": false,
634
+ "special": true
635
+ },
636
+ "32076": {
637
+ "content": "<extra_id_23>",
638
+ "lstrip": true,
639
+ "normalized": false,
640
+ "rstrip": true,
641
+ "single_word": false,
642
+ "special": true
643
+ },
644
+ "32077": {
645
+ "content": "<extra_id_22>",
646
+ "lstrip": true,
647
+ "normalized": false,
648
+ "rstrip": true,
649
+ "single_word": false,
650
+ "special": true
651
+ },
652
+ "32078": {
653
+ "content": "<extra_id_21>",
654
+ "lstrip": true,
655
+ "normalized": false,
656
+ "rstrip": true,
657
+ "single_word": false,
658
+ "special": true
659
+ },
660
+ "32079": {
661
+ "content": "<extra_id_20>",
662
+ "lstrip": true,
663
+ "normalized": false,
664
+ "rstrip": true,
665
+ "single_word": false,
666
+ "special": true
667
+ },
668
+ "32080": {
669
+ "content": "<extra_id_19>",
670
+ "lstrip": true,
671
+ "normalized": false,
672
+ "rstrip": true,
673
+ "single_word": false,
674
+ "special": true
675
+ },
676
+ "32081": {
677
+ "content": "<extra_id_18>",
678
+ "lstrip": true,
679
+ "normalized": false,
680
+ "rstrip": true,
681
+ "single_word": false,
682
+ "special": true
683
+ },
684
+ "32082": {
685
+ "content": "<extra_id_17>",
686
+ "lstrip": true,
687
+ "normalized": false,
688
+ "rstrip": true,
689
+ "single_word": false,
690
+ "special": true
691
+ },
692
+ "32083": {
693
+ "content": "<extra_id_16>",
694
+ "lstrip": true,
695
+ "normalized": false,
696
+ "rstrip": true,
697
+ "single_word": false,
698
+ "special": true
699
+ },
700
+ "32084": {
701
+ "content": "<extra_id_15>",
702
+ "lstrip": true,
703
+ "normalized": false,
704
+ "rstrip": true,
705
+ "single_word": false,
706
+ "special": true
707
+ },
708
+ "32085": {
709
+ "content": "<extra_id_14>",
710
+ "lstrip": true,
711
+ "normalized": false,
712
+ "rstrip": true,
713
+ "single_word": false,
714
+ "special": true
715
+ },
716
+ "32086": {
717
+ "content": "<extra_id_13>",
718
+ "lstrip": true,
719
+ "normalized": false,
720
+ "rstrip": true,
721
+ "single_word": false,
722
+ "special": true
723
+ },
724
+ "32087": {
725
+ "content": "<extra_id_12>",
726
+ "lstrip": true,
727
+ "normalized": false,
728
+ "rstrip": true,
729
+ "single_word": false,
730
+ "special": true
731
+ },
732
+ "32088": {
733
+ "content": "<extra_id_11>",
734
+ "lstrip": true,
735
+ "normalized": false,
736
+ "rstrip": true,
737
+ "single_word": false,
738
+ "special": true
739
+ },
740
+ "32089": {
741
+ "content": "<extra_id_10>",
742
+ "lstrip": true,
743
+ "normalized": false,
744
+ "rstrip": true,
745
+ "single_word": false,
746
+ "special": true
747
+ },
748
+ "32090": {
749
+ "content": "<extra_id_9>",
750
+ "lstrip": true,
751
+ "normalized": false,
752
+ "rstrip": true,
753
+ "single_word": false,
754
+ "special": true
755
+ },
756
+ "32091": {
757
+ "content": "<extra_id_8>",
758
+ "lstrip": true,
759
+ "normalized": false,
760
+ "rstrip": true,
761
+ "single_word": false,
762
+ "special": true
763
+ },
764
+ "32092": {
765
+ "content": "<extra_id_7>",
766
+ "lstrip": true,
767
+ "normalized": false,
768
+ "rstrip": true,
769
+ "single_word": false,
770
+ "special": true
771
+ },
772
+ "32093": {
773
+ "content": "<extra_id_6>",
774
+ "lstrip": true,
775
+ "normalized": false,
776
+ "rstrip": true,
777
+ "single_word": false,
778
+ "special": true
779
+ },
780
+ "32094": {
781
+ "content": "<extra_id_5>",
782
+ "lstrip": true,
783
+ "normalized": false,
784
+ "rstrip": true,
785
+ "single_word": false,
786
+ "special": true
787
+ },
788
+ "32095": {
789
+ "content": "<extra_id_4>",
790
+ "lstrip": true,
791
+ "normalized": false,
792
+ "rstrip": true,
793
+ "single_word": false,
794
+ "special": true
795
+ },
796
+ "32096": {
797
+ "content": "<extra_id_3>",
798
+ "lstrip": true,
799
+ "normalized": false,
800
+ "rstrip": true,
801
+ "single_word": false,
802
+ "special": true
803
+ },
804
+ "32097": {
805
+ "content": "<extra_id_2>",
806
+ "lstrip": true,
807
+ "normalized": false,
808
+ "rstrip": true,
809
+ "single_word": false,
810
+ "special": true
811
+ },
812
+ "32098": {
813
+ "content": "<extra_id_1>",
814
+ "lstrip": true,
815
+ "normalized": false,
816
+ "rstrip": true,
817
+ "single_word": false,
818
+ "special": true
819
+ },
820
+ "32099": {
821
+ "content": "<extra_id_0>",
822
+ "lstrip": true,
823
+ "normalized": false,
824
+ "rstrip": true,
825
+ "single_word": false,
826
+ "special": true
827
+ }
828
+ },
829
+ "additional_special_tokens": [
830
+ "<extra_id_0>",
831
+ "<extra_id_1>",
832
+ "<extra_id_2>",
833
+ "<extra_id_3>",
834
+ "<extra_id_4>",
835
+ "<extra_id_5>",
836
+ "<extra_id_6>",
837
+ "<extra_id_7>",
838
+ "<extra_id_8>",
839
+ "<extra_id_9>",
840
+ "<extra_id_10>",
841
+ "<extra_id_11>",
842
+ "<extra_id_12>",
843
+ "<extra_id_13>",
844
+ "<extra_id_14>",
845
+ "<extra_id_15>",
846
+ "<extra_id_16>",
847
+ "<extra_id_17>",
848
+ "<extra_id_18>",
849
+ "<extra_id_19>",
850
+ "<extra_id_20>",
851
+ "<extra_id_21>",
852
+ "<extra_id_22>",
853
+ "<extra_id_23>",
854
+ "<extra_id_24>",
855
+ "<extra_id_25>",
856
+ "<extra_id_26>",
857
+ "<extra_id_27>",
858
+ "<extra_id_28>",
859
+ "<extra_id_29>",
860
+ "<extra_id_30>",
861
+ "<extra_id_31>",
862
+ "<extra_id_32>",
863
+ "<extra_id_33>",
864
+ "<extra_id_34>",
865
+ "<extra_id_35>",
866
+ "<extra_id_36>",
867
+ "<extra_id_37>",
868
+ "<extra_id_38>",
869
+ "<extra_id_39>",
870
+ "<extra_id_40>",
871
+ "<extra_id_41>",
872
+ "<extra_id_42>",
873
+ "<extra_id_43>",
874
+ "<extra_id_44>",
875
+ "<extra_id_45>",
876
+ "<extra_id_46>",
877
+ "<extra_id_47>",
878
+ "<extra_id_48>",
879
+ "<extra_id_49>",
880
+ "<extra_id_50>",
881
+ "<extra_id_51>",
882
+ "<extra_id_52>",
883
+ "<extra_id_53>",
884
+ "<extra_id_54>",
885
+ "<extra_id_55>",
886
+ "<extra_id_56>",
887
+ "<extra_id_57>",
888
+ "<extra_id_58>",
889
+ "<extra_id_59>",
890
+ "<extra_id_60>",
891
+ "<extra_id_61>",
892
+ "<extra_id_62>",
893
+ "<extra_id_63>",
894
+ "<extra_id_64>",
895
+ "<extra_id_65>",
896
+ "<extra_id_66>",
897
+ "<extra_id_67>",
898
+ "<extra_id_68>",
899
+ "<extra_id_69>",
900
+ "<extra_id_70>",
901
+ "<extra_id_71>",
902
+ "<extra_id_72>",
903
+ "<extra_id_73>",
904
+ "<extra_id_74>",
905
+ "<extra_id_75>",
906
+ "<extra_id_76>",
907
+ "<extra_id_77>",
908
+ "<extra_id_78>",
909
+ "<extra_id_79>",
910
+ "<extra_id_80>",
911
+ "<extra_id_81>",
912
+ "<extra_id_82>",
913
+ "<extra_id_83>",
914
+ "<extra_id_84>",
915
+ "<extra_id_85>",
916
+ "<extra_id_86>",
917
+ "<extra_id_87>",
918
+ "<extra_id_88>",
919
+ "<extra_id_89>",
920
+ "<extra_id_90>",
921
+ "<extra_id_91>",
922
+ "<extra_id_92>",
923
+ "<extra_id_93>",
924
+ "<extra_id_94>",
925
+ "<extra_id_95>",
926
+ "<extra_id_96>",
927
+ "<extra_id_97>",
928
+ "<extra_id_98>",
929
+ "<extra_id_99>"
930
+ ],
931
+ "clean_up_tokenization_spaces": true,
932
+ "eos_token": "</s>",
933
+ "extra_ids": 100,
934
+ "legacy": true,
935
+ "model_max_length": 512,
936
+ "pad_token": "<pad>",
937
+ "sp_model_kwargs": {},
938
+ "tokenizer_class": "T5Tokenizer",
939
+ "unk_token": "<unk>"
940
+ }
transformer/config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AllegroTransformer3DModel",
3
+ "_diffusers_version": "0.28.0",
4
+ "_name_or_path": "/cpfs/data/user/yanghuan/expr/rsora/RSoraT2V_L32AH24AD96_122_20240918_88x720x1280_fps15_t5/checkpoint-38000/model",
5
+ "activation_fn": "gelu-approximate",
6
+ "attention_bias": true,
7
+ "attention_head_dim": 96,
8
+ "ca_attention_mode": "xformers",
9
+ "caption_channels": 4096,
10
+ "cross_attention_dim": 2304,
11
+ "double_self_attention": false,
12
+ "downsampler": null,
13
+ "dropout": 0.0,
14
+ "in_channels": 4,
15
+ "interpolation_scale_h": 2.0,
16
+ "interpolation_scale_t": 2.2,
17
+ "interpolation_scale_w": 2.0,
18
+ "model_max_length": 300,
19
+ "norm_elementwise_affine": false,
20
+ "norm_eps": 1e-06,
21
+ "norm_type": "ada_norm_single",
22
+ "num_attention_heads": 24,
23
+ "num_embeds_ada_norm": 1000,
24
+ "num_layers": 32,
25
+ "only_cross_attention": false,
26
+ "out_channels": 4,
27
+ "patch_size": 2,
28
+ "patch_size_t": 1,
29
+ "sa_attention_mode": "flash",
30
+ "sample_size": [
31
+ 90,
32
+ 160
33
+ ],
34
+ "sample_size_t": 22,
35
+ "upcast_attention": false,
36
+ "use_additional_conditions": null,
37
+ "use_linear_projection": false,
38
+ "use_rope": true
39
+ }
transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6927dcc812841c1da549bf11c97ddf30532aee0e708a6642fa64cf8e0dfcdef7
3
+ size 5543894392
transformer/transformer_3d_allegro.py ADDED
@@ -0,0 +1,1776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from Open-Sora-Plan
2
+
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ # --------------------------------------------------------
6
+ # References:
7
+ # Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
8
+ # --------------------------------------------------------
9
+
10
+
11
+ import json
12
+ import os
13
+ from dataclasses import dataclass
14
+ from functools import partial
15
+ from importlib import import_module
16
+ from typing import Any, Callable, Dict, Optional, Tuple
17
+
18
+ import numpy as np
19
+ import torch
20
+ import collections
21
+ import torch.nn.functional as F
22
+ from torch.nn.attention import SDPBackend, sdpa_kernel
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
25
+ from diffusers.models.attention_processor import (
26
+ AttnAddedKVProcessor,
27
+ AttnAddedKVProcessor2_0,
28
+ AttnProcessor,
29
+ CustomDiffusionAttnProcessor,
30
+ CustomDiffusionAttnProcessor2_0,
31
+ CustomDiffusionXFormersAttnProcessor,
32
+ LoRAAttnAddedKVProcessor,
33
+ LoRAAttnProcessor,
34
+ LoRAAttnProcessor2_0,
35
+ LoRAXFormersAttnProcessor,
36
+ SlicedAttnAddedKVProcessor,
37
+ SlicedAttnProcessor,
38
+ SpatialNorm,
39
+ XFormersAttnAddedKVProcessor,
40
+ XFormersAttnProcessor,
41
+ )
42
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps
43
+ from diffusers.models.modeling_utils import ModelMixin
44
+ from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero
45
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_xformers_available
46
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
47
+ from einops import rearrange, repeat
48
+ from torch import nn
49
+ from diffusers.models.embeddings import PixArtAlphaTextProjection
50
+
51
+
52
+ if is_xformers_available():
53
+ import xformers
54
+ import xformers.ops
55
+ else:
56
+ xformers = None
57
+
58
+ from diffusers.utils import logging
59
+
60
+ logger = logging.get_logger(__name__)
61
+
62
+
63
+ def to_2tuple(x):
64
+ if isinstance(x, collections.abc.Iterable):
65
+ return x
66
+ return (x, x)
67
+
68
+ class CombinedTimestepSizeEmbeddings(nn.Module):
69
+ """
70
+ For PixArt-Alpha.
71
+
72
+ Reference:
73
+ https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
74
+ """
75
+
76
+ def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
77
+ super().__init__()
78
+
79
+ self.outdim = size_emb_dim
80
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
81
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
82
+
83
+ self.use_additional_conditions = use_additional_conditions
84
+ if use_additional_conditions:
85
+ self.use_additional_conditions = True
86
+ self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
87
+ self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
88
+ self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
89
+
90
+ def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module):
91
+ if size.ndim == 1:
92
+ size = size[:, None]
93
+
94
+ if size.shape[0] != batch_size:
95
+ size = size.repeat(batch_size // size.shape[0], 1)
96
+ if size.shape[0] != batch_size:
97
+ raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.")
98
+
99
+ current_batch_size, dims = size.shape[0], size.shape[1]
100
+ size = size.reshape(-1)
101
+ size_freq = self.additional_condition_proj(size).to(size.dtype)
102
+
103
+ size_emb = embedder(size_freq)
104
+ size_emb = size_emb.reshape(current_batch_size, dims * self.outdim)
105
+ return size_emb
106
+
107
+ def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
108
+ timesteps_proj = self.time_proj(timestep)
109
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
110
+
111
+ if self.use_additional_conditions:
112
+ resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder)
113
+ aspect_ratio = self.apply_condition(
114
+ aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder
115
+ )
116
+ conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1)
117
+ else:
118
+ conditioning = timesteps_emb
119
+
120
+ return conditioning
121
+
122
+
123
+ class PositionGetter3D(object):
124
+ """ return positions of patches """
125
+
126
+ def __init__(self, ):
127
+ self.cache_positions = {}
128
+
129
+ def __call__(self, b, t, h, w, device):
130
+ if not (b, t,h,w) in self.cache_positions:
131
+ x = torch.arange(w, device=device)
132
+ y = torch.arange(h, device=device)
133
+ z = torch.arange(t, device=device)
134
+ pos = torch.cartesian_prod(z, y, x)
135
+
136
+ pos = pos.reshape(t * h * w, 3).transpose(0, 1).reshape(3, 1, -1).contiguous().expand(3, b, -1).clone()
137
+ poses = (pos[0].contiguous(), pos[1].contiguous(), pos[2].contiguous())
138
+ max_poses = (int(poses[0].max()), int(poses[1].max()), int(poses[2].max()))
139
+
140
+ self.cache_positions[b, t, h, w] = (poses, max_poses)
141
+ pos = self.cache_positions[b, t, h, w]
142
+
143
+ return pos
144
+
145
+
146
+ class RoPE3D(torch.nn.Module):
147
+
148
+ def __init__(self, freq=10000.0, F0=1.0, interpolation_scale_thw=(1, 1, 1)):
149
+ super().__init__()
150
+ self.base = freq
151
+ self.F0 = F0
152
+ self.interpolation_scale_t = interpolation_scale_thw[0]
153
+ self.interpolation_scale_h = interpolation_scale_thw[1]
154
+ self.interpolation_scale_w = interpolation_scale_thw[2]
155
+ self.cache = {}
156
+
157
+ def get_cos_sin(self, D, seq_len, device, dtype, interpolation_scale=1):
158
+ if (D, seq_len, device, dtype) not in self.cache:
159
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
160
+ t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) / interpolation_scale
161
+ freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
162
+ freqs = torch.cat((freqs, freqs), dim=-1)
163
+ cos = freqs.cos() # (Seq, Dim)
164
+ sin = freqs.sin()
165
+ self.cache[D, seq_len, device, dtype] = (cos, sin)
166
+ return self.cache[D, seq_len, device, dtype]
167
+
168
+ @staticmethod
169
+ def rotate_half(x):
170
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
171
+ return torch.cat((-x2, x1), dim=-1)
172
+
173
+ def apply_rope1d(self, tokens, pos1d, cos, sin):
174
+ assert pos1d.ndim == 2
175
+
176
+ # for (batch_size x ntokens x nheads x dim)
177
+ cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
178
+ sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
179
+ return (tokens * cos) + (self.rotate_half(tokens) * sin)
180
+
181
+ def forward(self, tokens, positions):
182
+ """
183
+ input:
184
+ * tokens: batch_size x nheads x ntokens x dim
185
+ * positions: batch_size x ntokens x 3 (t, y and x position of each token)
186
+ output:
187
+ * tokens after appplying RoPE3D (batch_size x nheads x ntokens x x dim)
188
+ """
189
+ assert tokens.size(3) % 3 == 0, "number of dimensions should be a multiple of three"
190
+ D = tokens.size(3) // 3
191
+ poses, max_poses = positions
192
+ assert len(poses) == 3 and poses[0].ndim == 2# Batch, Seq, 3
193
+ cos_t, sin_t = self.get_cos_sin(D, max_poses[0] + 1, tokens.device, tokens.dtype, self.interpolation_scale_t)
194
+ cos_y, sin_y = self.get_cos_sin(D, max_poses[1] + 1, tokens.device, tokens.dtype, self.interpolation_scale_h)
195
+ cos_x, sin_x = self.get_cos_sin(D, max_poses[2] + 1, tokens.device, tokens.dtype, self.interpolation_scale_w)
196
+ # split features into three along the feature dimension, and apply rope1d on each half
197
+ t, y, x = tokens.chunk(3, dim=-1)
198
+ t = self.apply_rope1d(t, poses[0], cos_t, sin_t)
199
+ y = self.apply_rope1d(y, poses[1], cos_y, sin_y)
200
+ x = self.apply_rope1d(x, poses[2], cos_x, sin_x)
201
+ tokens = torch.cat((t, y, x), dim=-1)
202
+ return tokens
203
+
204
+ class PatchEmbed2D(nn.Module):
205
+ """2D Image to Patch Embedding"""
206
+
207
+ def __init__(
208
+ self,
209
+ num_frames=1,
210
+ height=224,
211
+ width=224,
212
+ patch_size_t=1,
213
+ patch_size=16,
214
+ in_channels=3,
215
+ embed_dim=768,
216
+ layer_norm=False,
217
+ flatten=True,
218
+ bias=True,
219
+ interpolation_scale=(1, 1),
220
+ interpolation_scale_t=1,
221
+ use_abs_pos=False,
222
+ ):
223
+ super().__init__()
224
+ self.use_abs_pos = use_abs_pos
225
+ self.flatten = flatten
226
+ self.layer_norm = layer_norm
227
+
228
+ self.proj = nn.Conv2d(
229
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), bias=bias
230
+ )
231
+ if layer_norm:
232
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
233
+ else:
234
+ self.norm = None
235
+
236
+ self.patch_size_t = patch_size_t
237
+ self.patch_size = patch_size
238
+
239
+ def forward(self, latent):
240
+ b, _, _, _, _ = latent.shape
241
+ video_latent = None
242
+
243
+ latent = rearrange(latent, 'b c t h w -> (b t) c h w')
244
+
245
+ latent = self.proj(latent)
246
+ if self.flatten:
247
+ latent = latent.flatten(2).transpose(1, 2) # BT C H W -> BT N C
248
+ if self.layer_norm:
249
+ latent = self.norm(latent)
250
+
251
+ latent = rearrange(latent, '(b t) n c -> b (t n) c', b=b)
252
+ video_latent = latent
253
+
254
+ return video_latent
255
+
256
+
257
+ @maybe_allow_in_graph
258
+ class Attention(nn.Module):
259
+ r"""
260
+ A cross attention layer.
261
+
262
+ Parameters:
263
+ query_dim (`int`):
264
+ The number of channels in the query.
265
+ cross_attention_dim (`int`, *optional*):
266
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
267
+ heads (`int`, *optional*, defaults to 8):
268
+ The number of heads to use for multi-head attention.
269
+ dim_head (`int`, *optional*, defaults to 64):
270
+ The number of channels in each head.
271
+ dropout (`float`, *optional*, defaults to 0.0):
272
+ The dropout probability to use.
273
+ bias (`bool`, *optional*, defaults to False):
274
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
275
+ upcast_attention (`bool`, *optional*, defaults to False):
276
+ Set to `True` to upcast the attention computation to `float32`.
277
+ upcast_softmax (`bool`, *optional*, defaults to False):
278
+ Set to `True` to upcast the softmax computation to `float32`.
279
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
280
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
281
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
282
+ The number of groups to use for the group norm in the cross attention.
283
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
284
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
285
+ norm_num_groups (`int`, *optional*, defaults to `None`):
286
+ The number of groups to use for the group norm in the attention.
287
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
288
+ The number of channels to use for the spatial normalization.
289
+ out_bias (`bool`, *optional*, defaults to `True`):
290
+ Set to `True` to use a bias in the output linear layer.
291
+ scale_qk (`bool`, *optional*, defaults to `True`):
292
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
293
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
294
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
295
+ `added_kv_proj_dim` is not `None`.
296
+ eps (`float`, *optional*, defaults to 1e-5):
297
+ An additional value added to the denominator in group normalization that is used for numerical stability.
298
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
299
+ A factor to rescale the output by dividing it with this value.
300
+ residual_connection (`bool`, *optional*, defaults to `False`):
301
+ Set to `True` to add the residual connection to the output.
302
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
303
+ Set to `True` if the attention block is loaded from a deprecated state dict.
304
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
305
+ The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
306
+ `AttnProcessor` otherwise.
307
+ """
308
+
309
+ def __init__(
310
+ self,
311
+ query_dim: int,
312
+ cross_attention_dim: Optional[int] = None,
313
+ heads: int = 8,
314
+ dim_head: int = 64,
315
+ dropout: float = 0.0,
316
+ bias: bool = False,
317
+ upcast_attention: bool = False,
318
+ upcast_softmax: bool = False,
319
+ cross_attention_norm: Optional[str] = None,
320
+ cross_attention_norm_num_groups: int = 32,
321
+ added_kv_proj_dim: Optional[int] = None,
322
+ norm_num_groups: Optional[int] = None,
323
+ spatial_norm_dim: Optional[int] = None,
324
+ out_bias: bool = True,
325
+ scale_qk: bool = True,
326
+ only_cross_attention: bool = False,
327
+ eps: float = 1e-5,
328
+ rescale_output_factor: float = 1.0,
329
+ residual_connection: bool = False,
330
+ _from_deprecated_attn_block: bool = False,
331
+ processor: Optional["AttnProcessor"] = None,
332
+ attention_mode: str = "xformers",
333
+ use_rope: bool = False,
334
+ interpolation_scale_thw=None,
335
+ ):
336
+ super().__init__()
337
+ self.inner_dim = dim_head * heads
338
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
339
+ self.upcast_attention = upcast_attention
340
+ self.upcast_softmax = upcast_softmax
341
+ self.rescale_output_factor = rescale_output_factor
342
+ self.residual_connection = residual_connection
343
+ self.dropout = dropout
344
+ self.use_rope = use_rope
345
+
346
+ # we make use of this private variable to know whether this class is loaded
347
+ # with an deprecated state dict so that we can convert it on the fly
348
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
349
+
350
+ self.scale_qk = scale_qk
351
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
352
+
353
+ self.heads = heads
354
+ # for slice_size > 0 the attention score computation
355
+ # is split across the batch axis to save memory
356
+ # You can set slice_size with `set_attention_slice`
357
+ self.sliceable_head_dim = heads
358
+
359
+ self.added_kv_proj_dim = added_kv_proj_dim
360
+ self.only_cross_attention = only_cross_attention
361
+
362
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
363
+ raise ValueError(
364
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
365
+ )
366
+
367
+ if norm_num_groups is not None:
368
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
369
+ else:
370
+ self.group_norm = None
371
+
372
+ if spatial_norm_dim is not None:
373
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
374
+ else:
375
+ self.spatial_norm = None
376
+
377
+ if cross_attention_norm is None:
378
+ self.norm_cross = None
379
+ elif cross_attention_norm == "layer_norm":
380
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
381
+ elif cross_attention_norm == "group_norm":
382
+ if self.added_kv_proj_dim is not None:
383
+ # The given `encoder_hidden_states` are initially of shape
384
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
385
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
386
+ # before the projection, so we need to use `added_kv_proj_dim` as
387
+ # the number of channels for the group norm.
388
+ norm_cross_num_channels = added_kv_proj_dim
389
+ else:
390
+ norm_cross_num_channels = self.cross_attention_dim
391
+
392
+ self.norm_cross = nn.GroupNorm(
393
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
394
+ )
395
+ else:
396
+ raise ValueError(
397
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
398
+ )
399
+
400
+ linear_cls = nn.Linear
401
+
402
+
403
+ self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
404
+
405
+ if not self.only_cross_attention:
406
+ # only relevant for the `AddedKVProcessor` classes
407
+ self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
408
+ self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
409
+ else:
410
+ self.to_k = None
411
+ self.to_v = None
412
+
413
+ if self.added_kv_proj_dim is not None:
414
+ self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
415
+ self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
416
+
417
+ self.to_out = nn.ModuleList([])
418
+ self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias))
419
+ self.to_out.append(nn.Dropout(dropout))
420
+
421
+ # set attention processor
422
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
423
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
424
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
425
+ if processor is None:
426
+ processor = (
427
+ AttnProcessor2_0(
428
+ attention_mode,
429
+ use_rope,
430
+ interpolation_scale_thw=interpolation_scale_thw,
431
+ )
432
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
433
+ else AttnProcessor()
434
+ )
435
+ self.set_processor(processor)
436
+
437
+ def set_use_memory_efficient_attention_xformers(
438
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
439
+ ) -> None:
440
+ r"""
441
+ Set whether to use memory efficient attention from `xformers` or not.
442
+
443
+ Args:
444
+ use_memory_efficient_attention_xformers (`bool`):
445
+ Whether to use memory efficient attention from `xformers` or not.
446
+ attention_op (`Callable`, *optional*):
447
+ The attention operation to use. Defaults to `None` which uses the default attention operation from
448
+ `xformers`.
449
+ """
450
+ is_lora = hasattr(self, "processor")
451
+ is_custom_diffusion = hasattr(self, "processor") and isinstance(
452
+ self.processor,
453
+ (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
454
+ )
455
+ is_added_kv_processor = hasattr(self, "processor") and isinstance(
456
+ self.processor,
457
+ (
458
+ AttnAddedKVProcessor,
459
+ AttnAddedKVProcessor2_0,
460
+ SlicedAttnAddedKVProcessor,
461
+ XFormersAttnAddedKVProcessor,
462
+ LoRAAttnAddedKVProcessor,
463
+ ),
464
+ )
465
+
466
+ if use_memory_efficient_attention_xformers:
467
+ if is_added_kv_processor and (is_lora or is_custom_diffusion):
468
+ raise NotImplementedError(
469
+ f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}"
470
+ )
471
+ if not is_xformers_available():
472
+ raise ModuleNotFoundError(
473
+ (
474
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
475
+ " xformers"
476
+ ),
477
+ name="xformers",
478
+ )
479
+ elif not torch.cuda.is_available():
480
+ raise ValueError(
481
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
482
+ " only available for GPU "
483
+ )
484
+ else:
485
+ try:
486
+ # Make sure we can run the memory efficient attention
487
+ _ = xformers.ops.memory_efficient_attention(
488
+ torch.randn((1, 2, 40), device="cuda"),
489
+ torch.randn((1, 2, 40), device="cuda"),
490
+ torch.randn((1, 2, 40), device="cuda"),
491
+ )
492
+ except Exception as e:
493
+ raise e
494
+
495
+ if is_lora:
496
+ # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
497
+ # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
498
+ processor = LoRAXFormersAttnProcessor(
499
+ hidden_size=self.processor.hidden_size,
500
+ cross_attention_dim=self.processor.cross_attention_dim,
501
+ rank=self.processor.rank,
502
+ attention_op=attention_op,
503
+ )
504
+ processor.load_state_dict(self.processor.state_dict())
505
+ processor.to(self.processor.to_q_lora.up.weight.device)
506
+ elif is_custom_diffusion:
507
+ processor = CustomDiffusionXFormersAttnProcessor(
508
+ train_kv=self.processor.train_kv,
509
+ train_q_out=self.processor.train_q_out,
510
+ hidden_size=self.processor.hidden_size,
511
+ cross_attention_dim=self.processor.cross_attention_dim,
512
+ attention_op=attention_op,
513
+ )
514
+ processor.load_state_dict(self.processor.state_dict())
515
+ if hasattr(self.processor, "to_k_custom_diffusion"):
516
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
517
+ elif is_added_kv_processor:
518
+ # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
519
+ # which uses this type of cross attention ONLY because the attention mask of format
520
+ # [0, ..., -10.000, ..., 0, ...,] is not supported
521
+ # throw warning
522
+ logger.info(
523
+ "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
524
+ )
525
+ processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
526
+ else:
527
+ processor = XFormersAttnProcessor(attention_op=attention_op)
528
+ else:
529
+ if is_lora:
530
+ attn_processor_class = (
531
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
532
+ )
533
+ processor = attn_processor_class(
534
+ hidden_size=self.processor.hidden_size,
535
+ cross_attention_dim=self.processor.cross_attention_dim,
536
+ rank=self.processor.rank,
537
+ )
538
+ processor.load_state_dict(self.processor.state_dict())
539
+ processor.to(self.processor.to_q_lora.up.weight.device)
540
+ elif is_custom_diffusion:
541
+ attn_processor_class = (
542
+ CustomDiffusionAttnProcessor2_0
543
+ if hasattr(F, "scaled_dot_product_attention")
544
+ else CustomDiffusionAttnProcessor
545
+ )
546
+ processor = attn_processor_class(
547
+ train_kv=self.processor.train_kv,
548
+ train_q_out=self.processor.train_q_out,
549
+ hidden_size=self.processor.hidden_size,
550
+ cross_attention_dim=self.processor.cross_attention_dim,
551
+ )
552
+ processor.load_state_dict(self.processor.state_dict())
553
+ if hasattr(self.processor, "to_k_custom_diffusion"):
554
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
555
+ else:
556
+ # set attention processor
557
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
558
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
559
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
560
+ processor = (
561
+ AttnProcessor2_0()
562
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
563
+ else AttnProcessor()
564
+ )
565
+
566
+ self.set_processor(processor)
567
+
568
+ def set_attention_slice(self, slice_size: int) -> None:
569
+ r"""
570
+ Set the slice size for attention computation.
571
+
572
+ Args:
573
+ slice_size (`int`):
574
+ The slice size for attention computation.
575
+ """
576
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
577
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
578
+
579
+ if slice_size is not None and self.added_kv_proj_dim is not None:
580
+ processor = SlicedAttnAddedKVProcessor(slice_size)
581
+ elif slice_size is not None:
582
+ processor = SlicedAttnProcessor(slice_size)
583
+ elif self.added_kv_proj_dim is not None:
584
+ processor = AttnAddedKVProcessor()
585
+ else:
586
+ # set attention processor
587
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
588
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
589
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
590
+ processor = (
591
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
592
+ )
593
+
594
+ self.set_processor(processor)
595
+
596
+ def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None:
597
+ r"""
598
+ Set the attention processor to use.
599
+
600
+ Args:
601
+ processor (`AttnProcessor`):
602
+ The attention processor to use.
603
+ _remove_lora (`bool`, *optional*, defaults to `False`):
604
+ Set to `True` to remove LoRA layers from the model.
605
+ """
606
+ if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
607
+ deprecate(
608
+ "set_processor to offload LoRA",
609
+ "0.26.0",
610
+ "In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
611
+ )
612
+ # TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete
613
+ # We need to remove all LoRA layers
614
+ # Don't forget to remove ALL `_remove_lora` from the codebase
615
+ for module in self.modules():
616
+ if hasattr(module, "set_lora_layer"):
617
+ module.set_lora_layer(None)
618
+
619
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
620
+ # pop `processor` from `self._modules`
621
+ if (
622
+ hasattr(self, "processor")
623
+ and isinstance(self.processor, torch.nn.Module)
624
+ and not isinstance(processor, torch.nn.Module)
625
+ ):
626
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
627
+ self._modules.pop("processor")
628
+
629
+ self.processor = processor
630
+
631
+ def get_processor(self, return_deprecated_lora: bool = False):
632
+ r"""
633
+ Get the attention processor in use.
634
+
635
+ Args:
636
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
637
+ Set to `True` to return the deprecated LoRA attention processor.
638
+
639
+ Returns:
640
+ "AttentionProcessor": The attention processor in use.
641
+ """
642
+ if not return_deprecated_lora:
643
+ return self.processor
644
+
645
+ # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
646
+ # serialization format for LoRA Attention Processors. It should be deleted once the integration
647
+ # with PEFT is completed.
648
+ is_lora_activated = {
649
+ name: module.lora_layer is not None
650
+ for name, module in self.named_modules()
651
+ if hasattr(module, "lora_layer")
652
+ }
653
+
654
+ # 1. if no layer has a LoRA activated we can return the processor as usual
655
+ if not any(is_lora_activated.values()):
656
+ return self.processor
657
+
658
+ # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
659
+ is_lora_activated.pop("add_k_proj", None)
660
+ is_lora_activated.pop("add_v_proj", None)
661
+ # 2. else it is not posssible that only some layers have LoRA activated
662
+ if not all(is_lora_activated.values()):
663
+ raise ValueError(
664
+ f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
665
+ )
666
+
667
+ # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
668
+ non_lora_processor_cls_name = self.processor.__class__.__name__
669
+ lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
670
+
671
+ hidden_size = self.inner_dim
672
+
673
+ # now create a LoRA attention processor from the LoRA layers
674
+ if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
675
+ kwargs = {
676
+ "cross_attention_dim": self.cross_attention_dim,
677
+ "rank": self.to_q.lora_layer.rank,
678
+ "network_alpha": self.to_q.lora_layer.network_alpha,
679
+ "q_rank": self.to_q.lora_layer.rank,
680
+ "q_hidden_size": self.to_q.lora_layer.out_features,
681
+ "k_rank": self.to_k.lora_layer.rank,
682
+ "k_hidden_size": self.to_k.lora_layer.out_features,
683
+ "v_rank": self.to_v.lora_layer.rank,
684
+ "v_hidden_size": self.to_v.lora_layer.out_features,
685
+ "out_rank": self.to_out[0].lora_layer.rank,
686
+ "out_hidden_size": self.to_out[0].lora_layer.out_features,
687
+ }
688
+
689
+ if hasattr(self.processor, "attention_op"):
690
+ kwargs["attention_op"] = self.processor.attention_op
691
+
692
+ lora_processor = lora_processor_cls(hidden_size, **kwargs)
693
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
694
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
695
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
696
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
697
+ elif lora_processor_cls == LoRAAttnAddedKVProcessor:
698
+ lora_processor = lora_processor_cls(
699
+ hidden_size,
700
+ cross_attention_dim=self.add_k_proj.weight.shape[0],
701
+ rank=self.to_q.lora_layer.rank,
702
+ network_alpha=self.to_q.lora_layer.network_alpha,
703
+ )
704
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
705
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
706
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
707
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
708
+
709
+ # only save if used
710
+ if self.add_k_proj.lora_layer is not None:
711
+ lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
712
+ lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
713
+ else:
714
+ lora_processor.add_k_proj_lora = None
715
+ lora_processor.add_v_proj_lora = None
716
+ else:
717
+ raise ValueError(f"{lora_processor_cls} does not exist.")
718
+
719
+ return lora_processor
720
+
721
+ def forward(
722
+ self,
723
+ hidden_states: torch.FloatTensor,
724
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
725
+ attention_mask: Optional[torch.FloatTensor] = None,
726
+ **cross_attention_kwargs,
727
+ ) -> torch.Tensor:
728
+ r"""
729
+ The forward method of the `Attention` class.
730
+
731
+ Args:
732
+ hidden_states (`torch.Tensor`):
733
+ The hidden states of the query.
734
+ encoder_hidden_states (`torch.Tensor`, *optional*):
735
+ The hidden states of the encoder.
736
+ attention_mask (`torch.Tensor`, *optional*):
737
+ The attention mask to use. If `None`, no mask is applied.
738
+ **cross_attention_kwargs:
739
+ Additional keyword arguments to pass along to the cross attention.
740
+
741
+ Returns:
742
+ `torch.Tensor`: The output of the attention layer.
743
+ """
744
+ # The `Attention` class can call different attention processors / attention functions
745
+ # here we simply pass along all tensors to the selected processor class
746
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
747
+ return self.processor(
748
+ self,
749
+ hidden_states,
750
+ encoder_hidden_states=encoder_hidden_states,
751
+ attention_mask=attention_mask,
752
+ **cross_attention_kwargs,
753
+ )
754
+
755
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
756
+ r"""
757
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
758
+ is the number of heads initialized while constructing the `Attention` class.
759
+
760
+ Args:
761
+ tensor (`torch.Tensor`): The tensor to reshape.
762
+
763
+ Returns:
764
+ `torch.Tensor`: The reshaped tensor.
765
+ """
766
+ head_size = self.heads
767
+ batch_size, seq_len, dim = tensor.shape
768
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
769
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
770
+ return tensor
771
+
772
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
773
+ r"""
774
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
775
+ the number of heads initialized while constructing the `Attention` class.
776
+
777
+ Args:
778
+ tensor (`torch.Tensor`): The tensor to reshape.
779
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
780
+ reshaped to `[batch_size * heads, seq_len, dim // heads]`.
781
+
782
+ Returns:
783
+ `torch.Tensor`: The reshaped tensor.
784
+ """
785
+ head_size = self.heads
786
+ batch_size, seq_len, dim = tensor.shape
787
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
788
+ tensor = tensor.permute(0, 2, 1, 3)
789
+
790
+ if out_dim == 3:
791
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
792
+
793
+ return tensor
794
+
795
+ def get_attention_scores(
796
+ self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
797
+ ) -> torch.Tensor:
798
+ r"""
799
+ Compute the attention scores.
800
+
801
+ Args:
802
+ query (`torch.Tensor`): The query tensor.
803
+ key (`torch.Tensor`): The key tensor.
804
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
805
+
806
+ Returns:
807
+ `torch.Tensor`: The attention probabilities/scores.
808
+ """
809
+ dtype = query.dtype
810
+ if self.upcast_attention:
811
+ query = query.float()
812
+ key = key.float()
813
+
814
+ if attention_mask is None:
815
+ baddbmm_input = torch.empty(
816
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
817
+ )
818
+ beta = 0
819
+ else:
820
+ baddbmm_input = attention_mask
821
+ beta = 1
822
+
823
+ attention_scores = torch.baddbmm(
824
+ baddbmm_input,
825
+ query,
826
+ key.transpose(-1, -2),
827
+ beta=beta,
828
+ alpha=self.scale,
829
+ )
830
+ del baddbmm_input
831
+
832
+ if self.upcast_softmax:
833
+ attention_scores = attention_scores.float()
834
+
835
+ attention_probs = attention_scores.softmax(dim=-1)
836
+ del attention_scores
837
+
838
+ attention_probs = attention_probs.to(dtype)
839
+
840
+ return attention_probs
841
+
842
+ def prepare_attention_mask(
843
+ self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3, head_size = None,
844
+ ) -> torch.Tensor:
845
+ r"""
846
+ Prepare the attention mask for the attention computation.
847
+
848
+ Args:
849
+ attention_mask (`torch.Tensor`):
850
+ The attention mask to prepare.
851
+ target_length (`int`):
852
+ The target length of the attention mask. This is the length of the attention mask after padding.
853
+ batch_size (`int`):
854
+ The batch size, which is used to repeat the attention mask.
855
+ out_dim (`int`, *optional*, defaults to `3`):
856
+ The output dimension of the attention mask. Can be either `3` or `4`.
857
+
858
+ Returns:
859
+ `torch.Tensor`: The prepared attention mask.
860
+ """
861
+ head_size = head_size if head_size is not None else self.heads
862
+ if attention_mask is None:
863
+ return attention_mask
864
+
865
+ current_length: int = attention_mask.shape[-1]
866
+ if current_length != target_length:
867
+ if attention_mask.device.type == "mps":
868
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
869
+ # Instead, we can manually construct the padding tensor.
870
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
871
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
872
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
873
+ else:
874
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
875
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
876
+ # remaining_length: int = target_length - current_length
877
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
878
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
879
+
880
+ if out_dim == 3:
881
+ if attention_mask.shape[0] < batch_size * head_size:
882
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
883
+ elif out_dim == 4:
884
+ attention_mask = attention_mask.unsqueeze(1)
885
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
886
+
887
+ return attention_mask
888
+
889
+ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
890
+ r"""
891
+ Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
892
+ `Attention` class.
893
+
894
+ Args:
895
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
896
+
897
+ Returns:
898
+ `torch.Tensor`: The normalized encoder hidden states.
899
+ """
900
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
901
+
902
+ if isinstance(self.norm_cross, nn.LayerNorm):
903
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
904
+ elif isinstance(self.norm_cross, nn.GroupNorm):
905
+ # Group norm norms along the channels dimension and expects
906
+ # input to be in the shape of (N, C, *). In this case, we want
907
+ # to norm along the hidden dimension, so we need to move
908
+ # (batch_size, sequence_length, hidden_size) ->
909
+ # (batch_size, hidden_size, sequence_length)
910
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
911
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
912
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
913
+ else:
914
+ assert False
915
+
916
+ return encoder_hidden_states
917
+
918
+ def _init_compress(self):
919
+ self.sr.bias.data.zero_()
920
+ self.norm = nn.LayerNorm(self.inner_dim)
921
+
922
+
923
+ class AttnProcessor2_0(nn.Module):
924
+ r"""
925
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
926
+ """
927
+
928
+ def __init__(self, attention_mode="xformers", use_rope=False, interpolation_scale_thw=None):
929
+ super().__init__()
930
+ self.attention_mode = attention_mode
931
+ self.use_rope = use_rope
932
+ self.interpolation_scale_thw = interpolation_scale_thw
933
+
934
+ if self.use_rope:
935
+ self._init_rope(interpolation_scale_thw)
936
+
937
+ if not hasattr(F, "scaled_dot_product_attention"):
938
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
939
+
940
+ def _init_rope(self, interpolation_scale_thw):
941
+ self.rope = RoPE3D(interpolation_scale_thw=interpolation_scale_thw)
942
+ self.position_getter = PositionGetter3D()
943
+
944
+ def __call__(
945
+ self,
946
+ attn: Attention,
947
+ hidden_states: torch.FloatTensor,
948
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
949
+ attention_mask: Optional[torch.FloatTensor] = None,
950
+ temb: Optional[torch.FloatTensor] = None,
951
+ frame: int = 8,
952
+ height: int = 16,
953
+ width: int = 16,
954
+ ) -> torch.FloatTensor:
955
+
956
+ residual = hidden_states
957
+
958
+ if attn.spatial_norm is not None:
959
+ hidden_states = attn.spatial_norm(hidden_states, temb)
960
+
961
+ input_ndim = hidden_states.ndim
962
+
963
+ if input_ndim == 4:
964
+ batch_size, channel, height, width = hidden_states.shape
965
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
966
+
967
+
968
+ batch_size, sequence_length, _ = (
969
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
970
+ )
971
+
972
+ if attention_mask is not None and self.attention_mode == 'xformers':
973
+ attention_heads = attn.heads
974
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, head_size=attention_heads)
975
+ attention_mask = attention_mask.view(batch_size, attention_heads, -1, attention_mask.shape[-1])
976
+ else:
977
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
978
+ # scaled_dot_product_attention expects attention_mask shape to be
979
+ # (batch, heads, source_length, target_length)
980
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
981
+
982
+ if attn.group_norm is not None:
983
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
984
+
985
+ query = attn.to_q(hidden_states)
986
+
987
+ if encoder_hidden_states is None:
988
+ encoder_hidden_states = hidden_states
989
+ elif attn.norm_cross:
990
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
991
+
992
+ key = attn.to_k(encoder_hidden_states)
993
+ value = attn.to_v(encoder_hidden_states)
994
+
995
+
996
+
997
+ attn_heads = attn.heads
998
+
999
+ inner_dim = key.shape[-1]
1000
+ head_dim = inner_dim // attn_heads
1001
+
1002
+ query = query.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
1003
+ key = key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
1004
+ value = value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
1005
+
1006
+
1007
+ if self.use_rope:
1008
+ # require the shape of (batch_size x nheads x ntokens x dim)
1009
+ pos_thw = self.position_getter(batch_size, t=frame, h=height, w=width, device=query.device)
1010
+ query = self.rope(query, pos_thw)
1011
+ key = self.rope(key, pos_thw)
1012
+
1013
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1014
+ # TODO: add support for attn.scale when we move to Torch 2.1
1015
+ if self.attention_mode == 'flash':
1016
+ # assert attention_mask is None, 'flash-attn do not support attention_mask'
1017
+ with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
1018
+ hidden_states = F.scaled_dot_product_attention(
1019
+ query, key, value, dropout_p=0.0, is_causal=False
1020
+ )
1021
+ elif self.attention_mode == 'xformers':
1022
+ with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
1023
+ hidden_states = F.scaled_dot_product_attention(
1024
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1025
+ )
1026
+
1027
+
1028
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim)
1029
+ hidden_states = hidden_states.to(query.dtype)
1030
+
1031
+ # linear proj
1032
+ hidden_states = attn.to_out[0](hidden_states)
1033
+ # dropout
1034
+ hidden_states = attn.to_out[1](hidden_states)
1035
+
1036
+ if input_ndim == 4:
1037
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1038
+
1039
+ if attn.residual_connection:
1040
+ hidden_states = hidden_states + residual
1041
+
1042
+ hidden_states = hidden_states / attn.rescale_output_factor
1043
+
1044
+ return hidden_states
1045
+
1046
+ class FeedForward(nn.Module):
1047
+ r"""
1048
+ A feed-forward layer.
1049
+
1050
+ Parameters:
1051
+ dim (`int`): The number of channels in the input.
1052
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
1053
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
1054
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
1055
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
1056
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
1057
+ """
1058
+
1059
+ def __init__(
1060
+ self,
1061
+ dim: int,
1062
+ dim_out: Optional[int] = None,
1063
+ mult: int = 4,
1064
+ dropout: float = 0.0,
1065
+ activation_fn: str = "geglu",
1066
+ final_dropout: bool = False,
1067
+ ):
1068
+ super().__init__()
1069
+ inner_dim = int(dim * mult)
1070
+ dim_out = dim_out if dim_out is not None else dim
1071
+ linear_cls = nn.Linear
1072
+
1073
+ if activation_fn == "gelu":
1074
+ act_fn = GELU(dim, inner_dim)
1075
+ if activation_fn == "gelu-approximate":
1076
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
1077
+ elif activation_fn == "geglu":
1078
+ act_fn = GEGLU(dim, inner_dim)
1079
+ elif activation_fn == "geglu-approximate":
1080
+ act_fn = ApproximateGELU(dim, inner_dim)
1081
+
1082
+ self.net = nn.ModuleList([])
1083
+ # project in
1084
+ self.net.append(act_fn)
1085
+ # project dropout
1086
+ self.net.append(nn.Dropout(dropout))
1087
+ # project out
1088
+ self.net.append(linear_cls(inner_dim, dim_out))
1089
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
1090
+ if final_dropout:
1091
+ self.net.append(nn.Dropout(dropout))
1092
+
1093
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1094
+ for module in self.net:
1095
+ hidden_states = module(hidden_states)
1096
+ return hidden_states
1097
+
1098
+
1099
+ @maybe_allow_in_graph
1100
+ class BasicTransformerBlock(nn.Module):
1101
+ r"""
1102
+ A basic Transformer block.
1103
+
1104
+ Parameters:
1105
+ dim (`int`): The number of channels in the input and output.
1106
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
1107
+ attention_head_dim (`int`): The number of channels in each head.
1108
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
1109
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
1110
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
1111
+ num_embeds_ada_norm (:
1112
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
1113
+ attention_bias (:
1114
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
1115
+ only_cross_attention (`bool`, *optional*):
1116
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
1117
+ double_self_attention (`bool`, *optional*):
1118
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
1119
+ upcast_attention (`bool`, *optional*):
1120
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
1121
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
1122
+ Whether to use learnable elementwise affine parameters for normalization.
1123
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
1124
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
1125
+ final_dropout (`bool` *optional*, defaults to False):
1126
+ Whether to apply a final dropout after the last feed-forward layer.
1127
+ positional_embeddings (`str`, *optional*, defaults to `None`):
1128
+ The type of positional embeddings to apply to.
1129
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
1130
+ The maximum number of positional embeddings to apply.
1131
+ """
1132
+
1133
+ def __init__(
1134
+ self,
1135
+ dim: int,
1136
+ num_attention_heads: int,
1137
+ attention_head_dim: int,
1138
+ dropout=0.0,
1139
+ cross_attention_dim: Optional[int] = None,
1140
+ activation_fn: str = "geglu",
1141
+ num_embeds_ada_norm: Optional[int] = None,
1142
+ attention_bias: bool = False,
1143
+ only_cross_attention: bool = False,
1144
+ double_self_attention: bool = False,
1145
+ upcast_attention: bool = False,
1146
+ norm_elementwise_affine: bool = True,
1147
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
1148
+ norm_eps: float = 1e-5,
1149
+ final_dropout: bool = False,
1150
+ positional_embeddings: Optional[str] = None,
1151
+ num_positional_embeddings: Optional[int] = None,
1152
+ sa_attention_mode: str = "flash",
1153
+ ca_attention_mode: str = "xformers",
1154
+ use_rope: bool = False,
1155
+ interpolation_scale_thw: Tuple[int] = (1, 1, 1),
1156
+ block_idx: Optional[int] = None,
1157
+ ):
1158
+ super().__init__()
1159
+ self.only_cross_attention = only_cross_attention
1160
+
1161
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
1162
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
1163
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
1164
+ self.use_layer_norm = norm_type == "layer_norm"
1165
+
1166
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
1167
+ raise ValueError(
1168
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
1169
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
1170
+ )
1171
+
1172
+ if positional_embeddings and (num_positional_embeddings is None):
1173
+ raise ValueError(
1174
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
1175
+ )
1176
+
1177
+ if positional_embeddings == "sinusoidal":
1178
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
1179
+ else:
1180
+ self.pos_embed = None
1181
+
1182
+ # Define 3 blocks. Each block has its own normalization layer.
1183
+ # 1. Self-Attn
1184
+ if self.use_ada_layer_norm:
1185
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
1186
+ elif self.use_ada_layer_norm_zero:
1187
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
1188
+ else:
1189
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1190
+
1191
+ self.attn1 = Attention(
1192
+ query_dim=dim,
1193
+ heads=num_attention_heads,
1194
+ dim_head=attention_head_dim,
1195
+ dropout=dropout,
1196
+ bias=attention_bias,
1197
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
1198
+ upcast_attention=upcast_attention,
1199
+ attention_mode=sa_attention_mode,
1200
+ use_rope=use_rope,
1201
+ interpolation_scale_thw=interpolation_scale_thw,
1202
+ )
1203
+
1204
+ # 2. Cross-Attn
1205
+ if cross_attention_dim is not None or double_self_attention:
1206
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
1207
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
1208
+ # the second cross attention block.
1209
+ self.norm2 = (
1210
+ AdaLayerNorm(dim, num_embeds_ada_norm)
1211
+ if self.use_ada_layer_norm
1212
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1213
+ )
1214
+ self.attn2 = Attention(
1215
+ query_dim=dim,
1216
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
1217
+ heads=num_attention_heads,
1218
+ dim_head=attention_head_dim,
1219
+ dropout=dropout,
1220
+ bias=attention_bias,
1221
+ upcast_attention=upcast_attention,
1222
+ attention_mode=ca_attention_mode, # only xformers support attention_mask
1223
+ use_rope=False, # do not position in cross attention
1224
+ interpolation_scale_thw=interpolation_scale_thw,
1225
+ ) # is self-attn if encoder_hidden_states is none
1226
+ else:
1227
+ self.norm2 = None
1228
+ self.attn2 = None
1229
+
1230
+ # 3. Feed-forward
1231
+
1232
+ if not self.use_ada_layer_norm_single:
1233
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1234
+
1235
+ self.ff = FeedForward(
1236
+ dim,
1237
+ dropout=dropout,
1238
+ activation_fn=activation_fn,
1239
+ final_dropout=final_dropout,
1240
+ )
1241
+
1242
+ # 5. Scale-shift for PixArt-Alpha.
1243
+ if self.use_ada_layer_norm_single:
1244
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
1245
+
1246
+
1247
+ def forward(
1248
+ self,
1249
+ hidden_states: torch.FloatTensor,
1250
+ attention_mask: Optional[torch.FloatTensor] = None,
1251
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1252
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1253
+ timestep: Optional[torch.LongTensor] = None,
1254
+ cross_attention_kwargs: Dict[str, Any] = None,
1255
+ class_labels: Optional[torch.LongTensor] = None,
1256
+ frame: int = None,
1257
+ height: int = None,
1258
+ width: int = None,
1259
+ ) -> torch.FloatTensor:
1260
+ # Notice that normalization is always applied before the real computation in the following blocks.
1261
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
1262
+
1263
+ # 0. Self-Attention
1264
+ batch_size = hidden_states.shape[0]
1265
+
1266
+ if self.use_ada_layer_norm:
1267
+ norm_hidden_states = self.norm1(hidden_states, timestep)
1268
+ elif self.use_ada_layer_norm_zero:
1269
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
1270
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
1271
+ )
1272
+ elif self.use_layer_norm:
1273
+ norm_hidden_states = self.norm1(hidden_states)
1274
+ elif self.use_ada_layer_norm_single:
1275
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
1276
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
1277
+ ).chunk(6, dim=1)
1278
+ norm_hidden_states = self.norm1(hidden_states)
1279
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
1280
+ norm_hidden_states = norm_hidden_states.squeeze(1)
1281
+ else:
1282
+ raise ValueError("Incorrect norm used")
1283
+
1284
+ if self.pos_embed is not None:
1285
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1286
+
1287
+ attn_output = self.attn1(
1288
+ norm_hidden_states,
1289
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
1290
+ attention_mask=attention_mask,
1291
+ frame=frame,
1292
+ height=height,
1293
+ width=width,
1294
+ **cross_attention_kwargs,
1295
+ )
1296
+ if self.use_ada_layer_norm_zero:
1297
+ attn_output = gate_msa.unsqueeze(1) * attn_output
1298
+ elif self.use_ada_layer_norm_single:
1299
+ attn_output = gate_msa * attn_output
1300
+
1301
+ hidden_states = attn_output + hidden_states
1302
+ if hidden_states.ndim == 4:
1303
+ hidden_states = hidden_states.squeeze(1)
1304
+
1305
+ # 1. Cross-Attention
1306
+ if self.attn2 is not None:
1307
+
1308
+ if self.use_ada_layer_norm:
1309
+ norm_hidden_states = self.norm2(hidden_states, timestep)
1310
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
1311
+ norm_hidden_states = self.norm2(hidden_states)
1312
+ elif self.use_ada_layer_norm_single:
1313
+ # For PixArt norm2 isn't applied here:
1314
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
1315
+ norm_hidden_states = hidden_states
1316
+ else:
1317
+ raise ValueError("Incorrect norm")
1318
+
1319
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
1320
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1321
+
1322
+ attn_output = self.attn2(
1323
+ norm_hidden_states,
1324
+ encoder_hidden_states=encoder_hidden_states,
1325
+ attention_mask=encoder_attention_mask,
1326
+ **cross_attention_kwargs,
1327
+ )
1328
+ hidden_states = attn_output + hidden_states
1329
+
1330
+
1331
+ # 2. Feed-forward
1332
+ if not self.use_ada_layer_norm_single:
1333
+ norm_hidden_states = self.norm3(hidden_states)
1334
+
1335
+ if self.use_ada_layer_norm_zero:
1336
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
1337
+
1338
+ if self.use_ada_layer_norm_single:
1339
+ norm_hidden_states = self.norm2(hidden_states)
1340
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
1341
+
1342
+ ff_output = self.ff(norm_hidden_states)
1343
+
1344
+ if self.use_ada_layer_norm_zero:
1345
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
1346
+ elif self.use_ada_layer_norm_single:
1347
+ ff_output = gate_mlp * ff_output
1348
+
1349
+
1350
+ hidden_states = ff_output + hidden_states
1351
+ if hidden_states.ndim == 4:
1352
+ hidden_states = hidden_states.squeeze(1)
1353
+
1354
+ return hidden_states
1355
+
1356
+
1357
+ class AdaLayerNormSingle(nn.Module):
1358
+ r"""
1359
+ Norm layer adaptive layer norm single (adaLN-single).
1360
+
1361
+ As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
1362
+
1363
+ Parameters:
1364
+ embedding_dim (`int`): The size of each embedding vector.
1365
+ use_additional_conditions (`bool`): To use additional conditions for normalization or not.
1366
+ """
1367
+
1368
+ def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
1369
+ super().__init__()
1370
+
1371
+ self.emb = CombinedTimestepSizeEmbeddings(
1372
+ embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
1373
+ )
1374
+
1375
+ self.silu = nn.SiLU()
1376
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
1377
+
1378
+ def forward(
1379
+ self,
1380
+ timestep: torch.Tensor,
1381
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
1382
+ batch_size: int = None,
1383
+ hidden_dtype: Optional[torch.dtype] = None,
1384
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1385
+ # No modulation happening here.
1386
+ embedded_timestep = self.emb(
1387
+ timestep, batch_size=batch_size, hidden_dtype=hidden_dtype, resolution=None, aspect_ratio=None
1388
+ )
1389
+ return self.linear(self.silu(embedded_timestep)), embedded_timestep
1390
+
1391
+
1392
+ @dataclass
1393
+ class Transformer3DModelOutput(BaseOutput):
1394
+ """
1395
+ The output of [`Transformer2DModel`].
1396
+
1397
+ Args:
1398
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
1399
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
1400
+ distributions for the unnoised latent pixels.
1401
+ """
1402
+
1403
+ sample: torch.FloatTensor
1404
+
1405
+
1406
+ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
1407
+ _supports_gradient_checkpointing = True
1408
+
1409
+ """
1410
+ A 2D Transformer model for image-like data.
1411
+
1412
+ Parameters:
1413
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
1414
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
1415
+ in_channels (`int`, *optional*):
1416
+ The number of channels in the input and output (specify if the input is **continuous**).
1417
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
1418
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
1419
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
1420
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
1421
+ This is fixed during training since it is used to learn a number of position embeddings.
1422
+ num_vector_embeds (`int`, *optional*):
1423
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
1424
+ Includes the class for the masked latent pixel.
1425
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
1426
+ num_embeds_ada_norm ( `int`, *optional*):
1427
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
1428
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
1429
+ added to the hidden states.
1430
+
1431
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
1432
+ attention_bias (`bool`, *optional*):
1433
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
1434
+ """
1435
+
1436
+ @register_to_config
1437
+ def __init__(
1438
+ self,
1439
+ num_attention_heads: int = 16,
1440
+ attention_head_dim: int = 88,
1441
+ in_channels: Optional[int] = None,
1442
+ out_channels: Optional[int] = None,
1443
+ num_layers: int = 1,
1444
+ dropout: float = 0.0,
1445
+ cross_attention_dim: Optional[int] = None,
1446
+ attention_bias: bool = False,
1447
+ sample_size: Optional[int] = None,
1448
+ sample_size_t: Optional[int] = None,
1449
+ patch_size: Optional[int] = None,
1450
+ patch_size_t: Optional[int] = None,
1451
+ activation_fn: str = "geglu",
1452
+ num_embeds_ada_norm: Optional[int] = None,
1453
+ use_linear_projection: bool = False,
1454
+ only_cross_attention: bool = False,
1455
+ double_self_attention: bool = False,
1456
+ upcast_attention: bool = False,
1457
+ norm_type: str = "ada_norm",
1458
+ norm_elementwise_affine: bool = True,
1459
+ norm_eps: float = 1e-5,
1460
+ caption_channels: int = None,
1461
+ interpolation_scale_h: float = None,
1462
+ interpolation_scale_w: float = None,
1463
+ interpolation_scale_t: float = None,
1464
+ use_additional_conditions: Optional[bool] = None,
1465
+ sa_attention_mode: str = "flash",
1466
+ ca_attention_mode: str = 'xformers',
1467
+ downsampler: str = None,
1468
+ use_rope: bool = False,
1469
+ model_max_length: int = 300,
1470
+ ):
1471
+ super().__init__()
1472
+ self.use_linear_projection = use_linear_projection
1473
+ self.interpolation_scale_t = interpolation_scale_t
1474
+ self.interpolation_scale_h = interpolation_scale_h
1475
+ self.interpolation_scale_w = interpolation_scale_w
1476
+ self.downsampler = downsampler
1477
+ self.caption_channels = caption_channels
1478
+ self.num_attention_heads = num_attention_heads
1479
+ self.attention_head_dim = attention_head_dim
1480
+ inner_dim = num_attention_heads * attention_head_dim
1481
+ self.inner_dim = inner_dim
1482
+ self.in_channels = in_channels
1483
+ self.out_channels = in_channels if out_channels is None else out_channels
1484
+ self.use_rope = use_rope
1485
+ self.model_max_length = model_max_length
1486
+ self.num_layers = num_layers
1487
+ self.config.hidden_size = inner_dim
1488
+
1489
+
1490
+ # 1. Transformer3DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
1491
+ # Define whether input is continuous or discrete depending on configuration
1492
+ assert in_channels is not None and patch_size is not None
1493
+
1494
+ # 2. Initialize the right blocks.
1495
+ # Initialize the output blocks and other projection blocks when necessary.
1496
+
1497
+ assert self.config.sample_size_t is not None, "AllegroTransformer3DModel over patched input must provide sample_size_t"
1498
+ assert self.config.sample_size is not None, "AllegroTransformer3DModel over patched input must provide sample_size"
1499
+ #assert not (self.config.sample_size_t == 1 and self.config.patch_size_t == 2), "Image do not need patchfy in t-dim"
1500
+
1501
+ self.num_frames = self.config.sample_size_t
1502
+ self.config.sample_size = to_2tuple(self.config.sample_size)
1503
+ self.height = self.config.sample_size[0]
1504
+ self.width = self.config.sample_size[1]
1505
+ self.patch_size_t = self.config.patch_size_t
1506
+ self.patch_size = self.config.patch_size
1507
+ interpolation_scale_t = ((self.config.sample_size_t - 1) // 16 + 1) if self.config.sample_size_t % 2 == 1 else self.config.sample_size_t / 16
1508
+ interpolation_scale_t = (
1509
+ self.config.interpolation_scale_t if self.config.interpolation_scale_t is not None else interpolation_scale_t
1510
+ )
1511
+ interpolation_scale = (
1512
+ self.config.interpolation_scale_h if self.config.interpolation_scale_h is not None else self.config.sample_size[0] / 30,
1513
+ self.config.interpolation_scale_w if self.config.interpolation_scale_w is not None else self.config.sample_size[1] / 40,
1514
+ )
1515
+ self.pos_embed = PatchEmbed2D(
1516
+ num_frames=self.config.sample_size_t,
1517
+ height=self.config.sample_size[0],
1518
+ width=self.config.sample_size[1],
1519
+ patch_size_t=self.config.patch_size_t,
1520
+ patch_size=self.config.patch_size,
1521
+ in_channels=self.in_channels,
1522
+ embed_dim=self.inner_dim,
1523
+ interpolation_scale=interpolation_scale,
1524
+ interpolation_scale_t=interpolation_scale_t,
1525
+ use_abs_pos=not self.config.use_rope,
1526
+ )
1527
+ interpolation_scale_thw = (interpolation_scale_t, *interpolation_scale)
1528
+
1529
+ # 3. Define transformers blocks, spatial attention
1530
+ self.transformer_blocks = nn.ModuleList(
1531
+ [
1532
+ BasicTransformerBlock(
1533
+ inner_dim,
1534
+ num_attention_heads,
1535
+ attention_head_dim,
1536
+ dropout=dropout,
1537
+ cross_attention_dim=cross_attention_dim,
1538
+ activation_fn=activation_fn,
1539
+ num_embeds_ada_norm=num_embeds_ada_norm,
1540
+ attention_bias=attention_bias,
1541
+ only_cross_attention=only_cross_attention,
1542
+ double_self_attention=double_self_attention,
1543
+ upcast_attention=upcast_attention,
1544
+ norm_type=norm_type,
1545
+ norm_elementwise_affine=norm_elementwise_affine,
1546
+ norm_eps=norm_eps,
1547
+ sa_attention_mode=sa_attention_mode,
1548
+ ca_attention_mode=ca_attention_mode,
1549
+ use_rope=use_rope,
1550
+ interpolation_scale_thw=interpolation_scale_thw,
1551
+ block_idx=d,
1552
+ )
1553
+ for d in range(num_layers)
1554
+ ]
1555
+ )
1556
+
1557
+ # 4. Define output layers
1558
+
1559
+ if norm_type != "ada_norm_single":
1560
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
1561
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
1562
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
1563
+ elif norm_type == "ada_norm_single":
1564
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
1565
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
1566
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
1567
+
1568
+ # 5. PixArt-Alpha blocks.
1569
+ self.adaln_single = None
1570
+ self.use_additional_conditions = False
1571
+ if norm_type == "ada_norm_single":
1572
+ # self.use_additional_conditions = self.config.sample_size[0] == 128 # False, 128 -> 1024
1573
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
1574
+ # additional conditions until we find better name
1575
+ self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
1576
+
1577
+ self.caption_projection = None
1578
+ if caption_channels is not None:
1579
+ self.caption_projection = PixArtAlphaTextProjection(
1580
+ in_features=caption_channels, hidden_size=inner_dim
1581
+ )
1582
+
1583
+ self.gradient_checkpointing = False
1584
+
1585
+ def _set_gradient_checkpointing(self, module, value=False):
1586
+ self.gradient_checkpointing = value
1587
+
1588
+
1589
+ def forward(
1590
+ self,
1591
+ hidden_states: torch.Tensor,
1592
+ timestep: Optional[torch.LongTensor] = None,
1593
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1594
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
1595
+ class_labels: Optional[torch.LongTensor] = None,
1596
+ cross_attention_kwargs: Dict[str, Any] = None,
1597
+ attention_mask: Optional[torch.Tensor] = None,
1598
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1599
+ return_dict: bool = True,
1600
+ ):
1601
+ """
1602
+ The [`Transformer2DModel`] forward method.
1603
+
1604
+ Args:
1605
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, frame, channel, height, width)` if continuous):
1606
+ Input `hidden_states`.
1607
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
1608
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
1609
+ self-attention.
1610
+ timestep ( `torch.LongTensor`, *optional*):
1611
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
1612
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
1613
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
1614
+ `AdaLayerZeroNorm`.
1615
+ added_cond_kwargs ( `Dict[str, Any]`, *optional*):
1616
+ A kwargs dictionary that if specified is passed along to the `AdaLayerNormSingle`
1617
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
1618
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1619
+ `self.processor` in
1620
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1621
+ attention_mask ( `torch.Tensor`, *optional*):
1622
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
1623
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
1624
+ negative values to the attention scores corresponding to "discard" tokens.
1625
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
1626
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
1627
+
1628
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
1629
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
1630
+
1631
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
1632
+ above. This bias will be added to the cross-attention scores.
1633
+ return_dict (`bool`, *optional*, defaults to `True`):
1634
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
1635
+ tuple.
1636
+
1637
+ Returns:
1638
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
1639
+ `tuple` where the first element is the sample tensor.
1640
+ """
1641
+ batch_size, c, frame, h, w = hidden_states.shape
1642
+
1643
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
1644
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
1645
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
1646
+ # expects mask of shape:
1647
+ # [batch, key_tokens]
1648
+ # adds singleton query_tokens dimension:
1649
+ # [batch, 1, key_tokens]
1650
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
1651
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
1652
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) attention_mask_vid, attention_mask_img = None, None
1653
+ if attention_mask is not None and attention_mask.ndim == 4:
1654
+ # assume that mask is expressed as:
1655
+ # (1 = keep, 0 = discard)
1656
+ # convert mask into a bias that can be added to attention scores:
1657
+ # (keep = +0, discard = -10000.0)
1658
+ # b, frame+use_image_num, h, w -> a video with images
1659
+ # b, 1, h, w -> only images
1660
+ attention_mask = attention_mask.to(self.dtype)
1661
+ attention_mask_vid = attention_mask[:, :frame] # b, frame, h, w
1662
+
1663
+ if attention_mask_vid.numel() > 0:
1664
+ attention_mask_vid = attention_mask_vid.unsqueeze(1) # b 1 t h w
1665
+ attention_mask_vid = F.max_pool3d(attention_mask_vid, kernel_size=(self.patch_size_t, self.patch_size, self.patch_size),
1666
+ stride=(self.patch_size_t, self.patch_size, self.patch_size))
1667
+ attention_mask_vid = rearrange(attention_mask_vid, 'b 1 t h w -> (b 1) 1 (t h w)')
1668
+
1669
+ attention_mask_vid = (1 - attention_mask_vid.bool().to(self.dtype)) * -10000.0 if attention_mask_vid.numel() > 0 else None
1670
+
1671
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
1672
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 3:
1673
+ # b, 1+use_image_num, l -> a video with images
1674
+ # b, 1, l -> only images
1675
+ encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0
1676
+ encoder_attention_mask_vid = rearrange(encoder_attention_mask, 'b 1 l -> (b 1) 1 l') if encoder_attention_mask.numel() > 0 else None
1677
+
1678
+ # 1. Input
1679
+ frame = frame // self.patch_size_t # patchfy
1680
+ # print('frame', frame)
1681
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
1682
+
1683
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None} if added_cond_kwargs is None else added_cond_kwargs
1684
+ hidden_states, encoder_hidden_states_vid, \
1685
+ timestep_vid, embedded_timestep_vid = self._operate_on_patched_inputs(
1686
+ hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, batch_size,
1687
+ )
1688
+
1689
+
1690
+ for _, block in enumerate(self.transformer_blocks):
1691
+ hidden_states = block(
1692
+ hidden_states,
1693
+ attention_mask_vid,
1694
+ encoder_hidden_states_vid,
1695
+ encoder_attention_mask_vid,
1696
+ timestep_vid,
1697
+ cross_attention_kwargs,
1698
+ class_labels,
1699
+ frame=frame,
1700
+ height=height,
1701
+ width=width,
1702
+ )
1703
+
1704
+ # 3. Output
1705
+ output = None
1706
+ if hidden_states is not None:
1707
+ output = self._get_output_for_patched_inputs(
1708
+ hidden_states=hidden_states,
1709
+ timestep=timestep_vid,
1710
+ class_labels=class_labels,
1711
+ embedded_timestep=embedded_timestep_vid,
1712
+ num_frames=frame,
1713
+ height=height,
1714
+ width=width,
1715
+ ) # b c t h w
1716
+
1717
+ if not return_dict:
1718
+ return (output,)
1719
+
1720
+ return Transformer3DModelOutput(sample=output)
1721
+
1722
+ def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, batch_size):
1723
+ # batch_size = hidden_states.shape[0]
1724
+ hidden_states_vid = self.pos_embed(hidden_states.to(self.dtype))
1725
+ timestep_vid = None
1726
+ embedded_timestep_vid = None
1727
+ encoder_hidden_states_vid = None
1728
+
1729
+ if self.adaln_single is not None:
1730
+ if self.use_additional_conditions and added_cond_kwargs is None:
1731
+ raise ValueError(
1732
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
1733
+ )
1734
+ timestep, embedded_timestep = self.adaln_single(
1735
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=self.dtype
1736
+ ) # b 6d, b d
1737
+
1738
+ timestep_vid = timestep
1739
+ embedded_timestep_vid = embedded_timestep
1740
+
1741
+ if self.caption_projection is not None:
1742
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states) # b, 1+use_image_num, l, d or b, 1, l, d
1743
+ encoder_hidden_states_vid = rearrange(encoder_hidden_states[:, :1], 'b 1 l d -> (b 1) l d')
1744
+
1745
+ return hidden_states_vid, encoder_hidden_states_vid, timestep_vid, embedded_timestep_vid
1746
+
1747
+ def _get_output_for_patched_inputs(
1748
+ self, hidden_states, timestep, class_labels, embedded_timestep, num_frames, height=None, width=None
1749
+ ):
1750
+ # import ipdb;ipdb.set_trace()
1751
+ if self.config.norm_type != "ada_norm_single":
1752
+ conditioning = self.transformer_blocks[0].norm1.emb(
1753
+ timestep, class_labels, hidden_dtype=self.dtype
1754
+ )
1755
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
1756
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
1757
+ hidden_states = self.proj_out_2(hidden_states)
1758
+ elif self.config.norm_type == "ada_norm_single":
1759
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
1760
+ hidden_states = self.norm_out(hidden_states)
1761
+ # Modulation
1762
+ hidden_states = hidden_states * (1 + scale) + shift
1763
+ hidden_states = self.proj_out(hidden_states)
1764
+ hidden_states = hidden_states.squeeze(1)
1765
+
1766
+ # unpatchify
1767
+ if self.adaln_single is None:
1768
+ height = width = int(hidden_states.shape[1] ** 0.5)
1769
+ hidden_states = hidden_states.reshape(
1770
+ shape=(-1, num_frames, height, width, self.patch_size_t, self.patch_size, self.patch_size, self.out_channels)
1771
+ )
1772
+ hidden_states = torch.einsum("nthwopqc->nctohpwq", hidden_states)
1773
+ output = hidden_states.reshape(
1774
+ shape=(-1, self.out_channels, num_frames * self.patch_size_t, height * self.patch_size, width * self.patch_size)
1775
+ )
1776
+ return output
vae/config.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AllegroAutoencoderKL3D",
3
+ "_diffusers_version": "0.28.0",
4
+ "_name_or_path": "/cpfs/data/user/larrytsai/Projects/Yi-VG/allegro_pipeline/vae",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "blocks_tempdown_li": [
13
+ true,
14
+ true,
15
+ false,
16
+ false
17
+ ],
18
+ "blocks_tempup_li": [
19
+ false,
20
+ true,
21
+ true,
22
+ false
23
+ ],
24
+ "chunk_len": 24,
25
+ "down_block_num": 4,
26
+ "force_upcast": true,
27
+ "in_channels": 3,
28
+ "latent_channels": 4,
29
+ "layers_per_block": 2,
30
+ "load_mode": "full",
31
+ "norm_num_groups": 32,
32
+ "out_channels": 3,
33
+ "sample_size": 320,
34
+ "scale_factor": 0.13,
35
+ "t_over": 8,
36
+ "tile_overlap": [
37
+ 120,
38
+ 80
39
+ ],
40
+ "up_block_num": 4
41
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47871a698b18f92f15019d361a81cbc8af4676f8eef9a47fd2b95354a39f831a
3
+ size 699904972
vae/vae_allegro.py ADDED
@@ -0,0 +1,978 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ import os
4
+ from typing import Dict, Optional, Tuple, Union
5
+ from einops import rearrange
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.models.modeling_utils import ModelMixin
13
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
14
+ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
15
+ from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
16
+ from diffusers.models.attention_processor import Attention
17
+ from diffusers.models.resnet import ResnetBlock2D
18
+ from diffusers.models.upsampling import Upsample2D
19
+ from diffusers.models.downsampling import Downsample2D
20
+ from diffusers.models.attention_processor import SpatialNorm
21
+
22
+
23
+ class TemporalConvBlock(nn.Module):
24
+ """
25
+ Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
26
+ https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
27
+ """
28
+
29
+ def __init__(self, in_dim, out_dim=None, dropout=0.0, up_sample=False, down_sample=False, spa_stride=1):
30
+ super().__init__()
31
+ out_dim = out_dim or in_dim
32
+ self.in_dim = in_dim
33
+ self.out_dim = out_dim
34
+ spa_pad = int((spa_stride-1)*0.5)
35
+ temp_pad = 0
36
+ self.temp_pad = temp_pad
37
+
38
+ if down_sample:
39
+ self.conv1 = nn.Sequential(
40
+ nn.GroupNorm(32, in_dim),
41
+ nn.SiLU(),
42
+ nn.Conv3d(in_dim, out_dim, (2, spa_stride, spa_stride), stride=(2,1,1), padding=(0, spa_pad, spa_pad))
43
+ )
44
+ elif up_sample:
45
+ self.conv1 = nn.Sequential(
46
+ nn.GroupNorm(32, in_dim),
47
+ nn.SiLU(),
48
+ nn.Conv3d(in_dim, out_dim*2, (1, spa_stride, spa_stride), padding=(0, spa_pad, spa_pad))
49
+ )
50
+ else:
51
+ self.conv1 = nn.Sequential(
52
+ nn.GroupNorm(32, in_dim),
53
+ nn.SiLU(),
54
+ nn.Conv3d(in_dim, out_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad))
55
+ )
56
+ self.conv2 = nn.Sequential(
57
+ nn.GroupNorm(32, out_dim),
58
+ nn.SiLU(),
59
+ nn.Dropout(dropout),
60
+ nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)),
61
+ )
62
+ self.conv3 = nn.Sequential(
63
+ nn.GroupNorm(32, out_dim),
64
+ nn.SiLU(),
65
+ nn.Dropout(dropout),
66
+ nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)),
67
+ )
68
+ self.conv4 = nn.Sequential(
69
+ nn.GroupNorm(32, out_dim),
70
+ nn.SiLU(),
71
+ nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)),
72
+ )
73
+
74
+ # zero out the last layer params,so the conv block is identity
75
+ nn.init.zeros_(self.conv4[-1].weight)
76
+ nn.init.zeros_(self.conv4[-1].bias)
77
+
78
+ self.down_sample = down_sample
79
+ self.up_sample = up_sample
80
+
81
+
82
+ def forward(self, hidden_states):
83
+ identity = hidden_states
84
+
85
+ if self.down_sample:
86
+ identity = identity[:,:,::2]
87
+ elif self.up_sample:
88
+ hidden_states_new = torch.cat((hidden_states,hidden_states),dim=2)
89
+ hidden_states_new[:, :, 0::2] = hidden_states
90
+ hidden_states_new[:, :, 1::2] = hidden_states
91
+ identity = hidden_states_new
92
+ del hidden_states_new
93
+
94
+ if self.down_sample or self.up_sample:
95
+ hidden_states = self.conv1(hidden_states)
96
+ else:
97
+ hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2)
98
+ hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2)
99
+ hidden_states = self.conv1(hidden_states)
100
+
101
+
102
+ if self.up_sample:
103
+ hidden_states = rearrange(hidden_states, 'b (d c) f h w -> b c (f d) h w', d=2)
104
+
105
+ hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2)
106
+ hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2)
107
+ hidden_states = self.conv2(hidden_states)
108
+ hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2)
109
+ hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2)
110
+ hidden_states = self.conv3(hidden_states)
111
+ hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2)
112
+ hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2)
113
+ hidden_states = self.conv4(hidden_states)
114
+
115
+ hidden_states = identity + hidden_states
116
+
117
+ return hidden_states
118
+
119
+
120
+ class DownEncoderBlock3D(nn.Module):
121
+ def __init__(
122
+ self,
123
+ in_channels: int,
124
+ out_channels: int,
125
+ dropout: float = 0.0,
126
+ num_layers: int = 1,
127
+ resnet_eps: float = 1e-6,
128
+ resnet_time_scale_shift: str = "default",
129
+ resnet_act_fn: str = "swish",
130
+ resnet_groups: int = 32,
131
+ resnet_pre_norm: bool = True,
132
+ output_scale_factor=1.0,
133
+ add_downsample=True,
134
+ add_temp_downsample=False,
135
+ downsample_padding=1,
136
+ ):
137
+ super().__init__()
138
+ resnets = []
139
+ temp_convs = []
140
+
141
+ for i in range(num_layers):
142
+ in_channels = in_channels if i == 0 else out_channels
143
+ resnets.append(
144
+ ResnetBlock2D(
145
+ in_channels=in_channels,
146
+ out_channels=out_channels,
147
+ temb_channels=None,
148
+ eps=resnet_eps,
149
+ groups=resnet_groups,
150
+ dropout=dropout,
151
+ time_embedding_norm=resnet_time_scale_shift,
152
+ non_linearity=resnet_act_fn,
153
+ output_scale_factor=output_scale_factor,
154
+ pre_norm=resnet_pre_norm,
155
+ )
156
+ )
157
+ temp_convs.append(
158
+ TemporalConvBlock(
159
+ out_channels,
160
+ out_channels,
161
+ dropout=0.1,
162
+ )
163
+ )
164
+
165
+ self.resnets = nn.ModuleList(resnets)
166
+ self.temp_convs = nn.ModuleList(temp_convs)
167
+
168
+ if add_temp_downsample:
169
+ self.temp_convs_down = TemporalConvBlock(
170
+ out_channels,
171
+ out_channels,
172
+ dropout=0.1,
173
+ down_sample=True,
174
+ spa_stride=3
175
+ )
176
+ self.add_temp_downsample = add_temp_downsample
177
+
178
+ if add_downsample:
179
+ self.downsamplers = nn.ModuleList(
180
+ [
181
+ Downsample2D(
182
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
183
+ )
184
+ ]
185
+ )
186
+ else:
187
+ self.downsamplers = None
188
+
189
+ def _set_partial_grad(self):
190
+ for temp_conv in self.temp_convs:
191
+ temp_conv.requires_grad_(True)
192
+ if self.downsamplers:
193
+ for down_layer in self.downsamplers:
194
+ down_layer.requires_grad_(True)
195
+
196
+ def forward(self, hidden_states):
197
+ bz = hidden_states.shape[0]
198
+
199
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
200
+ hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w')
201
+ hidden_states = resnet(hidden_states, temb=None)
202
+ hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz)
203
+ hidden_states = temp_conv(hidden_states)
204
+ if self.add_temp_downsample:
205
+ hidden_states = self.temp_convs_down(hidden_states)
206
+
207
+ if self.downsamplers is not None:
208
+ hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w')
209
+ for upsampler in self.downsamplers:
210
+ hidden_states = upsampler(hidden_states)
211
+ hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz)
212
+ return hidden_states
213
+
214
+
215
+ class UpDecoderBlock3D(nn.Module):
216
+ def __init__(
217
+ self,
218
+ in_channels: int,
219
+ out_channels: int,
220
+ dropout: float = 0.0,
221
+ num_layers: int = 1,
222
+ resnet_eps: float = 1e-6,
223
+ resnet_time_scale_shift: str = "default", # default, spatial
224
+ resnet_act_fn: str = "swish",
225
+ resnet_groups: int = 32,
226
+ resnet_pre_norm: bool = True,
227
+ output_scale_factor=1.0,
228
+ add_upsample=True,
229
+ add_temp_upsample=False,
230
+ temb_channels=None,
231
+ ):
232
+ super().__init__()
233
+ self.add_upsample = add_upsample
234
+
235
+ resnets = []
236
+ temp_convs = []
237
+
238
+ for i in range(num_layers):
239
+ input_channels = in_channels if i == 0 else out_channels
240
+
241
+ resnets.append(
242
+ ResnetBlock2D(
243
+ in_channels=input_channels,
244
+ out_channels=out_channels,
245
+ temb_channels=temb_channels,
246
+ eps=resnet_eps,
247
+ groups=resnet_groups,
248
+ dropout=dropout,
249
+ time_embedding_norm=resnet_time_scale_shift,
250
+ non_linearity=resnet_act_fn,
251
+ output_scale_factor=output_scale_factor,
252
+ pre_norm=resnet_pre_norm,
253
+ )
254
+ )
255
+ temp_convs.append(
256
+ TemporalConvBlock(
257
+ out_channels,
258
+ out_channels,
259
+ dropout=0.1,
260
+ )
261
+ )
262
+
263
+ self.resnets = nn.ModuleList(resnets)
264
+ self.temp_convs = nn.ModuleList(temp_convs)
265
+
266
+ self.add_temp_upsample = add_temp_upsample
267
+ if add_temp_upsample:
268
+ self.temp_conv_up = TemporalConvBlock(
269
+ out_channels,
270
+ out_channels,
271
+ dropout=0.1,
272
+ up_sample=True,
273
+ spa_stride=3
274
+ )
275
+
276
+
277
+ if self.add_upsample:
278
+ # self.upsamplers = nn.ModuleList([PSUpsample2D(out_channels, use_conv=True, use_pixel_shuffle=True, out_channels=out_channels)])
279
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
280
+ else:
281
+ self.upsamplers = None
282
+
283
+ def _set_partial_grad(self):
284
+ for temp_conv in self.temp_convs:
285
+ temp_conv.requires_grad_(True)
286
+ if self.add_upsample:
287
+ self.upsamplers.requires_grad_(True)
288
+
289
+ def forward(self, hidden_states):
290
+ bz = hidden_states.shape[0]
291
+
292
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
293
+ hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w')
294
+ hidden_states = resnet(hidden_states, temb=None)
295
+ hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz)
296
+ hidden_states = temp_conv(hidden_states)
297
+ if self.add_temp_upsample:
298
+ hidden_states = self.temp_conv_up(hidden_states)
299
+
300
+ if self.upsamplers is not None:
301
+ hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w')
302
+ for upsampler in self.upsamplers:
303
+ hidden_states = upsampler(hidden_states)
304
+ hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz)
305
+ return hidden_states
306
+
307
+
308
+ class UNetMidBlock3DConv(nn.Module):
309
+ def __init__(
310
+ self,
311
+ in_channels: int,
312
+ temb_channels: int,
313
+ dropout: float = 0.0,
314
+ num_layers: int = 1,
315
+ resnet_eps: float = 1e-6,
316
+ resnet_time_scale_shift: str = "default", # default, spatial
317
+ resnet_act_fn: str = "swish",
318
+ resnet_groups: int = 32,
319
+ resnet_pre_norm: bool = True,
320
+ add_attention: bool = True,
321
+ attention_head_dim=1,
322
+ output_scale_factor=1.0,
323
+ ):
324
+ super().__init__()
325
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
326
+ self.add_attention = add_attention
327
+
328
+ # there is always at least one resnet
329
+ resnets = [
330
+ ResnetBlock2D(
331
+ in_channels=in_channels,
332
+ out_channels=in_channels,
333
+ temb_channels=temb_channels,
334
+ eps=resnet_eps,
335
+ groups=resnet_groups,
336
+ dropout=dropout,
337
+ time_embedding_norm=resnet_time_scale_shift,
338
+ non_linearity=resnet_act_fn,
339
+ output_scale_factor=output_scale_factor,
340
+ pre_norm=resnet_pre_norm,
341
+ )
342
+ ]
343
+ temp_convs = [
344
+ TemporalConvBlock(
345
+ in_channels,
346
+ in_channels,
347
+ dropout=0.1,
348
+ )
349
+ ]
350
+ attentions = []
351
+
352
+ if attention_head_dim is None:
353
+ attention_head_dim = in_channels
354
+
355
+ for _ in range(num_layers):
356
+ if self.add_attention:
357
+ attentions.append(
358
+ Attention(
359
+ in_channels,
360
+ heads=in_channels // attention_head_dim,
361
+ dim_head=attention_head_dim,
362
+ rescale_output_factor=output_scale_factor,
363
+ eps=resnet_eps,
364
+ norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None,
365
+ spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
366
+ residual_connection=True,
367
+ bias=True,
368
+ upcast_softmax=True,
369
+ _from_deprecated_attn_block=True,
370
+ )
371
+ )
372
+ else:
373
+ attentions.append(None)
374
+
375
+ resnets.append(
376
+ ResnetBlock2D(
377
+ in_channels=in_channels,
378
+ out_channels=in_channels,
379
+ temb_channels=temb_channels,
380
+ eps=resnet_eps,
381
+ groups=resnet_groups,
382
+ dropout=dropout,
383
+ time_embedding_norm=resnet_time_scale_shift,
384
+ non_linearity=resnet_act_fn,
385
+ output_scale_factor=output_scale_factor,
386
+ pre_norm=resnet_pre_norm,
387
+ )
388
+ )
389
+
390
+ temp_convs.append(
391
+ TemporalConvBlock(
392
+ in_channels,
393
+ in_channels,
394
+ dropout=0.1,
395
+ )
396
+ )
397
+
398
+ self.resnets = nn.ModuleList(resnets)
399
+ self.temp_convs = nn.ModuleList(temp_convs)
400
+ self.attentions = nn.ModuleList(attentions)
401
+
402
+ def _set_partial_grad(self):
403
+ for temp_conv in self.temp_convs:
404
+ temp_conv.requires_grad_(True)
405
+
406
+ def forward(
407
+ self,
408
+ hidden_states,
409
+ ):
410
+ bz = hidden_states.shape[0]
411
+ hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w')
412
+
413
+ hidden_states = self.resnets[0](hidden_states, temb=None)
414
+ hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz)
415
+ hidden_states = self.temp_convs[0](hidden_states)
416
+ hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w')
417
+
418
+ for attn, resnet, temp_conv in zip(
419
+ self.attentions, self.resnets[1:], self.temp_convs[1:]
420
+ ):
421
+ hidden_states = attn(hidden_states)
422
+ hidden_states = resnet(hidden_states, temb=None)
423
+ hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz)
424
+ hidden_states = temp_conv(hidden_states)
425
+ return hidden_states
426
+
427
+
428
+ class Encoder3D(nn.Module):
429
+ def __init__(
430
+ self,
431
+ in_channels=3,
432
+ out_channels=3,
433
+ num_blocks=4,
434
+ blocks_temp_li=[False, False, False, False],
435
+ block_out_channels=(64,),
436
+ layers_per_block=2,
437
+ norm_num_groups=32,
438
+ act_fn="silu",
439
+ double_z=True,
440
+ ):
441
+ super().__init__()
442
+ self.layers_per_block = layers_per_block
443
+ self.blocks_temp_li = blocks_temp_li
444
+
445
+ self.conv_in = nn.Conv2d(
446
+ in_channels,
447
+ block_out_channels[0],
448
+ kernel_size=3,
449
+ stride=1,
450
+ padding=1,
451
+ )
452
+
453
+ self.temp_conv_in = nn.Conv3d(
454
+ block_out_channels[0],
455
+ block_out_channels[0],
456
+ (3,1,1),
457
+ padding = (1, 0, 0)
458
+ )
459
+
460
+ self.mid_block = None
461
+ self.down_blocks = nn.ModuleList([])
462
+
463
+ # down
464
+ output_channel = block_out_channels[0]
465
+ for i in range(num_blocks):
466
+ input_channel = output_channel
467
+ output_channel = block_out_channels[i]
468
+ is_final_block = i == len(block_out_channels) - 1
469
+
470
+ down_block = DownEncoderBlock3D(
471
+ num_layers=self.layers_per_block,
472
+ in_channels=input_channel,
473
+ out_channels=output_channel,
474
+ add_downsample=not is_final_block,
475
+ add_temp_downsample=blocks_temp_li[i],
476
+ resnet_eps=1e-6,
477
+ downsample_padding=0,
478
+ resnet_act_fn=act_fn,
479
+ resnet_groups=norm_num_groups,
480
+ )
481
+ self.down_blocks.append(down_block)
482
+
483
+ # mid
484
+ self.mid_block = UNetMidBlock3DConv(
485
+ in_channels=block_out_channels[-1],
486
+ resnet_eps=1e-6,
487
+ resnet_act_fn=act_fn,
488
+ output_scale_factor=1,
489
+ resnet_time_scale_shift="default",
490
+ attention_head_dim=block_out_channels[-1],
491
+ resnet_groups=norm_num_groups,
492
+ temb_channels=None,
493
+ )
494
+
495
+ # out
496
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
497
+ self.conv_act = nn.SiLU()
498
+
499
+ conv_out_channels = 2 * out_channels if double_z else out_channels
500
+
501
+ self.temp_conv_out = nn.Conv3d(block_out_channels[-1], block_out_channels[-1], (3,1,1), padding = (1, 0, 0))
502
+
503
+ self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
504
+
505
+ nn.init.zeros_(self.temp_conv_in.weight)
506
+ nn.init.zeros_(self.temp_conv_in.bias)
507
+ nn.init.zeros_(self.temp_conv_out.weight)
508
+ nn.init.zeros_(self.temp_conv_out.bias)
509
+
510
+ self.gradient_checkpointing = False
511
+
512
+ def forward(self, x):
513
+ '''
514
+ x: [b, c, (tb f), h, w]
515
+ '''
516
+ bz = x.shape[0]
517
+ sample = rearrange(x, 'b c n h w -> (b n) c h w')
518
+ sample = self.conv_in(sample)
519
+
520
+ sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz)
521
+ temp_sample = sample
522
+ sample = self.temp_conv_in(sample)
523
+ sample = sample+temp_sample
524
+ # down
525
+ for b_id, down_block in enumerate(self.down_blocks):
526
+ sample = down_block(sample)
527
+ # middle
528
+ sample = self.mid_block(sample)
529
+
530
+ # post-process
531
+ sample = rearrange(sample, 'b c n h w -> (b n) c h w')
532
+ sample = self.conv_norm_out(sample)
533
+ sample = self.conv_act(sample)
534
+ sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz)
535
+
536
+ temp_sample = sample
537
+ sample = self.temp_conv_out(sample)
538
+ sample = sample+temp_sample
539
+ sample = rearrange(sample, 'b c n h w -> (b n) c h w')
540
+
541
+ sample = self.conv_out(sample)
542
+ sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz)
543
+ return sample
544
+
545
+ class Decoder3D(nn.Module):
546
+ def __init__(
547
+ self,
548
+ in_channels=4,
549
+ out_channels=3,
550
+ num_blocks=4,
551
+ blocks_temp_li=[False, False, False, False],
552
+ block_out_channels=(64,),
553
+ layers_per_block=2,
554
+ norm_num_groups=32,
555
+ act_fn="silu",
556
+ norm_type="group", # group, spatial
557
+ ):
558
+ super().__init__()
559
+ self.layers_per_block = layers_per_block
560
+ self.blocks_temp_li = blocks_temp_li
561
+
562
+ self.conv_in = nn.Conv2d(
563
+ in_channels,
564
+ block_out_channels[-1],
565
+ kernel_size=3,
566
+ stride=1,
567
+ padding=1,
568
+ )
569
+
570
+ self.temp_conv_in = nn.Conv3d(
571
+ block_out_channels[-1],
572
+ block_out_channels[-1],
573
+ (3,1,1),
574
+ padding = (1, 0, 0)
575
+ )
576
+
577
+ self.mid_block = None
578
+ self.up_blocks = nn.ModuleList([])
579
+
580
+ temb_channels = in_channels if norm_type == "spatial" else None
581
+
582
+ # mid
583
+ self.mid_block = UNetMidBlock3DConv(
584
+ in_channels=block_out_channels[-1],
585
+ resnet_eps=1e-6,
586
+ resnet_act_fn=act_fn,
587
+ output_scale_factor=1,
588
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
589
+ attention_head_dim=block_out_channels[-1],
590
+ resnet_groups=norm_num_groups,
591
+ temb_channels=temb_channels,
592
+ )
593
+
594
+ # up
595
+ reversed_block_out_channels = list(reversed(block_out_channels))
596
+ output_channel = reversed_block_out_channels[0]
597
+ for i in range(num_blocks):
598
+ prev_output_channel = output_channel
599
+ output_channel = reversed_block_out_channels[i]
600
+
601
+ is_final_block = i == len(block_out_channels) - 1
602
+
603
+ up_block = UpDecoderBlock3D(
604
+ num_layers=self.layers_per_block + 1,
605
+ in_channels=prev_output_channel,
606
+ out_channels=output_channel,
607
+ add_upsample=not is_final_block,
608
+ add_temp_upsample=blocks_temp_li[i],
609
+ resnet_eps=1e-6,
610
+ resnet_act_fn=act_fn,
611
+ resnet_groups=norm_num_groups,
612
+ temb_channels=temb_channels,
613
+ resnet_time_scale_shift=norm_type,
614
+ )
615
+ self.up_blocks.append(up_block)
616
+ prev_output_channel = output_channel
617
+
618
+ # out
619
+ if norm_type == "spatial":
620
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
621
+ else:
622
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
623
+ self.conv_act = nn.SiLU()
624
+
625
+ self.temp_conv_out = nn.Conv3d(block_out_channels[0], block_out_channels[0], (3,1,1), padding = (1, 0, 0))
626
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
627
+
628
+ nn.init.zeros_(self.temp_conv_in.weight)
629
+ nn.init.zeros_(self.temp_conv_in.bias)
630
+ nn.init.zeros_(self.temp_conv_out.weight)
631
+ nn.init.zeros_(self.temp_conv_out.bias)
632
+
633
+ self.gradient_checkpointing = False
634
+
635
+ def forward(self, z):
636
+ bz = z.shape[0]
637
+ sample = rearrange(z, 'b c n h w -> (b n) c h w')
638
+ sample = self.conv_in(sample)
639
+
640
+ sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz)
641
+ temp_sample = sample
642
+ sample = self.temp_conv_in(sample)
643
+ sample = sample+temp_sample
644
+
645
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
646
+ # middle
647
+ sample = self.mid_block(sample)
648
+ sample = sample.to(upscale_dtype)
649
+
650
+ # up
651
+ for b_id, up_block in enumerate(self.up_blocks):
652
+ sample = up_block(sample)
653
+
654
+ # post-process
655
+ sample = rearrange(sample, 'b c n h w -> (b n) c h w')
656
+ sample = self.conv_norm_out(sample)
657
+ sample = self.conv_act(sample)
658
+
659
+ sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz)
660
+ temp_sample = sample
661
+ sample = self.temp_conv_out(sample)
662
+ sample = sample+temp_sample
663
+ sample = rearrange(sample, 'b c n h w -> (b n) c h w')
664
+
665
+ sample = self.conv_out(sample)
666
+ sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz)
667
+ return sample
668
+
669
+
670
+
671
+ class AllegroAutoencoderKL3D(ModelMixin, ConfigMixin):
672
+ r"""
673
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
674
+
675
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
676
+ for all models (such as downloading or saving).
677
+
678
+ Parameters:
679
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
680
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
681
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
682
+ Tuple of downsample block types.
683
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
684
+ Tuple of upsample block types.
685
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
686
+ Tuple of block output channels.
687
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
688
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
689
+ sample_size (`int`, *optional*, defaults to `256`): Spatial Tiling Size.
690
+ tile_overlap (`tuple`, *optional*, defaults to `(120, 80`): Spatial overlapping size while tiling (height, width)
691
+ chunk_len (`int`, *optional*, defaults to `24`): Temporal Tiling Size.
692
+ t_over (`int`, *optional*, defaults to `8`): Temporal overlapping size while tiling
693
+ scaling_factor (`float`, *optional*, defaults to 0.13235):
694
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
695
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
696
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
697
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
698
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
699
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
700
+ force_upcast (`bool`, *optional*, default to `True`):
701
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
702
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
703
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
704
+ blocks_tempdown_li (`List`, *optional*, defaults to `[True, True, False, False]`): Each item indicates whether each TemporalBlock in the Encoder performs temporal downsampling.
705
+ blocks_tempup_li (`List`, *optional*, defaults to `[False, True, True, False]`): Each item indicates whether each TemporalBlock in the Decoder performs temporal upsampling.
706
+ load_mode (`str`, *optional*, defaults to `full`): Load mode for the model. Can be one of `full`, `encoder_only`, `decoder_only`. which corresponds to loading the full model state dicts, only the encoder state dicts, or only the decoder state dicts.
707
+ """
708
+
709
+ _supports_gradient_checkpointing = True
710
+
711
+ @register_to_config
712
+ def __init__(
713
+ self,
714
+ in_channels: int = 3,
715
+ out_channels: int = 3,
716
+ down_block_num: int = 4,
717
+ up_block_num: int = 4,
718
+ block_out_channels: Tuple[int] = (128,256,512,512),
719
+ layers_per_block: int = 2,
720
+ act_fn: str = "silu",
721
+ latent_channels: int = 4,
722
+ norm_num_groups: int = 32,
723
+ sample_size: int = 320,
724
+ tile_overlap: tuple = (120, 80),
725
+ force_upcast: bool = True,
726
+ chunk_len: int = 24,
727
+ t_over: int = 8,
728
+ scale_factor: float = 0.13235,
729
+ blocks_tempdown_li=[True, True, False, False],
730
+ blocks_tempup_li=[False, True, True, False],
731
+ load_mode = 'full',
732
+ ):
733
+ super().__init__()
734
+
735
+ self.blocks_tempdown_li = blocks_tempdown_li
736
+ self.blocks_tempup_li = blocks_tempup_li
737
+ # pass init params to Encoder
738
+ self.load_mode = load_mode
739
+ if load_mode in ['full', 'encoder_only']:
740
+ self.encoder = Encoder3D(
741
+ in_channels=in_channels,
742
+ out_channels=latent_channels,
743
+ num_blocks=down_block_num,
744
+ blocks_temp_li=blocks_tempdown_li,
745
+ block_out_channels=block_out_channels,
746
+ layers_per_block=layers_per_block,
747
+ act_fn=act_fn,
748
+ norm_num_groups=norm_num_groups,
749
+ double_z=True,
750
+ )
751
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
752
+
753
+ if load_mode in ['full', 'decoder_only']:
754
+ # pass init params to Decoder
755
+ self.decoder = Decoder3D(
756
+ in_channels=latent_channels,
757
+ out_channels=out_channels,
758
+ num_blocks=up_block_num,
759
+ blocks_temp_li=blocks_tempup_li,
760
+ block_out_channels=block_out_channels,
761
+ layers_per_block=layers_per_block,
762
+ norm_num_groups=norm_num_groups,
763
+ act_fn=act_fn,
764
+ )
765
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
766
+
767
+
768
+ # only relevant if vae tiling is enabled
769
+ sample_size = (
770
+ sample_size[0]
771
+ if isinstance(sample_size, (list, tuple))
772
+ else sample_size
773
+ )
774
+ self.tile_overlap = tile_overlap
775
+ self.vae_scale_factor=[4, 8, 8]
776
+ self.scale_factor = scale_factor
777
+ self.sample_size = sample_size
778
+ self.chunk_len = chunk_len
779
+ self.t_over = t_over
780
+
781
+ self.latent_chunk_len = self.chunk_len//4
782
+ self.latent_t_over = self.t_over//4
783
+ self.kernel = (self.chunk_len, self.sample_size, self.sample_size) #(24, 256, 256)
784
+ self.stride = (self.chunk_len - self.t_over, self.sample_size-self.tile_overlap[0], self.sample_size-self.tile_overlap[1]) # (16, 112, 192)
785
+
786
+
787
+ def encode(self, input_imgs: torch.Tensor, return_dict: bool = True, local_batch_size=1) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
788
+ KERNEL = self.kernel
789
+ STRIDE = self.stride
790
+ LOCAL_BS = local_batch_size
791
+ OUT_C = 8
792
+
793
+ B, C, N, H, W = input_imgs.shape
794
+
795
+
796
+ out_n = math.floor((N - KERNEL[0]) / STRIDE[0]) + 1
797
+ out_h = math.floor((H - KERNEL[1]) / STRIDE[1]) + 1
798
+ out_w = math.floor((W - KERNEL[2]) / STRIDE[2]) + 1
799
+
800
+ ## cut video into overlapped small cubes and batch forward
801
+ num = 0
802
+
803
+ out_latent = torch.zeros((out_n*out_h*out_w, OUT_C, KERNEL[0]//4, KERNEL[1]//8, KERNEL[2]//8), device=input_imgs.device, dtype=input_imgs.dtype)
804
+ vae_batch_input = torch.zeros((LOCAL_BS, C, KERNEL[0], KERNEL[1], KERNEL[2]), device=input_imgs.device, dtype=input_imgs.dtype)
805
+
806
+ for i in range(out_n):
807
+ for j in range(out_h):
808
+ for k in range(out_w):
809
+ n_start, n_end = i * STRIDE[0], i * STRIDE[0] + KERNEL[0]
810
+ h_start, h_end = j * STRIDE[1], j * STRIDE[1] + KERNEL[1]
811
+ w_start, w_end = k * STRIDE[2], k * STRIDE[2] + KERNEL[2]
812
+ video_cube = input_imgs[:, :, n_start:n_end, h_start:h_end, w_start:w_end]
813
+ vae_batch_input[num%LOCAL_BS] = video_cube
814
+
815
+ if num%LOCAL_BS == LOCAL_BS-1 or num == out_n*out_h*out_w-1:
816
+ latent = self.encoder(vae_batch_input)
817
+
818
+ if num == out_n*out_h*out_w-1 and num%LOCAL_BS != LOCAL_BS-1:
819
+ out_latent[num-num%LOCAL_BS:] = latent[:num%LOCAL_BS+1]
820
+ else:
821
+ out_latent[num-LOCAL_BS+1:num+1] = latent
822
+ vae_batch_input = torch.zeros((LOCAL_BS, C, KERNEL[0], KERNEL[1], KERNEL[2]), device=input_imgs.device, dtype=input_imgs.dtype)
823
+ num+=1
824
+
825
+ ## flatten the batched out latent to videos and supress the overlapped parts
826
+ B, C, N, H, W = input_imgs.shape
827
+
828
+ out_video_cube = torch.zeros((B, OUT_C, N//4, H//8, W//8), device=input_imgs.device, dtype=input_imgs.dtype)
829
+ OUT_KERNEL = KERNEL[0]//4, KERNEL[1]//8, KERNEL[2]//8
830
+ OUT_STRIDE = STRIDE[0]//4, STRIDE[1]//8, STRIDE[2]//8
831
+ OVERLAP = OUT_KERNEL[0]-OUT_STRIDE[0], OUT_KERNEL[1]-OUT_STRIDE[1], OUT_KERNEL[2]-OUT_STRIDE[2]
832
+
833
+ for i in range(out_n):
834
+ n_start, n_end = i * OUT_STRIDE[0], i * OUT_STRIDE[0] + OUT_KERNEL[0]
835
+ for j in range(out_h):
836
+ h_start, h_end = j * OUT_STRIDE[1], j * OUT_STRIDE[1] + OUT_KERNEL[1]
837
+ for k in range(out_w):
838
+ w_start, w_end = k * OUT_STRIDE[2], k * OUT_STRIDE[2] + OUT_KERNEL[2]
839
+ latent_mean_blend = prepare_for_blend((i, out_n, OVERLAP[0]), (j, out_h, OVERLAP[1]), (k, out_w, OVERLAP[2]), out_latent[i*out_h*out_w+j*out_w+k].unsqueeze(0))
840
+ out_video_cube[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += latent_mean_blend
841
+
842
+ ## final conv
843
+ out_video_cube = rearrange(out_video_cube, 'b c n h w -> (b n) c h w')
844
+ out_video_cube = self.quant_conv(out_video_cube)
845
+ out_video_cube = rearrange(out_video_cube, '(b n) c h w -> b c n h w', b=B)
846
+
847
+ posterior = DiagonalGaussianDistribution(out_video_cube)
848
+
849
+ if not return_dict:
850
+ return (posterior,)
851
+
852
+ return AutoencoderKLOutput(latent_dist=posterior)
853
+
854
+
855
+ def decode(self, input_latents: torch.Tensor, return_dict: bool = True, local_batch_size=1) -> Union[DecoderOutput, torch.Tensor]:
856
+ KERNEL = self.kernel
857
+ STRIDE = self.stride
858
+
859
+ LOCAL_BS = local_batch_size
860
+ OUT_C = 3
861
+ IN_KERNEL = KERNEL[0]//4, KERNEL[1]//8, KERNEL[2]//8
862
+ IN_STRIDE = STRIDE[0]//4, STRIDE[1]//8, STRIDE[2]//8
863
+
864
+ B, C, N, H, W = input_latents.shape
865
+
866
+ ## post quant conv (a mapping)
867
+ input_latents = rearrange(input_latents, 'b c n h w -> (b n) c h w')
868
+ input_latents = self.post_quant_conv(input_latents)
869
+ input_latents = rearrange(input_latents, '(b n) c h w -> b c n h w', b=B)
870
+
871
+ ## out tensor shape
872
+ out_n = math.floor((N - IN_KERNEL[0]) / IN_STRIDE[0]) + 1
873
+ out_h = math.floor((H - IN_KERNEL[1]) / IN_STRIDE[1]) + 1
874
+ out_w = math.floor((W - IN_KERNEL[2]) / IN_STRIDE[2]) + 1
875
+
876
+ ## cut latent into overlapped small cubes and batch forward
877
+ num = 0
878
+ decoded_cube = torch.zeros((out_n*out_h*out_w, OUT_C, KERNEL[0], KERNEL[1], KERNEL[2]), device=input_latents.device, dtype=input_latents.dtype)
879
+ vae_batch_input = torch.zeros((LOCAL_BS, C, IN_KERNEL[0], IN_KERNEL[1], IN_KERNEL[2]), device=input_latents.device, dtype=input_latents.dtype)
880
+ for i in range(out_n):
881
+ for j in range(out_h):
882
+ for k in range(out_w):
883
+ n_start, n_end = i * IN_STRIDE[0], i * IN_STRIDE[0] + IN_KERNEL[0]
884
+ h_start, h_end = j * IN_STRIDE[1], j * IN_STRIDE[1] + IN_KERNEL[1]
885
+ w_start, w_end = k * IN_STRIDE[2], k * IN_STRIDE[2] + IN_KERNEL[2]
886
+ latent_cube = input_latents[:, :, n_start:n_end, h_start:h_end, w_start:w_end]
887
+ vae_batch_input[num%LOCAL_BS] = latent_cube
888
+ if num%LOCAL_BS == LOCAL_BS-1 or num == out_n*out_h*out_w-1:
889
+
890
+ latent = self.decoder(vae_batch_input)
891
+
892
+ if num == out_n*out_h*out_w-1 and num%LOCAL_BS != LOCAL_BS-1:
893
+ decoded_cube[num-num%LOCAL_BS:] = latent[:num%LOCAL_BS+1]
894
+ else:
895
+ decoded_cube[num-LOCAL_BS+1:num+1] = latent
896
+ vae_batch_input = torch.zeros((LOCAL_BS, C, IN_KERNEL[0], IN_KERNEL[1], IN_KERNEL[2]), device=input_latents.device, dtype=input_latents.dtype)
897
+ num+=1
898
+ B, C, N, H, W = input_latents.shape
899
+
900
+ out_video = torch.zeros((B, OUT_C, N*4, H*8, W*8), device=input_latents.device, dtype=input_latents.dtype)
901
+ OVERLAP = KERNEL[0]-STRIDE[0], KERNEL[1]-STRIDE[1], KERNEL[2]-STRIDE[2]
902
+ for i in range(out_n):
903
+ n_start, n_end = i * STRIDE[0], i * STRIDE[0] + KERNEL[0]
904
+ for j in range(out_h):
905
+ h_start, h_end = j * STRIDE[1], j * STRIDE[1] + KERNEL[1]
906
+ for k in range(out_w):
907
+ w_start, w_end = k * STRIDE[2], k * STRIDE[2] + KERNEL[2]
908
+ out_video_blend = prepare_for_blend((i, out_n, OVERLAP[0]), (j, out_h, OVERLAP[1]), (k, out_w, OVERLAP[2]), decoded_cube[i*out_h*out_w+j*out_w+k].unsqueeze(0))
909
+ out_video[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += out_video_blend
910
+
911
+ out_video = rearrange(out_video, 'b c t h w -> b t c h w').contiguous()
912
+
913
+ decoded = out_video
914
+ if not return_dict:
915
+ return (decoded,)
916
+
917
+ return DecoderOutput(sample=decoded)
918
+
919
+ def forward(
920
+ self,
921
+ sample: torch.Tensor,
922
+ sample_posterior: bool = False,
923
+ return_dict: bool = True,
924
+ generator: Optional[torch.Generator] = None,
925
+ encoder_local_batch_size: int = 2,
926
+ decoder_local_batch_size: int = 2,
927
+ ) -> Union[DecoderOutput, torch.Tensor]:
928
+ r"""
929
+ Args:
930
+ sample (`torch.Tensor`): Input sample.
931
+ sample_posterior (`bool`, *optional*, defaults to `False`):
932
+ Whether to sample from the posterior.
933
+ return_dict (`bool`, *optional*, defaults to `True`):
934
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
935
+ generator (`torch.Generator`, *optional*):
936
+ PyTorch random number generator.
937
+ encoder_local_batch_size (`int`, *optional*, defaults to 2):
938
+ Local batch size for the encoder's batch inference.
939
+ decoder_local_batch_size (`int`, *optional*, defaults to 2):
940
+ Local batch size for the decoder's batch inference.
941
+ """
942
+ x = sample
943
+ posterior = self.encode(x, local_batch_size=encoder_local_batch_size).latent_dist
944
+ if sample_posterior:
945
+ z = posterior.sample(generator=generator)
946
+ else:
947
+ z = posterior.mode()
948
+ dec = self.decode(z, local_batch_size=decoder_local_batch_size).sample
949
+
950
+ if not return_dict:
951
+ return (dec,)
952
+
953
+ return DecoderOutput(sample=dec)
954
+
955
+ @classmethod
956
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
957
+ kwargs["torch_type"] = torch.float32
958
+ return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
959
+
960
+
961
+ def prepare_for_blend(n_param, h_param, w_param, x):
962
+ n, n_max, overlap_n = n_param
963
+ h, h_max, overlap_h = h_param
964
+ w, w_max, overlap_w = w_param
965
+ if overlap_n > 0:
966
+ if n > 0: # the head overlap part decays from 0 to 1
967
+ x[:,:,0:overlap_n,:,:] = x[:,:,0:overlap_n,:,:] * (torch.arange(0, overlap_n).float().to(x.device) / overlap_n).reshape(overlap_n,1,1)
968
+ if n < n_max-1: # the tail overlap part decays from 1 to 0
969
+ x[:,:,-overlap_n:,:,:] = x[:,:,-overlap_n:,:,:] * (1 - torch.arange(0, overlap_n).float().to(x.device) / overlap_n).reshape(overlap_n,1,1)
970
+ if h > 0:
971
+ x[:,:,:,0:overlap_h,:] = x[:,:,:,0:overlap_h,:] * (torch.arange(0, overlap_h).float().to(x.device) / overlap_h).reshape(overlap_h,1)
972
+ if h < h_max-1:
973
+ x[:,:,:,-overlap_h:,:] = x[:,:,:,-overlap_h:,:] * (1 - torch.arange(0, overlap_h).float().to(x.device) / overlap_h).reshape(overlap_h,1)
974
+ if w > 0:
975
+ x[:,:,:,:,0:overlap_w] = x[:,:,:,:,0:overlap_w] * (torch.arange(0, overlap_w).float().to(x.device) / overlap_w)
976
+ if w < w_max-1:
977
+ x[:,:,:,:,-overlap_w:] = x[:,:,:,:,-overlap_w:] * (1 - torch.arange(0, overlap_w).float().to(x.device) / overlap_w)
978
+ return x