|
import math |
|
import random |
|
import torch |
|
import numpy as np |
|
from icecream import ic |
|
|
|
def print_rank_0(message): |
|
"""If distributed is initialized, print only on rank 0.""" |
|
if torch.distributed.is_initialized(): |
|
if torch.distributed.get_rank() == 0: |
|
print(message, flush=True) |
|
else: |
|
print(message, flush=True) |
|
|
|
ARGS = None |
|
def set_args(args): |
|
global ARGS |
|
ARGS = args |
|
|
|
def get_args(): |
|
return ARGS |
|
|
|
TOKENIZER = None |
|
def set_tokenizer(tokenizer): |
|
global TOKENIZER |
|
TOKENIZER = tokenizer |
|
|
|
def get_tokenizer(): |
|
return TOKENIZER |
|
from torch import distributed as dist |
|
|
|
class worker_init: |
|
def __init__(self, epoch_id): |
|
self.epoch_id = epoch_id |
|
def _worker_init_fn(self, worker_id): |
|
random.seed(worker_id + self.epoch_id*1e4 + dist.get_rank()*1e8) |
|
|
|
|
|
def batchify(batch): |
|
|
|
video = [data["video"] if data["video"] is not None else None for data in batch] |
|
if all([img is None for img in video]): |
|
video = None |
|
else: |
|
video = torch.cat([img for img in video if img is not None], dim=0) |
|
num_videos_per_sample = torch.LongTensor([data["video"].size(0) if data['video'] is not None else 0 for data in batch]) |
|
num_images_per_sample = torch.LongTensor([0 for data in batch]) |
|
|
|
text = torch.stack([torch.LongTensor(data["text"]['input_ids']) for data in batch], dim=0) |
|
non_padding_mask = torch.stack([torch.LongTensor(data["text"]['non_padding_mask']) for data in batch], dim=0) |
|
non_media_mask = torch.stack([torch.LongTensor(data["text"]['non_media_mask']) for data in batch], dim=0) |
|
prompt_mask = torch.stack([torch.LongTensor(data["text"]['prompt_mask']) for data in batch], dim=0) |
|
videopaths = [data["videopath"] for data in batch] |
|
captions = [data["caption"] for data in batch] |
|
output_batch = { |
|
"pixel_values": None, |
|
"video_pixel_values": video, |
|
"input_ids": text.long(), |
|
"labels": text.long().clone(), |
|
"num_images": num_images_per_sample.long(), |
|
"num_videos": num_videos_per_sample.long(), |
|
"non_padding_mask": non_padding_mask.long(), |
|
"non_media_mask": non_media_mask.long(), |
|
"prompt_mask": prompt_mask.long(), |
|
"videopaths": videopaths, |
|
"captions": captions, |
|
} |
|
|
|
return output_batch |
|
|
|
|
|
def get_param_groups(modules, |
|
no_weight_decay_cond, |
|
scale_lr_cond, |
|
lr_mult): |
|
"""creates param groups based on weight decay condition (regularized vs non regularized) |
|
and learning rate scale condition (args.lr vs lr_mult * args.lr) |
|
scale_lr_cond is used during finetuning where head of the network requires a scaled |
|
version of the base learning rate. |
|
""" |
|
wd_no_scale_lr = [] |
|
wd_scale_lr = [] |
|
no_wd_no_scale_lr = [] |
|
no_wd_scale_lr = [] |
|
for module in modules: |
|
for name, param in module.named_parameters(): |
|
if not param.requires_grad: |
|
continue |
|
|
|
if no_weight_decay_cond is not None: |
|
no_wd = no_weight_decay_cond(name, param) |
|
else: |
|
|
|
no_wd = name.endswith(".bias") or len(param.shape) == 1 |
|
|
|
if scale_lr_cond is not None: |
|
scale_lr = scale_lr_cond(name, param) |
|
else: |
|
scale_lr = False |
|
|
|
if not no_wd and not scale_lr: |
|
wd_no_scale_lr.append(param) |
|
elif not no_wd and scale_lr: |
|
wd_scale_lr.append(param) |
|
elif no_wd and not scale_lr: |
|
no_wd_no_scale_lr.append(param) |
|
else: |
|
no_wd_scale_lr.append(param) |
|
|
|
param_groups = [] |
|
if len(wd_no_scale_lr): |
|
param_groups.append( |
|
{'params': wd_no_scale_lr, 'wd_mult': 1.0, 'lr_mult': 1.0}) |
|
if len(wd_scale_lr): |
|
param_groups.append( |
|
{'params': wd_scale_lr, 'wd_mult': 1.0, 'lr_mult': lr_mult}) |
|
if len(no_wd_no_scale_lr): |
|
param_groups.append({'params': no_wd_no_scale_lr, |
|
'wd_mult': 0.0, 'lr_mult': 1.0}) |
|
if len(no_wd_scale_lr): |
|
param_groups.append( |
|
{'params': no_wd_scale_lr, 'wd_mult': 0.0, 'lr_mult': lr_mult}) |
|
|
|
return param_groups |
|
|
|
def get_cosine_schedule_with_warmup( |
|
optimizer, lr, min_lr, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1 |
|
): |
|
""" |
|
Create a schedule with a learning rate that decreases following the values of the cosine function between the |
|
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the |
|
initial lr set in the optimizer. |
|
|
|
Args: |
|
optimizer ([`~torch.optim.Optimizer`]): |
|
The optimizer for which to schedule the learning rate. |
|
num_warmup_steps (`int`): |
|
The number of steps for the warmup phase. |
|
num_training_steps (`int`): |
|
The total number of training steps. |
|
num_cycles (`float`, *optional*, defaults to 0.5): |
|
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 |
|
following a half-cosine). |
|
last_epoch (`int`, *optional*, defaults to -1): |
|
The index of the last epoch when resuming training. |
|
|
|
Return: |
|
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. |
|
""" |
|
|
|
delta_min_lr = (lr-min_lr)/lr |
|
|
|
def lr_lambda(current_step): |
|
if current_step < num_warmup_steps: |
|
return (1-delta_min_lr) + delta_min_lr * float(current_step) / float(max(1, num_warmup_steps)) |
|
progress = float(current_step - num_warmup_steps) / \ |
|
float(max(1, num_training_steps - num_warmup_steps)) |
|
return delta_min_lr + (1-delta_min_lr) * max(0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) |
|
from torch.optim.lr_scheduler import LambdaLR |
|
return LambdaLR(optimizer, lr_lambda, last_epoch) |