import os import random import argparse import json import torch from PIL import Image from torchvision import transforms import torchvision.transforms.functional as F from glob import glob def parse_args_paired_training(input_args=None): """ Parses command-line arguments used for configuring an paired session (pix2pix-Turbo). This function sets up an argument parser to handle various training options. Returns: argparse.Namespace: The parsed command-line arguments. """ parser = argparse.ArgumentParser() # args for the loss function parser.add_argument("--gan_disc_type", default="vagan_clip") parser.add_argument("--gan_loss_type", default="multilevel_sigmoid_s") parser.add_argument("--lambda_gan", default=0.5, type=float) parser.add_argument("--lambda_lpips", default=5, type=float) parser.add_argument("--lambda_l2", default=1.0, type=float) parser.add_argument("--lambda_clipsim", default=5.0, type=float) # dataset options parser.add_argument("--dataset_folder", required=True, type=str) parser.add_argument("--train_image_prep", default="resized_crop_512", type=str) parser.add_argument("--test_image_prep", default="resized_crop_512", type=str) # validation eval args parser.add_argument("--eval_freq", default=100, type=int) parser.add_argument("--track_val_fid", default=False, action="store_true") parser.add_argument("--num_samples_eval", type=int, default=100, help="Number of samples to use for all evaluation") parser.add_argument("--viz_freq", type=int, default=100, help="Frequency of visualizing the outputs.") parser.add_argument("--tracker_project_name", type=str, default="train_pix2pix_turbo", help="The name of the wandb project to log to.") # details about the model architecture parser.add_argument("--pretrained_model_name_or_path") parser.add_argument("--revision", type=str, default=None,) parser.add_argument("--variant", type=str, default=None,) parser.add_argument("--tokenizer_name", type=str, default=None) parser.add_argument("--lora_rank_unet", default=8, type=int) parser.add_argument("--lora_rank_vae", default=4, type=int) # training details parser.add_argument("--output_dir", required=True) parser.add_argument("--cache_dir", default=None,) parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument("--resolution", type=int, default=512,) parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.") parser.add_argument("--num_training_epochs", type=int, default=10) parser.add_argument("--max_train_steps", type=int, default=10_000,) parser.add_argument("--checkpointing_steps", type=int, default=500,) parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.",) parser.add_argument("--gradient_checkpointing", action="store_true",) parser.add_argument("--learning_rate", type=float, default=5e-6) parser.add_argument("--lr_scheduler", type=str, default="constant", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]' ), ) parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler.") parser.add_argument("--lr_num_cycles", type=int, default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.", ) parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") parser.add_argument("--dataloader_num_workers", type=int, default=0,) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--allow_tf32", action="store_true", help=( "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) parser.add_argument("--report_to", type=str, default="wandb", help=( 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"],) parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.") parser.add_argument("--set_grads_to_none", action="store_true",) if input_args is not None: args = parser.parse_args(input_args) else: args = parser.parse_args() return args def parse_args_unpaired_training(): """ Parses command-line arguments used for configuring an unpaired session (CycleGAN-Turbo). This function sets up an argument parser to handle various training options. Returns: argparse.Namespace: The parsed command-line arguments. """ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") # fixed random seed parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") # args for the loss function parser.add_argument("--gan_disc_type", default="vagan_clip") parser.add_argument("--gan_loss_type", default="multilevel_sigmoid") parser.add_argument("--lambda_gan", default=0.5, type=float) parser.add_argument("--lambda_idt", default=1, type=float) parser.add_argument("--lambda_cycle", default=1, type=float) parser.add_argument("--lambda_cycle_lpips", default=10.0, type=float) parser.add_argument("--lambda_idt_lpips", default=1.0, type=float) # args for dataset and dataloader options parser.add_argument("--dataset_folder", required=True, type=str) parser.add_argument("--train_img_prep", required=True) parser.add_argument("--val_img_prep", required=True) parser.add_argument("--dataloader_num_workers", type=int, default=0) parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.") parser.add_argument("--max_train_epochs", type=int, default=100) parser.add_argument("--max_train_steps", type=int, default=None) # args for the model parser.add_argument("--pretrained_model_name_or_path", default="stabilityai/sd-turbo") parser.add_argument("--revision", default=None, type=str) parser.add_argument("--variant", default=None, type=str) parser.add_argument("--lora_rank_unet", default=128, type=int) parser.add_argument("--lora_rank_vae", default=4, type=int) # args for validation and logging parser.add_argument("--viz_freq", type=int, default=20) parser.add_argument("--output_dir", type=str, required=True) parser.add_argument("--report_to", type=str, default="wandb") parser.add_argument("--tracker_project_name", type=str, required=True) parser.add_argument("--validation_steps", type=int, default=500,) parser.add_argument("--validation_num_images", type=int, default=-1, help="Number of images to use for validation. -1 to use all images.") parser.add_argument("--checkpointing_steps", type=int, default=500) # args for the optimization options parser.add_argument("--learning_rate", type=float, default=5e-6,) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") parser.add_argument("--max_grad_norm", default=10.0, type=float, help="Max gradient norm.") parser.add_argument("--lr_scheduler", type=str, default="constant", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]' ), ) parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler.") parser.add_argument("--lr_num_cycles", type=int, default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.",) parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1) # memory saving options parser.add_argument("--allow_tf32", action="store_true", help=( "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) parser.add_argument("--gradient_checkpointing", action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.") parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.") args = parser.parse_args() return args def build_transform(image_prep): """ Constructs a transformation pipeline based on the specified image preparation method. Parameters: - image_prep (str): A string describing the desired image preparation Returns: - torchvision.transforms.Compose: A composable sequence of transformations to be applied to images. """ if image_prep == "resized_crop_512": T = transforms.Compose([ transforms.Resize(512, interpolation=transforms.InterpolationMode.LANCZOS), transforms.CenterCrop(512), ]) elif image_prep == "resize_286_randomcrop_256x256_hflip": T = transforms.Compose([ transforms.Resize((286, 286), interpolation=Image.LANCZOS), transforms.RandomCrop((256, 256)), transforms.RandomHorizontalFlip(), ]) elif image_prep in ["resize_256", "resize_256x256"]: T = transforms.Compose([ transforms.Resize((256, 256), interpolation=Image.LANCZOS) ]) elif image_prep in ["resize_512", "resize_512x512"]: T = transforms.Compose([ transforms.Resize((512, 512), interpolation=Image.LANCZOS) ]) elif image_prep == "no_resize": T = transforms.Lambda(lambda x: x) return T class PairedDataset(torch.utils.data.Dataset): def __init__(self, dataset_folder, split, image_prep, tokenizer): """ Itialize the paired dataset object for loading and transforming paired data samples from specified dataset folders. This constructor sets up the paths to input and output folders based on the specified 'split', loads the captions (or prompts) for the input images, and prepares the transformations and tokenizer to be applied on the data. Parameters: - dataset_folder (str): The root folder containing the dataset, expected to include sub-folders for different splits (e.g., 'train_A', 'train_B'). - split (str): The dataset split to use ('train' or 'test'), used to select the appropriate sub-folders and caption files within the dataset folder. - image_prep (str): The image preprocessing transformation to apply to each image. - tokenizer: The tokenizer used for tokenizing the captions (or prompts). """ super().__init__() if split == "train": self.input_folder = os.path.join(dataset_folder, "train_A") self.output_folder = os.path.join(dataset_folder, "train_B") captions = os.path.join(dataset_folder, "train_prompts.json") elif split == "test": self.input_folder = os.path.join(dataset_folder, "test_A") self.output_folder = os.path.join(dataset_folder, "test_B") captions = os.path.join(dataset_folder, "test_prompts.json") with open(captions, "r") as f: self.captions = json.load(f) self.img_names = list(self.captions.keys()) self.T = build_transform(image_prep) self.tokenizer = tokenizer def __len__(self): """ Returns: int: The total number of items in the dataset. """ return len(self.captions) def __getitem__(self, idx): """ Retrieves a dataset item given its index. Each item consists of an input image, its corresponding output image, the captions associated with the input image, and the tokenized form of this caption. This method performs the necessary preprocessing on both the input and output images, including scaling and normalization, as well as tokenizing the caption using a provided tokenizer. Parameters: - idx (int): The index of the item to retrieve. Returns: dict: A dictionary containing the following key-value pairs: - "output_pixel_values": a tensor of the preprocessed output image with pixel values scaled to [-1, 1]. - "conditioning_pixel_values": a tensor of the preprocessed input image with pixel values scaled to [0, 1]. - "caption": the text caption. - "input_ids": a tensor of the tokenized caption. Note: The actual preprocessing steps (scaling and normalization) for images are defined externally and passed to this class through the `image_prep` parameter during initialization. The tokenization process relies on the `tokenizer` also provided at initialization, which should be compatible with the models intended to be used with this dataset. """ img_name = self.img_names[idx] input_img = Image.open(os.path.join(self.input_folder, img_name)) output_img = Image.open(os.path.join(self.output_folder, img_name)) caption = self.captions[img_name] # input images scaled to 0,1 img_t = self.T(input_img) img_t = F.to_tensor(img_t) # output images scaled to -1,1 output_t = self.T(output_img) output_t = F.to_tensor(output_t) output_t = F.normalize(output_t, mean=[0.5], std=[0.5]) input_ids = self.tokenizer( caption, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ).input_ids return { "output_pixel_values": output_t, "conditioning_pixel_values": img_t, "caption": caption, "input_ids": input_ids, } class UnpairedDataset(torch.utils.data.Dataset): def __init__(self, dataset_folder, split, image_prep, tokenizer): """ A dataset class for loading unpaired data samples from two distinct domains (source and target), typically used in unsupervised learning tasks like image-to-image translation. The class supports loading images from specified dataset folders, applying predefined image preprocessing transformations, and utilizing fixed textual prompts (captions) for each domain, tokenized using a provided tokenizer. Parameters: - dataset_folder (str): Base directory of the dataset containing subdirectories (train_A, train_B, test_A, test_B) - split (str): Indicates the dataset split to use. Expected values are 'train' or 'test'. - image_prep (str): he image preprocessing transformation to apply to each image. - tokenizer: The tokenizer used for tokenizing the captions (or prompts). """ super().__init__() if split == "train": self.source_folder = os.path.join(dataset_folder, "train_A") self.target_folder = os.path.join(dataset_folder, "train_B") elif split == "test": self.source_folder = os.path.join(dataset_folder, "test_A") self.target_folder = os.path.join(dataset_folder, "test_B") self.tokenizer = tokenizer with open(os.path.join(dataset_folder, "fixed_prompt_a.txt"), "r") as f: self.fixed_caption_src = f.read().strip() self.input_ids_src = self.tokenizer( self.fixed_caption_src, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ).input_ids with open(os.path.join(dataset_folder, "fixed_prompt_b.txt"), "r") as f: self.fixed_caption_tgt = f.read().strip() self.input_ids_tgt = self.tokenizer( self.fixed_caption_tgt, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ).input_ids # find all images in the source and target folders with all IMG extensions self.l_imgs_src = [] for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif"]: self.l_imgs_src.extend(glob(os.path.join(self.source_folder, ext))) self.l_imgs_tgt = [] for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif"]: self.l_imgs_tgt.extend(glob(os.path.join(self.target_folder, ext))) self.T = build_transform(image_prep) def __len__(self): """ Returns: int: The total number of items in the dataset. """ return len(self.l_imgs_src) + len(self.l_imgs_tgt) def __getitem__(self, index): """ Fetches a pair of unaligned images from the source and target domains along with their corresponding tokenized captions. For the source domain, if the requested index is within the range of available images, the specific image at that index is chosen. If the index exceeds the number of source images, a random source image is selected. For the target domain, an image is always randomly selected, irrespective of the index, to maintain the unpaired nature of the dataset. Both images are preprocessed according to the specified image transformation `T`, and normalized. The fixed captions for both domains are included along with their tokenized forms. Parameters: - index (int): The index of the source image to retrieve. Returns: dict: A dictionary containing processed data for a single training example, with the following keys: - "pixel_values_src": The processed source image - "pixel_values_tgt": The processed target image - "caption_src": The fixed caption of the source domain. - "caption_tgt": The fixed caption of the target domain. - "input_ids_src": The source domain's fixed caption tokenized. - "input_ids_tgt": The target domain's fixed caption tokenized. """ if index < len(self.l_imgs_src): img_path_src = self.l_imgs_src[index] else: img_path_src = random.choice(self.l_imgs_src) img_path_tgt = random.choice(self.l_imgs_tgt) img_pil_src = Image.open(img_path_src).convert("RGB") img_pil_tgt = Image.open(img_path_tgt).convert("RGB") img_t_src = F.to_tensor(self.T(img_pil_src)) img_t_tgt = F.to_tensor(self.T(img_pil_tgt)) img_t_src = F.normalize(img_t_src, mean=[0.5], std=[0.5]) img_t_tgt = F.normalize(img_t_tgt, mean=[0.5], std=[0.5]) return { "pixel_values_src": img_t_src, "pixel_values_tgt": img_t_tgt, "caption_src": self.fixed_caption_src, "caption_tgt": self.fixed_caption_tgt, "input_ids_src": self.input_ids_src, "input_ids_tgt": self.input_ids_tgt, }