import os import requests from tqdm import tqdm from diffusers import DDPMScheduler def make_1step_sched(): noise_scheduler_1step = DDPMScheduler.from_pretrained("stabilityai/sd-turbo", subfolder="scheduler") noise_scheduler_1step.set_timesteps(1, device="cuda") noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda() return noise_scheduler_1step def my_vae_encoder_fwd(self, sample): sample = self.conv_in(sample) l_blocks = [] # down for down_block in self.down_blocks: l_blocks.append(sample) sample = down_block(sample) # middle sample = self.mid_block(sample) sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) self.current_down_blocks = l_blocks return sample def my_vae_decoder_fwd(self, sample, latent_embeds=None): sample = self.conv_in(sample) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype # middle sample = self.mid_block(sample, latent_embeds) sample = sample.to(upscale_dtype) if not self.ignore_skip: skip_convs = [self.skip_conv_1, self.skip_conv_2, self.skip_conv_3, self.skip_conv_4] # up for idx, up_block in enumerate(self.up_blocks): skip_in = skip_convs[idx](self.incoming_skip_acts[::-1][idx] * self.gamma) # add skip sample = sample + skip_in sample = up_block(sample, latent_embeds) else: for idx, up_block in enumerate(self.up_blocks): sample = up_block(sample, latent_embeds) # post-process if latent_embeds is None: sample = self.conv_norm_out(sample) else: sample = self.conv_norm_out(sample, latent_embeds) sample = self.conv_act(sample) sample = self.conv_out(sample) return sample def download_url(url, outf): if not os.path.exists(outf): print(f"Downloading checkpoint to {outf}") response = requests.get(url, stream=True) total_size_in_bytes = int(response.headers.get('content-length', 0)) block_size = 1024 # 1 Kibibyte progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) with open(outf, 'wb') as file: for data in response.iter_content(block_size): progress_bar.update(len(data)) file.write(data) progress_bar.close() if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: print("ERROR, something went wrong") print(f"Downloaded successfully to {outf}") else: print(f"Skipping download, {outf} already exists")