Hritik commited on
Commit
1040e55
1 Parent(s): 0187095

add pipeline video

Browse files
Files changed (46) hide show
  1. pipeline_video/__init__.py +0 -0
  2. pipeline_video/__pycache__/utils.cpython-310.pyc +0 -0
  3. pipeline_video/__pycache__/utils.cpython-39.pyc +0 -0
  4. pipeline_video/data_utils/__init__.py +29 -0
  5. pipeline_video/data_utils/__pycache__/__init__.cpython-310.pyc +0 -0
  6. pipeline_video/data_utils/__pycache__/__init__.cpython-39.pyc +0 -0
  7. pipeline_video/data_utils/__pycache__/randaugment.cpython-310.pyc +0 -0
  8. pipeline_video/data_utils/__pycache__/randaugment.cpython-39.pyc +0 -0
  9. pipeline_video/data_utils/__pycache__/registry.cpython-310.pyc +0 -0
  10. pipeline_video/data_utils/__pycache__/registry.cpython-39.pyc +0 -0
  11. pipeline_video/data_utils/__pycache__/xgpt3_dataset.cpython-310.pyc +0 -0
  12. pipeline_video/data_utils/__pycache__/xgpt3_dataset.cpython-39.pyc +0 -0
  13. pipeline_video/data_utils/processors/__init__.py +9 -0
  14. pipeline_video/data_utils/processors/__pycache__/__init__.cpython-310.pyc +0 -0
  15. pipeline_video/data_utils/processors/__pycache__/__init__.cpython-39.pyc +0 -0
  16. pipeline_video/data_utils/processors/__pycache__/builder.cpython-310.pyc +0 -0
  17. pipeline_video/data_utils/processors/__pycache__/builder.cpython-39.pyc +0 -0
  18. pipeline_video/data_utils/processors/__pycache__/caption_processor.cpython-310.pyc +0 -0
  19. pipeline_video/data_utils/processors/__pycache__/caption_processor.cpython-39.pyc +0 -0
  20. pipeline_video/data_utils/processors/__pycache__/default_processor.cpython-310.pyc +0 -0
  21. pipeline_video/data_utils/processors/__pycache__/default_processor.cpython-39.pyc +0 -0
  22. pipeline_video/data_utils/processors/builder.py +12 -0
  23. pipeline_video/data_utils/processors/caption_processor.py +53 -0
  24. pipeline_video/data_utils/processors/default_processor.py +42 -0
  25. pipeline_video/data_utils/randaugment.py +345 -0
  26. pipeline_video/data_utils/registry.py +422 -0
  27. pipeline_video/data_utils/xgpt3_dataset.py +204 -0
  28. pipeline_video/entailment_inference.py +122 -0
  29. pipeline_video/mplug_owl_video/__init__.py +77 -0
  30. pipeline_video/mplug_owl_video/__pycache__/__init__.cpython-310.pyc +0 -0
  31. pipeline_video/mplug_owl_video/__pycache__/__init__.cpython-39.pyc +0 -0
  32. pipeline_video/mplug_owl_video/__pycache__/configuration_mplug_owl.cpython-310.pyc +0 -0
  33. pipeline_video/mplug_owl_video/__pycache__/configuration_mplug_owl.cpython-39.pyc +0 -0
  34. pipeline_video/mplug_owl_video/__pycache__/modeling_mplug_owl.cpython-310.pyc +0 -0
  35. pipeline_video/mplug_owl_video/__pycache__/modeling_mplug_owl.cpython-39.pyc +0 -0
  36. pipeline_video/mplug_owl_video/__pycache__/processing_mplug_owl.cpython-310.pyc +0 -0
  37. pipeline_video/mplug_owl_video/__pycache__/processing_mplug_owl.cpython-39.pyc +0 -0
  38. pipeline_video/mplug_owl_video/__pycache__/tokenization_mplug_owl.cpython-310.pyc +0 -0
  39. pipeline_video/mplug_owl_video/__pycache__/tokenization_mplug_owl.cpython-39.pyc +0 -0
  40. pipeline_video/mplug_owl_video/configuration_mplug_owl.py +296 -0
  41. pipeline_video/mplug_owl_video/modeling_mplug_owl.py +1938 -0
  42. pipeline_video/mplug_owl_video/processing_mplug_owl.py +246 -0
  43. pipeline_video/mplug_owl_video/tokenization_mplug_owl.py +62 -0
  44. pipeline_video/nle_inference.py +126 -0
  45. pipeline_video/train.py +263 -0
  46. pipeline_video/utils.py +160 -0
pipeline_video/__init__.py ADDED
File without changes
pipeline_video/__pycache__/utils.cpython-310.pyc ADDED
Binary file (7.93 kB). View file
 
pipeline_video/__pycache__/utils.cpython-39.pyc ADDED
Binary file (6.35 kB). View file
 
pipeline_video/data_utils/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from .processors.builder import build_processors
2
+ from .xgpt3_dataset import MultiModalDataset
3
+ from mplug_owl_video.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor
4
+
5
+ def train_valid_test_datasets_provider(data_path, config, tokenizer, seq_length=1024, loss_objective = 'sequential'):
6
+ """Build train and valid datasets."""
7
+ print('> building train and validation datasets for mPLUG-Owl ...')
8
+ train_ds, valid_ds = build_train_valid_test_datasets(
9
+ input_file=data_path,
10
+ tokenizer=tokenizer,
11
+ max_length=seq_length,
12
+ config=config, loss_objective = loss_objective)
13
+ print("> finished creating mPLUG-Owl datasets ...")
14
+
15
+ return train_ds, valid_ds
16
+
17
+
18
+ def build_train_valid_test_datasets(input_file, tokenizer, max_length=80, config=None):
19
+
20
+ # train_processors = build_processors(config['train_processors'])
21
+ # valid_processors = build_processors(config['valid_processors'])
22
+
23
+ image_processor = MplugOwlImageProcessor.from_pretrained(config['pretrained_ckpt'])
24
+ processor = MplugOwlProcessor(image_processor, tokenizer)
25
+
26
+ assert len(input_file) == 2 # If you have files more than 2, modify code at here or merger them into train and dev
27
+ train_ds = MultiModalDataset(input_file[0], tokenizer, processor, max_length, loss_objective = loss_objective)
28
+ valid_ds = MultiModalDataset(input_file[1], tokenizer, processor, max_length, loss_objective = loss_objective)
29
+ return (train_ds, valid_ds)
pipeline_video/data_utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.21 kB). View file
 
pipeline_video/data_utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (1.21 kB). View file
 
pipeline_video/data_utils/__pycache__/randaugment.cpython-310.pyc ADDED
Binary file (10.4 kB). View file
 
pipeline_video/data_utils/__pycache__/randaugment.cpython-39.pyc ADDED
Binary file (10.6 kB). View file
 
pipeline_video/data_utils/__pycache__/registry.cpython-310.pyc ADDED
Binary file (13 kB). View file
 
pipeline_video/data_utils/__pycache__/registry.cpython-39.pyc ADDED
Binary file (13 kB). View file
 
pipeline_video/data_utils/__pycache__/xgpt3_dataset.cpython-310.pyc ADDED
Binary file (6.59 kB). View file
 
pipeline_video/data_utils/__pycache__/xgpt3_dataset.cpython-39.pyc ADDED
Binary file (5.45 kB). View file
 
pipeline_video/data_utils/processors/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba. All rights reserved.
2
+ from .builder import PROCESSORS, build_processors
3
+ from .default_processor import DefaultProcessor
4
+ from .caption_processor import CaptionProcessor
5
+
6
+ __all__ = [
7
+ 'PROCESSORS', 'build_processors',
8
+ 'DefaultProcessor', 'CaptionProcessor'
9
+ ]
pipeline_video/data_utils/processors/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (409 Bytes). View file
 
pipeline_video/data_utils/processors/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (407 Bytes). View file
 
pipeline_video/data_utils/processors/__pycache__/builder.cpython-310.pyc ADDED
Binary file (538 Bytes). View file
 
pipeline_video/data_utils/processors/__pycache__/builder.cpython-39.pyc ADDED
Binary file (536 Bytes). View file
 
pipeline_video/data_utils/processors/__pycache__/caption_processor.cpython-310.pyc ADDED
Binary file (1.71 kB). View file
 
pipeline_video/data_utils/processors/__pycache__/caption_processor.cpython-39.pyc ADDED
Binary file (1.7 kB). View file
 
pipeline_video/data_utils/processors/__pycache__/default_processor.cpython-310.pyc ADDED
Binary file (1.37 kB). View file
 
pipeline_video/data_utils/processors/__pycache__/default_processor.cpython-39.pyc ADDED
Binary file (1.36 kB). View file
 
pipeline_video/data_utils/processors/builder.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+
4
+ from data_utils.registry import Registry, build_from_cfg
5
+
6
+ PROCESSORS = Registry('processors')
7
+
8
+ def build_processors(processors_cfg):
9
+ processors = dict()
10
+ for task, processor in processors_cfg.items():
11
+ processors[task] = build_from_cfg(processor, PROCESSORS)
12
+ return processors
pipeline_video/data_utils/processors/caption_processor.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+ import random
5
+
6
+ from data_utils.randaugment import RandomAugment
7
+ from .builder import PROCESSORS
8
+
9
+
10
+ @PROCESSORS.register_module()
11
+ class CaptionProcessor:
12
+ def __init__(self, image_size=224, min_scale = 0.5, randaug=False):
13
+ self.image_size = image_size
14
+ self.min_scale = min_scale
15
+
16
+ if randaug:
17
+ self.image_transform = transforms.Compose([
18
+ transforms.RandomResizedCrop(image_size,scale=(min_scale, 1.0), interpolation=Image.BICUBIC),
19
+ transforms.RandomHorizontalFlip(),
20
+ RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness',
21
+ 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
24
+ ])
25
+ else:
26
+ self.image_transform = transforms.Compose([
27
+ transforms.RandomResizedCrop(image_size,scale=(min_scale, 1.0), interpolation=Image.BICUBIC),
28
+ transforms.RandomHorizontalFlip(),
29
+ transforms.ToTensor(),
30
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
31
+ ])
32
+ self.text_transform = None
33
+
34
+ def __call__(self, image, text):
35
+ assert image or text
36
+
37
+ if image:
38
+ image_input = self.image_transform(image)
39
+ else:
40
+ image_input = None
41
+
42
+ if text:
43
+ if isinstance(text["prompt"], list):
44
+ prompt = random.choice(text["prompt"])
45
+ else:
46
+ prompt = text["prompt"]
47
+ text_input = dict(
48
+ prompt=prompt,
49
+ completion=text["text"],
50
+ )
51
+ else:
52
+ text_input = None
53
+ return image_input, text_input
pipeline_video/data_utils/processors/default_processor.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+ import random
5
+
6
+ from data_utils.randaugment import RandomAugment
7
+ from .builder import PROCESSORS
8
+
9
+
10
+ @PROCESSORS.register_module()
11
+ class DefaultProcessor:
12
+ def __init__(self, image_size=224):
13
+ self.image_size = image_size
14
+
15
+ self.image_transform = transforms.Compose([
16
+ transforms.Resize((image_size, image_size),interpolation=Image.BICUBIC),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
19
+ ])
20
+
21
+ self.text_transform = None
22
+
23
+ def __call__(self, image, text):
24
+ assert image or text
25
+
26
+ if image:
27
+ image_input = self.image_transform(image)
28
+ else:
29
+ image_input = None
30
+
31
+ if text:
32
+ if isinstance(text["prompt"], list):
33
+ prompt = random.choice(text["prompt"])
34
+ else:
35
+ prompt = text["prompt"]
36
+ text_input = dict(
37
+ prompt=prompt,
38
+ completion=text["text"],
39
+ )
40
+ else:
41
+ text_input = None
42
+ return image_input, text_input
pipeline_video/data_utils/randaugment.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from PIL import Image
4
+
5
+
6
+ ## aug functions
7
+ def identity_func(img):
8
+ return img
9
+
10
+
11
+ def autocontrast_func(img, cutoff=0):
12
+ '''
13
+ same output as PIL.ImageOps.autocontrast
14
+ '''
15
+ n_bins = 256
16
+
17
+ def tune_channel(ch):
18
+ n = ch.size
19
+ cut = cutoff * n // 100
20
+ if cut == 0:
21
+ high, low = ch.max(), ch.min()
22
+ else:
23
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
24
+ low = np.argwhere(np.cumsum(hist) > cut)
25
+ low = 0 if low.shape[0] == 0 else low[0]
26
+ high = np.argwhere(np.cumsum(hist[::-1]) > cut)
27
+ high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
28
+ if high <= low:
29
+ table = np.arange(n_bins)
30
+ else:
31
+ scale = (n_bins - 1) / (high - low)
32
+ offset = -low * scale
33
+ table = np.arange(n_bins) * scale + offset
34
+ table[table < 0] = 0
35
+ table[table > n_bins - 1] = n_bins - 1
36
+ table = table.clip(0, 255).astype(np.uint8)
37
+ return table[ch]
38
+
39
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
40
+ out = cv2.merge(channels)
41
+ return out
42
+
43
+
44
+ def equalize_func(img):
45
+ '''
46
+ same output as PIL.ImageOps.equalize
47
+ PIL's implementation is different from cv2.equalize
48
+ '''
49
+ n_bins = 256
50
+
51
+ def tune_channel(ch):
52
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
53
+ non_zero_hist = hist[hist != 0].reshape(-1)
54
+ step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
55
+ if step == 0: return ch
56
+ n = np.empty_like(hist)
57
+ n[0] = step // 2
58
+ n[1:] = hist[:-1]
59
+ table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
60
+ return table[ch]
61
+
62
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
63
+ out = cv2.merge(channels)
64
+ return out
65
+
66
+
67
+ def rotate_func(img, degree, fill=(0, 0, 0)):
68
+ '''
69
+ like PIL, rotate by degree, not radians
70
+ '''
71
+ H, W = img.shape[0], img.shape[1]
72
+ center = W / 2, H / 2
73
+ M = cv2.getRotationMatrix2D(center, degree, 1)
74
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
75
+ return out
76
+
77
+
78
+ def solarize_func(img, thresh=128):
79
+ '''
80
+ same output as PIL.ImageOps.posterize
81
+ '''
82
+ table = np.array([el if el < thresh else 255 - el for el in range(256)])
83
+ table = table.clip(0, 255).astype(np.uint8)
84
+ out = table[img]
85
+ return out
86
+
87
+
88
+ def color_func(img, factor):
89
+ '''
90
+ same output as PIL.ImageEnhance.Color
91
+ '''
92
+ ## implementation according to PIL definition, quite slow
93
+ # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
94
+ # out = blend(degenerate, img, factor)
95
+ # M = (
96
+ # np.eye(3) * factor
97
+ # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
98
+ # )[np.newaxis, np.newaxis, :]
99
+ M = (
100
+ np.float32([
101
+ [0.886, -0.114, -0.114],
102
+ [-0.587, 0.413, -0.587],
103
+ [-0.299, -0.299, 0.701]]) * factor
104
+ + np.float32([[0.114], [0.587], [0.299]])
105
+ )
106
+ out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
107
+ return out
108
+
109
+
110
+ def contrast_func(img, factor):
111
+ """
112
+ same output as PIL.ImageEnhance.Contrast
113
+ """
114
+ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
115
+ table = np.array([(
116
+ el - mean) * factor + mean
117
+ for el in range(256)
118
+ ]).clip(0, 255).astype(np.uint8)
119
+ out = table[img]
120
+ return out
121
+
122
+
123
+ def brightness_func(img, factor):
124
+ '''
125
+ same output as PIL.ImageEnhance.Contrast
126
+ '''
127
+ table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
128
+ out = table[img]
129
+ return out
130
+
131
+
132
+ def sharpness_func(img, factor):
133
+ '''
134
+ The differences the this result and PIL are all on the 4 boundaries, the center
135
+ areas are same
136
+ '''
137
+ kernel = np.ones((3, 3), dtype=np.float32)
138
+ kernel[1][1] = 5
139
+ kernel /= 13
140
+ degenerate = cv2.filter2D(img, -1, kernel)
141
+ if factor == 0.0:
142
+ out = degenerate
143
+ elif factor == 1.0:
144
+ out = img
145
+ else:
146
+ out = img.astype(np.float32)
147
+ degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
148
+ out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
149
+ out = out.astype(np.uint8)
150
+ return out
151
+
152
+
153
+ def shear_x_func(img, factor, fill=(0, 0, 0)):
154
+ H, W = img.shape[0], img.shape[1]
155
+ M = np.float32([[1, factor, 0], [0, 1, 0]])
156
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
157
+ return out
158
+
159
+
160
+ def translate_x_func(img, offset, fill=(0, 0, 0)):
161
+ '''
162
+ same output as PIL.Image.transform
163
+ '''
164
+ H, W = img.shape[0], img.shape[1]
165
+ M = np.float32([[1, 0, -offset], [0, 1, 0]])
166
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
167
+ return out
168
+
169
+
170
+ def translate_y_func(img, offset, fill=(0, 0, 0)):
171
+ '''
172
+ same output as PIL.Image.transform
173
+ '''
174
+ H, W = img.shape[0], img.shape[1]
175
+ M = np.float32([[1, 0, 0], [0, 1, -offset]])
176
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
177
+ return out
178
+
179
+
180
+ def posterize_func(img, bits):
181
+ '''
182
+ same output as PIL.ImageOps.posterize
183
+ '''
184
+ out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
185
+ return out
186
+
187
+
188
+ def shear_y_func(img, factor, fill=(0, 0, 0)):
189
+ H, W = img.shape[0], img.shape[1]
190
+ M = np.float32([[1, 0, 0], [factor, 1, 0]])
191
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
192
+ return out
193
+
194
+
195
+ def cutout_func(img, pad_size, replace=(0, 0, 0)):
196
+ replace = np.array(replace, dtype=np.uint8)
197
+ H, W = img.shape[0], img.shape[1]
198
+ rh, rw = np.random.random(2)
199
+ pad_size = pad_size // 2
200
+ ch, cw = int(rh * H), int(rw * W)
201
+ x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
202
+ y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
203
+ out = img.copy()
204
+ out[x1:x2, y1:y2, :] = replace
205
+ return out
206
+
207
+
208
+ ### level to args
209
+ def enhance_level_to_args(MAX_LEVEL):
210
+ def level_to_args(level):
211
+ return ((level / MAX_LEVEL) * 1.8 + 0.1,)
212
+ return level_to_args
213
+
214
+
215
+ def shear_level_to_args(MAX_LEVEL, replace_value):
216
+ def level_to_args(level):
217
+ level = (level / MAX_LEVEL) * 0.3
218
+ if np.random.random() > 0.5: level = -level
219
+ return (level, replace_value)
220
+
221
+ return level_to_args
222
+
223
+
224
+ def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
225
+ def level_to_args(level):
226
+ level = (level / MAX_LEVEL) * float(translate_const)
227
+ if np.random.random() > 0.5: level = -level
228
+ return (level, replace_value)
229
+
230
+ return level_to_args
231
+
232
+
233
+ def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
234
+ def level_to_args(level):
235
+ level = int((level / MAX_LEVEL) * cutout_const)
236
+ return (level, replace_value)
237
+
238
+ return level_to_args
239
+
240
+
241
+ def solarize_level_to_args(MAX_LEVEL):
242
+ def level_to_args(level):
243
+ level = int((level / MAX_LEVEL) * 256)
244
+ return (level, )
245
+ return level_to_args
246
+
247
+
248
+ def none_level_to_args(level):
249
+ return ()
250
+
251
+
252
+ def posterize_level_to_args(MAX_LEVEL):
253
+ def level_to_args(level):
254
+ level = int((level / MAX_LEVEL) * 4)
255
+ return (level, )
256
+ return level_to_args
257
+
258
+
259
+ def rotate_level_to_args(MAX_LEVEL, replace_value):
260
+ def level_to_args(level):
261
+ level = (level / MAX_LEVEL) * 30
262
+ if np.random.random() < 0.5:
263
+ level = -level
264
+ return (level, replace_value)
265
+
266
+ return level_to_args
267
+
268
+
269
+ func_dict = {
270
+ 'Identity': identity_func,
271
+ 'AutoContrast': autocontrast_func,
272
+ 'Equalize': equalize_func,
273
+ 'Rotate': rotate_func,
274
+ 'Solarize': solarize_func,
275
+ 'Color': color_func,
276
+ 'Contrast': contrast_func,
277
+ 'Brightness': brightness_func,
278
+ 'Sharpness': sharpness_func,
279
+ 'ShearX': shear_x_func,
280
+ 'TranslateX': translate_x_func,
281
+ 'TranslateY': translate_y_func,
282
+ 'Posterize': posterize_func,
283
+ 'ShearY': shear_y_func,
284
+ }
285
+
286
+ translate_const = 10
287
+ MAX_LEVEL = 10
288
+ replace_value = (128, 128, 128)
289
+ arg_dict = {
290
+ 'Identity': none_level_to_args,
291
+ 'AutoContrast': none_level_to_args,
292
+ 'Equalize': none_level_to_args,
293
+ 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
294
+ 'Solarize': solarize_level_to_args(MAX_LEVEL),
295
+ 'Color': enhance_level_to_args(MAX_LEVEL),
296
+ 'Contrast': enhance_level_to_args(MAX_LEVEL),
297
+ 'Brightness': enhance_level_to_args(MAX_LEVEL),
298
+ 'Sharpness': enhance_level_to_args(MAX_LEVEL),
299
+ 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
300
+ 'TranslateX': translate_level_to_args(
301
+ translate_const, MAX_LEVEL, replace_value
302
+ ),
303
+ 'TranslateY': translate_level_to_args(
304
+ translate_const, MAX_LEVEL, replace_value
305
+ ),
306
+ 'Posterize': posterize_level_to_args(MAX_LEVEL),
307
+ 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
308
+ }
309
+
310
+
311
+ class RandomAugment(object):
312
+
313
+ def __init__(self, N=2, M=10, isPIL=False, returnPIL=False, augs=[]):
314
+ self.N = N
315
+ self.M = M
316
+ self.isPIL = isPIL
317
+ self.returnPIL = returnPIL
318
+ if augs:
319
+ self.augs = augs
320
+ else:
321
+ self.augs = list(arg_dict.keys())
322
+
323
+ def get_random_ops(self):
324
+ sampled_ops = np.random.choice(self.augs, self.N)
325
+ return [(op, 0.5, self.M) for op in sampled_ops]
326
+
327
+ def __call__(self, img):
328
+ if self.isPIL:
329
+ img = np.array(img)
330
+ ops = self.get_random_ops()
331
+ for name, prob, level in ops:
332
+ if np.random.random() > prob:
333
+ continue
334
+ args = arg_dict[name](level)
335
+ img = func_dict[name](img, *args)
336
+ if self.returnPIL:
337
+ img = img.astype('uint8')
338
+ img = Image.fromarray(img)
339
+ return img
340
+
341
+
342
+ if __name__ == '__main__':
343
+ a = RandomAugment()
344
+ img = np.random.randn(32, 32, 3)
345
+ a(img)
pipeline_video/data_utils/registry.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba. All rights reserved.
2
+ import inspect
3
+ import warnings
4
+ import functools
5
+ from functools import partial
6
+ from typing import Any, Dict, Optional
7
+ from collections import abc
8
+ from inspect import getfullargspec
9
+
10
+
11
+ def is_seq_of(seq, expected_type, seq_type=None):
12
+ """Check whether it is a sequence of some type.
13
+ Args:
14
+ seq (Sequence): The sequence to be checked.
15
+ expected_type (type): Expected type of sequence items.
16
+ seq_type (type, optional): Expected sequence type.
17
+ Returns:
18
+ bool: Whether the sequence is valid.
19
+ """
20
+ if seq_type is None:
21
+ exp_seq_type = abc.Sequence
22
+ else:
23
+ assert isinstance(seq_type, type)
24
+ exp_seq_type = seq_type
25
+ if not isinstance(seq, exp_seq_type):
26
+ return False
27
+ for item in seq:
28
+ if not isinstance(item, expected_type):
29
+ return False
30
+ return True
31
+
32
+
33
+ def deprecated_api_warning(name_dict, cls_name=None):
34
+ """A decorator to check if some arguments are deprecate and try to replace
35
+ deprecate src_arg_name to dst_arg_name.
36
+ Args:
37
+ name_dict(dict):
38
+ key (str): Deprecate argument names.
39
+ val (str): Expected argument names.
40
+ Returns:
41
+ func: New function.
42
+ """
43
+
44
+ def api_warning_wrapper(old_func):
45
+
46
+ @functools.wraps(old_func)
47
+ def new_func(*args, **kwargs):
48
+ # get the arg spec of the decorated method
49
+ args_info = getfullargspec(old_func)
50
+ # get name of the function
51
+ func_name = old_func.__name__
52
+ if cls_name is not None:
53
+ func_name = f'{cls_name}.{func_name}'
54
+ if args:
55
+ arg_names = args_info.args[:len(args)]
56
+ for src_arg_name, dst_arg_name in name_dict.items():
57
+ if src_arg_name in arg_names:
58
+ warnings.warn(
59
+ f'"{src_arg_name}" is deprecated in '
60
+ f'`{func_name}`, please use "{dst_arg_name}" '
61
+ 'instead', DeprecationWarning)
62
+ arg_names[arg_names.index(src_arg_name)] = dst_arg_name
63
+ if kwargs:
64
+ for src_arg_name, dst_arg_name in name_dict.items():
65
+ if src_arg_name in kwargs:
66
+
67
+ assert dst_arg_name not in kwargs, (
68
+ f'The expected behavior is to replace '
69
+ f'the deprecated key `{src_arg_name}` to '
70
+ f'new key `{dst_arg_name}`, but got them '
71
+ f'in the arguments at the same time, which '
72
+ f'is confusing. `{src_arg_name} will be '
73
+ f'deprecated in the future, please '
74
+ f'use `{dst_arg_name}` instead.')
75
+
76
+ warnings.warn(
77
+ f'"{src_arg_name}" is deprecated in '
78
+ f'`{func_name}`, please use "{dst_arg_name}" '
79
+ 'instead', DeprecationWarning)
80
+ kwargs[dst_arg_name] = kwargs.pop(src_arg_name)
81
+
82
+ # apply converted arguments to the decorated method
83
+ output = old_func(*args, **kwargs)
84
+ return output
85
+
86
+ return new_func
87
+
88
+ return api_warning_wrapper
89
+
90
+
91
+
92
+ def build_from_cfg(cfg: Dict,
93
+ registry: 'Registry',
94
+ default_args: Optional[Dict] = None) -> Any:
95
+ """Build a module from config dict when it is a class configuration, or
96
+ call a function from config dict when it is a function configuration.
97
+
98
+ Example:
99
+ >>> MODELS = Registry('models')
100
+ >>> @MODELS.register_module()
101
+ >>> class ResNet:
102
+ >>> pass
103
+ >>> resnet = build_from_cfg(dict(type='Resnet'), MODELS)
104
+ >>> # Returns an instantiated object
105
+ >>> @MODELS.register_module()
106
+ >>> def resnet50():
107
+ >>> pass
108
+ >>> resnet = build_from_cfg(dict(type='resnet50'), MODELS)
109
+ >>> # Return a result of the calling function
110
+
111
+ Args:
112
+ cfg (dict): Config dict. It should at least contain the key "type".
113
+ registry (:obj:`Registry`): The registry to search the type from.
114
+ default_args (dict, optional): Default initialization arguments.
115
+
116
+ Returns:
117
+ object: The constructed object.
118
+ """
119
+ if not isinstance(cfg, dict):
120
+ raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
121
+ if 'type' not in cfg:
122
+ if default_args is None or 'type' not in default_args:
123
+ raise KeyError(
124
+ '`cfg` or `default_args` must contain the key "type", '
125
+ f'but got {cfg}\n{default_args}')
126
+ if not isinstance(registry, Registry):
127
+ raise TypeError('registry must be an mmcv.Registry object, '
128
+ f'but got {type(registry)}')
129
+ if not (isinstance(default_args, dict) or default_args is None):
130
+ raise TypeError('default_args must be a dict or None, '
131
+ f'but got {type(default_args)}')
132
+
133
+ args = cfg.copy()
134
+
135
+ if default_args is not None:
136
+ for name, value in default_args.items():
137
+ args.setdefault(name, value)
138
+
139
+ obj_type = args.pop('type')
140
+ if isinstance(obj_type, str):
141
+ obj_cls = registry.get(obj_type)
142
+ if obj_cls is None:
143
+ raise KeyError(
144
+ f'{obj_type} is not in the {registry.name} registry')
145
+ elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
146
+ obj_cls = obj_type
147
+ else:
148
+ raise TypeError(
149
+ f'type must be a str or valid type, but got {type(obj_type)}')
150
+ try:
151
+ return obj_cls(**args)
152
+ except Exception as e:
153
+ # Normal TypeError does not print class name.
154
+ raise type(e)(f'{obj_cls.__name__}: {e}')
155
+
156
+
157
+ class Registry:
158
+ """A registry to map strings to classes or functions.
159
+
160
+ Registered object could be built from registry. Meanwhile, registered
161
+ functions could be called from registry.
162
+
163
+ Example:
164
+ >>> MODELS = Registry('models')
165
+ >>> @MODELS.register_module()
166
+ >>> class ResNet:
167
+ >>> pass
168
+ >>> resnet = MODELS.build(dict(type='ResNet'))
169
+ >>> @MODELS.register_module()
170
+ >>> def resnet50():
171
+ >>> pass
172
+ >>> resnet = MODELS.build(dict(type='resnet50'))
173
+
174
+ Please refer to
175
+ https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
176
+ advanced usage.
177
+
178
+ Args:
179
+ name (str): Registry name.
180
+ build_func(func, optional): Build function to construct instance from
181
+ Registry, func:`build_from_cfg` is used if neither ``parent`` or
182
+ ``build_func`` is specified. If ``parent`` is specified and
183
+ ``build_func`` is not given, ``build_func`` will be inherited
184
+ from ``parent``. Default: None.
185
+ parent (Registry, optional): Parent registry. The class registered in
186
+ children registry could be built from parent. Default: None.
187
+ scope (str, optional): The scope of registry. It is the key to search
188
+ for children registry. If not specified, scope will be the name of
189
+ the package where class is defined, e.g. mmdet, mmcls, mmseg.
190
+ Default: None.
191
+ """
192
+
193
+ def __init__(self, name, build_func=None, parent=None, scope=None):
194
+ self._name = name
195
+ self._module_dict = dict()
196
+ self._children = dict()
197
+ self._scope = self.infer_scope() if scope is None else scope
198
+
199
+ # self.build_func will be set with the following priority:
200
+ # 1. build_func
201
+ # 2. parent.build_func
202
+ # 3. build_from_cfg
203
+ if build_func is None:
204
+ if parent is not None:
205
+ self.build_func = parent.build_func
206
+ else:
207
+ self.build_func = build_from_cfg
208
+ else:
209
+ self.build_func = build_func
210
+ if parent is not None:
211
+ assert isinstance(parent, Registry)
212
+ parent._add_children(self)
213
+ self.parent = parent
214
+ else:
215
+ self.parent = None
216
+
217
+ def __len__(self):
218
+ return len(self._module_dict)
219
+
220
+ def __contains__(self, key):
221
+ return self.get(key) is not None
222
+
223
+ def __repr__(self):
224
+ format_str = self.__class__.__name__ + \
225
+ f'(name={self._name}, ' \
226
+ f'items={self._module_dict})'
227
+ return format_str
228
+
229
+ @staticmethod
230
+ def infer_scope():
231
+ """Infer the scope of registry.
232
+
233
+ The name of the package where registry is defined will be returned.
234
+
235
+ Example:
236
+ >>> # in mmdet/models/backbone/resnet.py
237
+ >>> MODELS = Registry('models')
238
+ >>> @MODELS.register_module()
239
+ >>> class ResNet:
240
+ >>> pass
241
+ The scope of ``ResNet`` will be ``mmdet``.
242
+
243
+ Returns:
244
+ str: The inferred scope name.
245
+ """
246
+ # We access the caller using inspect.currentframe() instead of
247
+ # inspect.stack() for performance reasons. See details in PR #1844
248
+ frame = inspect.currentframe()
249
+ # get the frame where `infer_scope()` is called
250
+ infer_scope_caller = frame.f_back.f_back
251
+ filename = inspect.getmodule(infer_scope_caller).__name__
252
+ split_filename = filename.split('.')
253
+ return split_filename[0]
254
+
255
+ @staticmethod
256
+ def split_scope_key(key):
257
+ """Split scope and key.
258
+
259
+ The first scope will be split from key.
260
+
261
+ Examples:
262
+ >>> Registry.split_scope_key('mmdet.ResNet')
263
+ 'mmdet', 'ResNet'
264
+ >>> Registry.split_scope_key('ResNet')
265
+ None, 'ResNet'
266
+
267
+ Return:
268
+ tuple[str | None, str]: The former element is the first scope of
269
+ the key, which can be ``None``. The latter is the remaining key.
270
+ """
271
+ split_index = key.find('.')
272
+ if split_index != -1:
273
+ return key[:split_index], key[split_index + 1:]
274
+ else:
275
+ return None, key
276
+
277
+ @property
278
+ def name(self):
279
+ return self._name
280
+
281
+ @property
282
+ def scope(self):
283
+ return self._scope
284
+
285
+ @property
286
+ def module_dict(self):
287
+ return self._module_dict
288
+
289
+ @property
290
+ def children(self):
291
+ return self._children
292
+
293
+ def get(self, key):
294
+ """Get the registry record.
295
+
296
+ Args:
297
+ key (str): The class name in string format.
298
+
299
+ Returns:
300
+ class: The corresponding class.
301
+ """
302
+ scope, real_key = self.split_scope_key(key)
303
+ if scope is None or scope == self._scope:
304
+ # get from self
305
+ if real_key in self._module_dict:
306
+ return self._module_dict[real_key]
307
+ else:
308
+ # get from self._children
309
+ if scope in self._children:
310
+ return self._children[scope].get(real_key)
311
+ else:
312
+ # goto root
313
+ parent = self.parent
314
+ while parent.parent is not None:
315
+ parent = parent.parent
316
+ return parent.get(key)
317
+
318
+ def build(self, *args, **kwargs):
319
+ return self.build_func(*args, **kwargs, registry=self)
320
+
321
+ def _add_children(self, registry):
322
+ """Add children for a registry.
323
+
324
+ The ``registry`` will be added as children based on its scope.
325
+ The parent registry could build objects from children registry.
326
+
327
+ Example:
328
+ >>> models = Registry('models')
329
+ >>> mmdet_models = Registry('models', parent=models)
330
+ >>> @mmdet_models.register_module()
331
+ >>> class ResNet:
332
+ >>> pass
333
+ >>> resnet = models.build(dict(type='mmdet.ResNet'))
334
+ """
335
+
336
+ assert isinstance(registry, Registry)
337
+ assert registry.scope is not None
338
+ assert registry.scope not in self.children, \
339
+ f'scope {registry.scope} exists in {self.name} registry'
340
+ self.children[registry.scope] = registry
341
+
342
+ @deprecated_api_warning(name_dict=dict(module_class='module'))
343
+ def _register_module(self, module, module_name=None, force=False):
344
+ if not inspect.isclass(module) and not inspect.isfunction(module):
345
+ raise TypeError('module must be a class or a function, '
346
+ f'but got {type(module)}')
347
+
348
+ if module_name is None:
349
+ module_name = module.__name__
350
+ if isinstance(module_name, str):
351
+ module_name = [module_name]
352
+ for name in module_name:
353
+ if not force and name in self._module_dict:
354
+ raise KeyError(f'{name} is already registered '
355
+ f'in {self.name}')
356
+ self._module_dict[name] = module
357
+
358
+ def deprecated_register_module(self, cls=None, force=False):
359
+ warnings.warn(
360
+ 'The old API of register_module(module, force=False) '
361
+ 'is deprecated and will be removed, please use the new API '
362
+ 'register_module(name=None, force=False, module=None) instead.',
363
+ DeprecationWarning)
364
+ if cls is None:
365
+ return partial(self.deprecated_register_module, force=force)
366
+ self._register_module(cls, force=force)
367
+ return cls
368
+
369
+ def register_module(self, name=None, force=False, module=None):
370
+ """Register a module.
371
+
372
+ A record will be added to `self._module_dict`, whose key is the class
373
+ name or the specified name, and value is the class itself.
374
+ It can be used as a decorator or a normal function.
375
+
376
+ Example:
377
+ >>> backbones = Registry('backbone')
378
+ >>> @backbones.register_module()
379
+ >>> class ResNet:
380
+ >>> pass
381
+
382
+ >>> backbones = Registry('backbone')
383
+ >>> @backbones.register_module(name='mnet')
384
+ >>> class MobileNet:
385
+ >>> pass
386
+
387
+ >>> backbones = Registry('backbone')
388
+ >>> class ResNet:
389
+ >>> pass
390
+ >>> backbones.register_module(ResNet)
391
+
392
+ Args:
393
+ name (str | None): The module name to be registered. If not
394
+ specified, the class name will be used.
395
+ force (bool, optional): Whether to override an existing class with
396
+ the same name. Default: False.
397
+ module (type): Module class or function to be registered.
398
+ """
399
+ if not isinstance(force, bool):
400
+ raise TypeError(f'force must be a boolean, but got {type(force)}')
401
+ # NOTE: This is a walkaround to be compatible with the old api,
402
+ # while it may introduce unexpected bugs.
403
+ if isinstance(name, type):
404
+ return self.deprecated_register_module(name, force=force)
405
+
406
+ # raise the error ahead of time
407
+ if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
408
+ raise TypeError(
409
+ 'name must be either of None, an instance of str or a sequence'
410
+ f' of str, but got {type(name)}')
411
+
412
+ # use it as a normal method: x.register_module(module=SomeClass)
413
+ if module is not None:
414
+ self._register_module(module=module, module_name=name, force=force)
415
+ return module
416
+
417
+ # use it as a decorator: @x.register_module()
418
+ def _register(module):
419
+ self._register_module(module=module, module_name=name, force=force)
420
+ return module
421
+
422
+ return _register
pipeline_video/data_utils/xgpt3_dataset.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import random
5
+ import re
6
+ import time
7
+ import traceback
8
+ import warnings
9
+ from io import BytesIO
10
+ import pandas as pd
11
+ import h5py
12
+ import numpy as np
13
+ import torch
14
+ from icecream import ic
15
+ from PIL import Image, ImageFile
16
+ from torch.utils.data import Dataset, Subset
17
+
18
+ from utils import get_args
19
+
20
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
21
+ ImageFile.MAX_IMAGE_PIXELS = None
22
+ Image.MAX_IMAGE_PIXELS = None
23
+
24
+ logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
25
+ datefmt='%m/%d/%Y %H:%M:%S',
26
+ level=logging.INFO)
27
+ warnings.filterwarnings("ignore")
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ def load_jsonl(filename):
32
+ with open(filename, "r", encoding="utf-8") as f:
33
+ return [json.loads(l.strip("\n")) for l in f.readlines()]
34
+
35
+
36
+ class MultiModalDataset(Dataset):
37
+ """MultiModal dataset"""
38
+
39
+ def __init__(self, input_file, tokenizer, processor,
40
+ max_length=2048,
41
+ media_tokens=['<image>', '<|video|>'], loss_objective = 'sequential'):
42
+
43
+ args = get_args()
44
+
45
+ self.loss_objective = loss_objective
46
+ if 'sequential' in self.loss_objective:
47
+ self.dataset = pd.read_csv(input_file)
48
+ self.dataset = self.dataset.dropna()
49
+ else:
50
+ raise NotImplementedError('dataset loader not implemented for other loss objectives')
51
+
52
+ self.dataset = pd.read_csv(input_file)
53
+ self.tokenizer = tokenizer
54
+ self.max_length = max_length
55
+ self.processor = processor
56
+ self.media_tokens = {k: -int(i+1) for i, k in enumerate(media_tokens)}
57
+ self.media_lengths = {'<image>': 1+64,'<|video|>': 1+64}
58
+ print("num_media_token: ", self.media_lengths)
59
+ print(len(self.dataset))
60
+ self.bucket = {}
61
+
62
+ def __len__(self):
63
+ return len(self.dataset)
64
+
65
+ def __getitem__(self, index):
66
+
67
+ data = self.dataset.iloc[index]
68
+ videopath = data['videopath']
69
+ caption = data['caption']
70
+ video_input = self.processor(videos=[videopath], num_frames=32, return_tensors='pt') # video_pixel_values
71
+ text_input = self._extract_text_token_from_conversation(caption, self.max_length, index)
72
+ item = {'video': video_input, 'text': text_input, 'videopath': videopath, 'caption': caption}
73
+ return item
74
+
75
+ def _extract_text_token_from_conversation(self, data, max_length, index):
76
+ # output enc_chunk
77
+ enc_chunk = []
78
+
79
+ if self.tokenizer.bos_token_id > 0:
80
+ prompt_chunk = [self.tokenizer.bos_token_id]
81
+ else:
82
+ prompt_chunk = []
83
+
84
+ # conversation = data["completion"]
85
+ conversation = data
86
+
87
+ # For Text only data
88
+ if all([media_token not in conversation for media_token in self.media_tokens.keys()]):
89
+ pattern = '|'.join(map(re.escape, ['AI: ', '\nHuman: ']))
90
+ chunk_strs = re.split(f'({pattern})', conversation)
91
+ prompt_length = -1
92
+ stop_flag = False
93
+ for idx, chunk_str in enumerate(chunk_strs):
94
+ if idx == 0:
95
+ enc_chunk = prompt_chunk + \
96
+ self.tokenizer(chunk_str, add_special_tokens=False)[
97
+ 'input_ids']
98
+ enc_length = len(enc_chunk)
99
+ label_chunk = [0] * enc_length
100
+ else:
101
+ if chunk_strs[idx-1] == 'AI: ':
102
+ curr_chunk = self.tokenizer(
103
+ chunk_str, add_special_tokens=False)['input_ids']
104
+ if enc_length + len(curr_chunk) >= max_length:
105
+ curr_chunk = curr_chunk[:max_length-enc_length]
106
+ stop_flag = True
107
+ curr_chunk += [self.tokenizer.eos_token_id]
108
+ enc_length += len(curr_chunk)
109
+ enc_chunk += curr_chunk
110
+ label_chunk += [1] * len(curr_chunk)
111
+ else:
112
+ curr_chunk = self.tokenizer(
113
+ chunk_str, add_special_tokens=False)['input_ids']
114
+ if enc_length + len(curr_chunk) >= max_length + 1:
115
+ curr_chunk = curr_chunk[:max_length+1-enc_length]
116
+ stop_flag = True
117
+ enc_length += len(curr_chunk)
118
+ enc_chunk += curr_chunk
119
+ label_chunk += [0] * len(curr_chunk)
120
+ if stop_flag:
121
+ break
122
+
123
+ # For Image-Text Data
124
+ else:
125
+ enc_length = 0
126
+ prompt_length = -2
127
+ pattern = '|'.join(
128
+ map(re.escape, list(self.media_tokens.keys()) + ['AI: ', '\nHuman: ']))
129
+ chunk_strs = re.split(f'({pattern})', conversation)
130
+ chunk_strs = [x for x in chunk_strs if len(x) > 0]
131
+ for idx, chunk_str in enumerate(chunk_strs):
132
+ if enc_length >= max_length + 1:
133
+ break
134
+
135
+ if idx == 0:
136
+ enc_chunk = prompt_chunk + \
137
+ self.tokenizer(chunk_str, add_special_tokens=False)[
138
+ 'input_ids']
139
+ enc_length = len(enc_chunk)
140
+ label_chunk = [0] * enc_length
141
+ else:
142
+ if chunk_str in self.media_tokens:
143
+ # [CLS] + 256 + [EOS]
144
+ if enc_length + self.media_lengths[chunk_str] > max_length + 1:
145
+ break
146
+ else:
147
+ enc_chunk += [self.media_tokens[chunk_str]
148
+ ] * self.media_lengths[chunk_str]
149
+ enc_length += self.media_lengths[chunk_str]
150
+ label_chunk += [0] * self.media_lengths[chunk_str]
151
+ else:
152
+
153
+ if chunk_strs[idx-1] == 'AI: ':
154
+ curr_chunk = self.tokenizer(
155
+ chunk_str, add_special_tokens=False)['input_ids']
156
+ if enc_length + len(curr_chunk) >= max_length:
157
+ curr_chunk = curr_chunk[:max_length-enc_length]
158
+ curr_chunk += [self.tokenizer.eos_token_id]
159
+ enc_length += len(curr_chunk)
160
+ enc_chunk += curr_chunk
161
+ label_chunk += [1] * len(curr_chunk)
162
+ else:
163
+ curr_chunk = self.tokenizer(
164
+ chunk_str, add_special_tokens=False)['input_ids']
165
+ if enc_length + len(curr_chunk) >= max_length + 1:
166
+ curr_chunk = curr_chunk[:max_length +
167
+ 1-enc_length]
168
+ enc_length += len(curr_chunk)
169
+ enc_chunk += curr_chunk
170
+ label_chunk += [0] * len(curr_chunk)
171
+
172
+ if enc_length < max_length + 1:
173
+ padding_chunk = [self.tokenizer.pad_token_id] * \
174
+ (max_length + 1 - enc_length)
175
+ padding_length = len(padding_chunk)
176
+ label_chunk += [0] * (max_length + 1 - enc_length)
177
+ enc_chunk = enc_chunk + padding_chunk
178
+ else:
179
+ padding_length = 0
180
+
181
+ assert enc_length + padding_length == max_length + \
182
+ 1, (index, prompt_length, enc_length,
183
+ padding_length, max_length + 1)
184
+ assert len(label_chunk) == max_length + \
185
+ 1, (len(label_chunk), max_length + 1)
186
+ non_padding_mask = [1 if i < enc_length -
187
+ 1 else 0 for i in range(max_length)]
188
+
189
+ enc_chunk = torch.tensor(enc_chunk).long()
190
+ non_padding_mask = torch.tensor(non_padding_mask).long()
191
+ prompt_mask = torch.tensor(label_chunk)[1:].long()
192
+ prompt_length = torch.tensor([prompt_length]).long()
193
+
194
+ # Create loss mask
195
+ if all([media_token not in conversation for media_token in self.media_tokens.keys()]):
196
+ non_media_mask = torch.ones_like(non_padding_mask).long()
197
+ else:
198
+ tmp_enc_chunk = enc_chunk.clone()
199
+ tmp_enc_chunk[tmp_enc_chunk >= 0] = 1
200
+ tmp_enc_chunk[tmp_enc_chunk < 0] = 0
201
+ non_media_mask = torch.tensor(tmp_enc_chunk).long()
202
+ non_media_mask = non_media_mask[1:].long()
203
+ return {'input_ids': enc_chunk, "prompt_length": prompt_length, 'seq_length': enc_length,
204
+ "non_padding_mask": non_padding_mask, 'non_media_mask': non_media_mask, 'prompt_mask': prompt_mask}
pipeline_video/entailment_inference.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import csv
3
+ import json
4
+ import torch
5
+ import argparse
6
+ import pandas as pd
7
+ import torch.nn as nn
8
+ from tqdm import tqdm
9
+ from collections import defaultdict
10
+ from transformers.models.llama.tokenization_llama import LlamaTokenizer
11
+ from torch.utils.data import DataLoader
12
+ from mplug_owl_video.modeling_mplug_owl import MplugOwlForConditionalGeneration
13
+ from mplug_owl_video.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor
14
+ from peft import LoraConfig, get_peft_model
15
+ from data_utils.xgpt3_dataset import MultiModalDataset
16
+ from utils import batchify
17
+
18
+ parser = argparse.ArgumentParser()
19
+
20
+ parser.add_argument('--input_csv', type = str, required = True, help = 'input json file')
21
+ parser.add_argument('--output_csv', type = str, help = 'output csv with scores')
22
+ parser.add_argument('--pretrained_ckpt', type = str, required = True, help = 'pretrained ckpt')
23
+ parser.add_argument('--trained_ckpt', type = str, help = 'trained ckpt')
24
+ parser.add_argument('--lora_r', type = int, default = 32)
25
+ parser.add_argument('--use_lora', action = 'store_true', help = 'lora model')
26
+ parser.add_argument('--all-params', action = 'store_true', help = 'use all params of the model')
27
+ parser.add_argument('--batch_size', type = int, default = 32)
28
+
29
+ args = parser.parse_args()
30
+ softmax = nn.Softmax(dim=2)
31
+
32
+ def get_entail(logits, input_ids, tokenizer):
33
+ logits = softmax(logits)
34
+ token_id_yes = tokenizer.encode('Yes', add_special_tokens = False)[0]
35
+ token_id_no = tokenizer.encode('No', add_special_tokens = False)[0]
36
+ entailment = []
37
+ for j in range(len(logits)):
38
+ for i in range(len(input_ids[j])):
39
+ if input_ids[j][i] == tokenizer.pad_token_id: # pad token if the answer is not present
40
+ i = i - 1
41
+ break
42
+ elif i == len(input_ids[j]) - 1:
43
+ break
44
+ score = logits[j][i][token_id_yes] / (logits[j][i][token_id_yes] + logits[j][i][token_id_no])
45
+ entailment.append(score)
46
+ entailment = torch.stack(entailment)
47
+ return entailment
48
+
49
+ def get_scores(model, tokenizer, dataloader):
50
+
51
+ with torch.no_grad():
52
+ for index, inputs in tqdm(enumerate(dataloader)):
53
+ for k, v in inputs.items():
54
+ if torch.is_tensor(v):
55
+ if v.dtype == torch.float:
56
+ inputs[k] = v.bfloat16()
57
+ inputs[k] = inputs[k].to(model.device)
58
+ outputs = model(pixel_values = inputs['pixel_values'], video_pixel_values = inputs['video_pixel_values'], labels = None, \
59
+ num_images = inputs['num_images'], num_videos = inputs['num_videos'], input_ids = inputs['input_ids'], non_padding_mask = inputs['non_padding_mask'], \
60
+ non_media_mask = inputs['non_media_mask'], prompt_mask = inputs['prompt_mask'])
61
+ logits = outputs['logits']
62
+ entail_scores = get_entail(logits, inputs['input_ids'], tokenizer)
63
+ for m in range(len(entail_scores)):
64
+ with open(args.output_csv, 'a') as f:
65
+ writer = csv.writer(f)
66
+ writer.writerow([inputs['videopaths'][m], inputs['captions'][m], entail_scores[m].item()])
67
+ print(f"Batch {index} Done")
68
+
69
+ def main():
70
+
71
+ pretrained_ckpt = args.pretrained_ckpt
72
+
73
+ # Processors
74
+ tokenizer = LlamaTokenizer.from_pretrained(pretrained_ckpt)
75
+ image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt)
76
+ processor = MplugOwlProcessor(image_processor, tokenizer)
77
+
78
+ valid_data = MultiModalDataset(args.input_csv, tokenizer, processor, max_length = 256, loss_objective = 'sequential')
79
+ dataloader = DataLoader(valid_data, batch_size=args.batch_size, pin_memory=True, collate_fn=batchify)
80
+
81
+ # Instantiate model
82
+ model = MplugOwlForConditionalGeneration.from_pretrained(
83
+ pretrained_ckpt,
84
+ torch_dtype=torch.bfloat16,
85
+ device_map={'':0}
86
+ )
87
+
88
+ if args.use_lora:
89
+ for name, param in model.named_parameters():
90
+ param.requires_grad = False
91
+ if args.all_params:
92
+ peft_config = LoraConfig(
93
+ target_modules=r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj|gate_proj|down_proj|up_proj)',
94
+ inference_mode=True,
95
+ r=args.lora_r,
96
+ lora_alpha=16,
97
+ lora_dropout=0.05
98
+ )
99
+ else:
100
+ peft_config = LoraConfig(
101
+ target_modules=r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj)',
102
+ inference_mode=True,
103
+ r=args.lora_r,
104
+ lora_alpha=16,
105
+ lora_dropout=0.05
106
+ )
107
+
108
+ model = get_peft_model(model, peft_config)
109
+ model.print_trainable_parameters()
110
+
111
+ with open(args.trained_ckpt, 'rb') as f:
112
+ ckpt = torch.load(f, map_location = torch.device(f"cuda:0"))
113
+ model.load_state_dict(ckpt)
114
+ model = model.to(torch.bfloat16)
115
+ print('Model Loaded')
116
+
117
+ model.eval()
118
+
119
+ get_scores(model, tokenizer, dataloader)
120
+
121
+ if __name__ == "__main__":
122
+ main()
pipeline_video/mplug_owl_video/__init__.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from transformers.utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
17
+
18
+
19
+ _import_structure = {
20
+ "configuration_mplug_owl": ["MPLUG_OWL_PRETRAINED_CONFIG_ARCHIVE_MAP", "MplugOwlConfig"],
21
+ "processing_mplug_owl": ["MplugOwlImageProcessor", "MplugOwlProcessor"],
22
+ "tokenization_mplug_owl": ["MplugOwlTokenizer"],
23
+ }
24
+
25
+ try:
26
+ if not is_tokenizers_available():
27
+ raise OptionalDependencyNotAvailable()
28
+ except OptionalDependencyNotAvailable:
29
+ pass
30
+
31
+
32
+ try:
33
+ if not is_torch_available():
34
+ raise OptionalDependencyNotAvailable()
35
+ except OptionalDependencyNotAvailable:
36
+ pass
37
+ else:
38
+ _import_structure["modeling_mplug_owl"] = [
39
+ "MPLUG_OWL_PRETRAINED_MODEL_ARCHIVE_LIST",
40
+ "MplugOwlForConditionalGeneration",
41
+ "MplugOwlModel",
42
+ ]
43
+
44
+
45
+ if TYPE_CHECKING:
46
+ from .configuration_mplug_owl import MPLUG_OWL_PRETRAINED_CONFIG_ARCHIVE_MAP, MplugOwlConfig
47
+ from .tokenization_mplug_owl import MplugOwlTokenizer
48
+
49
+ try:
50
+ if not is_tokenizers_available():
51
+ raise OptionalDependencyNotAvailable()
52
+ except OptionalDependencyNotAvailable:
53
+ pass
54
+
55
+ try:
56
+ if not is_torch_available():
57
+ raise OptionalDependencyNotAvailable()
58
+ except OptionalDependencyNotAvailable:
59
+ pass
60
+ else:
61
+ from .modeling_mplug_owl import (
62
+ MPLUG_OWL_PRETRAINED_MODEL_ARCHIVE_LIST,
63
+ MplugOwlForConditionalGeneration,
64
+ MplugOwlModel,
65
+ MplugOwlPreTrainedModel,
66
+ )
67
+
68
+
69
+ else:
70
+ import sys
71
+
72
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
73
+
74
+ from .configuration_mplug_owl import *
75
+ from .modeling_mplug_owl import *
76
+ from .processing_mplug_owl import *
77
+ from .tokenization_mplug_owl import *
pipeline_video/mplug_owl_video/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.29 kB). View file
 
pipeline_video/mplug_owl_video/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (1.26 kB). View file
 
pipeline_video/mplug_owl_video/__pycache__/configuration_mplug_owl.cpython-310.pyc ADDED
Binary file (10.5 kB). View file
 
pipeline_video/mplug_owl_video/__pycache__/configuration_mplug_owl.cpython-39.pyc ADDED
Binary file (10.5 kB). View file
 
pipeline_video/mplug_owl_video/__pycache__/modeling_mplug_owl.cpython-310.pyc ADDED
Binary file (57.1 kB). View file
 
pipeline_video/mplug_owl_video/__pycache__/modeling_mplug_owl.cpython-39.pyc ADDED
Binary file (57 kB). View file
 
pipeline_video/mplug_owl_video/__pycache__/processing_mplug_owl.cpython-310.pyc ADDED
Binary file (7.32 kB). View file
 
pipeline_video/mplug_owl_video/__pycache__/processing_mplug_owl.cpython-39.pyc ADDED
Binary file (7.36 kB). View file
 
pipeline_video/mplug_owl_video/__pycache__/tokenization_mplug_owl.cpython-310.pyc ADDED
Binary file (1.31 kB). View file
 
pipeline_video/mplug_owl_video/__pycache__/tokenization_mplug_owl.cpython-39.pyc ADDED
Binary file (1.28 kB). View file
 
pipeline_video/mplug_owl_video/configuration_mplug_owl.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 x-plug and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ MplugOwl model configuration """
16
+ import copy
17
+ import os
18
+ from typing import Union
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
22
+ from transformers.utils import logging
23
+ from transformers.models.auto import CONFIG_MAPPING
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+ MPLUG_OWL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
29
+ "MAGAer13/mplug-owl-llama-7b": "https://huggingface.co/MAGAer13/mplug-owl-llama-7b/resolve/main/config.json",
30
+ # See all MplugOwl models at https://huggingface.co/models?filter=mplug_owl
31
+ }
32
+
33
+
34
+ class MplugOwlVisionConfig(PretrainedConfig):
35
+ r"""
36
+ This is the configuration class to store the configuration of a [`MplugOwlVisionModel`]. It is used to instantiate a
37
+ mPLUG-Owl vision encoder according to the specified arguments, defining the model architecture. Instantiating a
38
+ configuration defaults will yield a similar configuration to that of the mPLUG-Owl
39
+ [x-plug/x_plug-llama-7b](https://huggingface.co/x-plug/x_plug-llama-7b) architecture.
40
+
41
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
42
+ documentation from [`PretrainedConfig`] for more information.
43
+
44
+ Args:
45
+ hidden_size (`int`, *optional*, defaults to 768):
46
+ Dimensionality of the encoder layers and the pooler layer.
47
+ intermediate_size (`int`, *optional*, defaults to 3072):
48
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
49
+ num_hidden_layers (`int`, *optional*, defaults to 12):
50
+ Number of hidden layers in the Transformer encoder.
51
+ num_attention_heads (`int`, *optional*, defaults to 12):
52
+ Number of attention heads for each attention layer in the Transformer encoder.
53
+ image_size (`int`, *optional*, defaults to 224):
54
+ The size (resolution) of each image.
55
+ patch_size (`int`, *optional*, defaults to 32):
56
+ The size (resolution) of each patch.
57
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
58
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
59
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
60
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
61
+ The epsilon used by the layer normalization layers.
62
+ attention_dropout (`float`, *optional*, defaults to 0.0):
63
+ The dropout ratio for the attention probabilities.
64
+ initializer_range (`float`, *optional*, defaults to 0.02):
65
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
66
+ initializer_factor (`float`, *optional*, defaults to 1):
67
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
68
+ testing).
69
+
70
+
71
+ ```"""
72
+
73
+ model_type = "mplug_owl_vision_model"
74
+
75
+ def __init__(
76
+ self,
77
+ hidden_size=1024,
78
+ intermediate_size=4096,
79
+ projection_dim=768,
80
+ num_hidden_layers=24,
81
+ num_attention_heads=16,
82
+ num_channels=3,
83
+ image_size=224,
84
+ patch_size=14,
85
+ hidden_act="quick_gelu",
86
+ layer_norm_eps=1e-6,
87
+ attention_dropout=0.0,
88
+ initializer_range=0.02,
89
+ initializer_factor=1.0,
90
+ use_flash_attn=False,
91
+ **kwargs,
92
+ ):
93
+ super().__init__(**kwargs)
94
+ self.hidden_size = hidden_size
95
+ self.intermediate_size = intermediate_size
96
+ self.projection_dim = projection_dim
97
+ self.num_hidden_layers = num_hidden_layers
98
+ self.num_attention_heads = num_attention_heads
99
+ self.num_channels = num_channels
100
+ self.patch_size = patch_size
101
+ self.image_size = image_size
102
+ self.initializer_range = initializer_range
103
+ self.initializer_factor = initializer_factor
104
+ self.attention_dropout = attention_dropout
105
+ self.layer_norm_eps = layer_norm_eps
106
+ self.hidden_act = hidden_act
107
+ self.use_flash_attn = use_flash_attn
108
+
109
+ @classmethod
110
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
111
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
112
+
113
+ # get the vision config dict if we are loading from MplugOwlConfig
114
+ if config_dict.get("model_type") == "mplug-owl":
115
+ config_dict = config_dict["vision_config"]
116
+
117
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
118
+ logger.warning(
119
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
120
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
121
+ )
122
+
123
+ return cls.from_dict(config_dict, **kwargs)
124
+
125
+
126
+ class MplugOwlVisualAbstractorConfig(PretrainedConfig):
127
+ model_type = "mplug_owl_visual_abstract"
128
+
129
+ def __init__(
130
+ self,
131
+ hidden_size=1024, #
132
+ num_hidden_layers=6, #
133
+ num_attention_heads=16, #
134
+ intermediate_size=4096, #
135
+ attention_probs_dropout_prob=0.1, #
136
+ initializer_range=0.02,
137
+ layer_norm_eps=1e-6, #
138
+ encoder_hidden_size=1024, #
139
+ **kwargs,
140
+ ):
141
+ super().__init__(**kwargs)
142
+ self.hidden_size = hidden_size
143
+ self.num_hidden_layers = num_hidden_layers
144
+ self.num_attention_heads = num_attention_heads
145
+ self.intermediate_size = intermediate_size
146
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
147
+ self.initializer_range = initializer_range
148
+ self.layer_norm_eps = layer_norm_eps
149
+ self.encoder_hidden_size = encoder_hidden_size
150
+
151
+ @classmethod
152
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
153
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
154
+
155
+ # get the visual_abstractor config dict if we are loading from MplugOwlConfig
156
+ if config_dict.get("model_type") == "mplug-owl":
157
+ config_dict = config_dict["abstractor_config"]
158
+
159
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
160
+ logger.warning(
161
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
162
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
163
+ )
164
+
165
+ return cls.from_dict(config_dict, **kwargs)
166
+
167
+
168
+ class MplugOwlConfig(PretrainedConfig):
169
+ r"""
170
+ [`MplugOwlConfig`] is the configuration class to store the configuration of a [`MplugOwlForConditionalGeneration`]. It is
171
+ used to instantiate a mPLUG-Owl model according to the specified arguments, defining the vision model, Q-Former model
172
+ and language model configs. Instantiating a configuration with the defaults will yield a similar configuration to
173
+ that of the mPLUG-Owl [x-plug/x_plug-llama-7b](https://huggingface.co/x-plug/x_plug-llama-7b) architecture.
174
+
175
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
176
+ documentation from [`PretrainedConfig`] for more information.
177
+
178
+ Args:
179
+ vision_config (`dict`, *optional*):
180
+ Dictionary of configuration options used to initialize [`MplugOwlVisionConfig`].
181
+ visual_abstractor_config (`dict`, *optional*):
182
+ Dictionary of configuration options used to initialize [`MplugOwlVisualAbstractorConfig`].
183
+ text_config (`dict`, *optional*):
184
+ Dictionary of configuration options used to initialize any [`PretrainedConfig`].
185
+ num_query_tokens (`int`, *optional*, defaults to 32):
186
+ The number of query tokens passed through the Transformer.
187
+
188
+ kwargs (*optional*):
189
+ Dictionary of keyword arguments.
190
+
191
+ Example:
192
+
193
+ ```python
194
+ >>> from transformers import (
195
+ ... MplugOwlVisionConfig,
196
+ ... MplugOwlVisualAbstractorConfig,
197
+ ... OPTConfig,
198
+ ... MplugOwlConfig,
199
+ ... MplugOwlForConditionalGeneration,
200
+ ... )
201
+
202
+ >>> # Initializing a MplugOwlConfig with x-plug/x_plug-llama-7b style configuration
203
+ >>> configuration = MplugOwlConfig()
204
+
205
+ >>> # Initializing a MplugOwlForConditionalGeneration (with random weights) from the x-plug/x_plug-llama-7b style configuration
206
+ >>> model = MplugOwlForConditionalGeneration(configuration)
207
+
208
+ >>> # Accessing the model configuration
209
+ >>> configuration = model.config
210
+
211
+ >>> # We can also initialize a MplugOwlConfig from a MplugOwlVisionConfig, MplugOwlVisualAbstractorConfig and any PretrainedConfig
212
+
213
+ >>> # Initializing mPLUG-Owl vision, mPLUG-Owl Q-Former and language model configurations
214
+ >>> vision_config = MplugOwlVisionConfig()
215
+ >>> visual_abstractor_config = MplugOwlVisualAbstractorConfig()
216
+ >>> text_config = OPTConfig()
217
+
218
+ >>> config = MplugOwlConfig.from_text_vision_configs(vision_config, visual_abstractor_config, text_config)
219
+ ```"""
220
+ model_type = "mplug-owl"
221
+ is_composition = True
222
+
223
+ def __init__(
224
+ self, vision_config=None, visual_abstractor_config=None, text_config=None, num_query_tokens=64, **kwargs
225
+ ):
226
+ super().__init__(**kwargs)
227
+ if vision_config is None:
228
+ vision_config = MplugOwlVisionConfig().to_dict()
229
+ logger.info("vision_config is None.")
230
+
231
+ if visual_abstractor_config is None:
232
+ visual_abstractor_config = {}
233
+ logger.info("abstractor_config is None. ")
234
+
235
+ if text_config is None:
236
+ # we use LLAMA 7b by default
237
+ from ..llama.configuration_llama import LlamaConfig
238
+
239
+ text_config = LlamaConfig(pad_token_id=2).to_dict()
240
+ logger.info("text_config is None.")
241
+
242
+ self.vision_config = MplugOwlVisionConfig(**vision_config)
243
+ self.visual_abstractor_config = MplugOwlVisualAbstractorConfig(**visual_abstractor_config)
244
+ # self.visual_abstractor_config.layer_norm_eps = 1e-6
245
+ text_model_type = text_config["model_type"] if "model_type" in text_config else "llama"
246
+ self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
247
+
248
+ self.tie_word_embeddings = self.text_config.tie_word_embeddings
249
+ self.is_encoder_decoder = self.text_config.is_encoder_decoder
250
+
251
+ self.num_query_tokens = num_query_tokens
252
+ # self.visual_abstractor_config.encoder_hidden_size = self.vision_config.hidden_size
253
+ self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
254
+ self.initializer_factor = 1.0
255
+ self.initializer_range = 0.02
256
+
257
+ for attr in dir(self.text_config):
258
+ if not hasattr(self, attr):
259
+ setattr(self, attr, getattr(self.text_config, attr))
260
+
261
+ @classmethod
262
+ def from_vision_visual_abstractor_text_configs(
263
+ cls,
264
+ vision_config: MplugOwlVisionConfig,
265
+ visual_abstractor_config: MplugOwlVisualAbstractorConfig,
266
+ text_config: PretrainedConfig,
267
+ **kwargs,
268
+ ):
269
+ r"""
270
+ Instantiate a [`MplugOwlConfig`] (or a derived class) from a mPLUG-Owl vision model, Q-Former and language model
271
+ configurations.
272
+
273
+ Returns:
274
+ [`MplugOwlConfig`]: An instance of a configuration object
275
+ """
276
+
277
+ return cls(
278
+ vision_config=vision_config.to_dict(),
279
+ visual_abstractor_config=visual_abstractor_config.to_dict(),
280
+ text_config=text_config.to_dict(),
281
+ **kwargs,
282
+ )
283
+
284
+ def to_dict(self):
285
+ """
286
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
287
+
288
+ Returns:
289
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
290
+ """
291
+ output = copy.deepcopy(self.__dict__)
292
+ output["vision_config"] = self.vision_config.to_dict()
293
+ output["visual_abstractor_config"] = self.visual_abstractor_config.to_dict()
294
+ output["text_config"] = self.text_config.to_dict()
295
+ output["model_type"] = self.__class__.model_type
296
+ return output
pipeline_video/mplug_owl_video/modeling_mplug_owl.py ADDED
@@ -0,0 +1,1938 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 x-plug The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch MplugOwl model. """
16
+
17
+ import logging
18
+ import math
19
+ from typing import Any, Optional, Tuple, Union
20
+
21
+ try:
22
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_func
23
+
24
+ flash_attn_func = flash_attn_unpadded_func
25
+ except:
26
+ flash_attn_func = None
27
+ print("install flash-attn first.")
28
+ import math
29
+ from dataclasses import dataclass
30
+ from typing import Any, Optional, Tuple, Union
31
+
32
+ import torch
33
+ import torch.utils.checkpoint
34
+ from torch import nn
35
+ import einops
36
+
37
+ from transformers.modeling_outputs import (
38
+ BaseModelOutput,
39
+ BaseModelOutputWithPooling,
40
+ BaseModelOutputWithPastAndCrossAttentions
41
+ )
42
+ from transformers.modeling_utils import PreTrainedModel
43
+ from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
44
+ from transformers.utils import (
45
+ ModelOutput,
46
+ add_start_docstrings,
47
+ add_start_docstrings_to_model_forward,
48
+ logging,
49
+ replace_return_docstrings,
50
+ )
51
+ from transformers.models.auto import AutoModelForCausalLM
52
+ from .configuration_mplug_owl import MplugOwlConfig, MplugOwlVisionConfig, MplugOwlVisualAbstractorConfig
53
+
54
+
55
+ logger = logging.get_logger(__name__)
56
+
57
+ _CHECKPOINT_FOR_DOC = "MAGAer13/mplug-owl-llama-7b"
58
+ _CONFIG_FOR_DOC = "MplugOwlConfig"
59
+
60
+
61
+ MPLUG_OWL_PRETRAINED_MODEL_ARCHIVE_LIST = [
62
+ "MAGAer13/mplug-owl-llama-7b",
63
+ # See all MplugOwl models at https://huggingface.co/models?filter=mplug_owl
64
+ ]
65
+
66
+
67
+ @dataclass
68
+ class MplugOwlForConditionalGenerationModelOutput(ModelOutput):
69
+ """
70
+ Class defining the outputs of [`MPlugOwlForConditionalGeneration`].
71
+
72
+ Args:
73
+ loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
74
+ Language modeling loss from the language model.
75
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
76
+ Prediction scores of the language modeling head of the language model.
77
+ vision_outputs (`BaseModelOutputWithPooling`):
78
+ Outputs of the vision encoder.
79
+
80
+ language_model_outputs (`CausalLMOutputWithPast` or `Seq2SeqLMOutput`):
81
+ Outputs of the language model.
82
+ """
83
+
84
+ loss: Optional[Tuple[torch.FloatTensor]] = None
85
+ logits: Optional[Tuple[torch.FloatTensor]] = None
86
+ vision_outputs: Optional[torch.FloatTensor] = None
87
+ language_model_outputs: Optional[Tuple[torch.FloatTensor]] = None
88
+
89
+ def to_tuple(self) -> Tuple[Any]:
90
+ return tuple(
91
+ self[k] if k not in ["vision_outputs", "language_model_outputs"] else getattr(self, k).to_tuple()
92
+ for k in self.keys()
93
+ )
94
+
95
+
96
+ def get_ltor_masks_and_position_ids_from_embeddings(data):
97
+ """Build masks and position id for left to right model."""
98
+
99
+ # Extract batch size and sequence length.
100
+ micro_batch_size, seq_length = data.size()[:2]
101
+
102
+ # Attention mask (lower triangular).
103
+ att_mask_batch = 1
104
+ attention_mask = torch.tril(torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)).view(
105
+ att_mask_batch, 1, seq_length, seq_length
106
+ )
107
+
108
+ # Loss mask.
109
+ loss_mask = torch.ones(data.size()[:2], dtype=torch.float, device=data.device)
110
+
111
+ # Position ids.
112
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
113
+ position_ids = position_ids.unsqueeze(0).expand_as(data[..., 0])
114
+
115
+ # Convert attention mask to binary:
116
+ attention_mask = attention_mask < 0.5
117
+
118
+ return attention_mask, loss_mask, position_ids
119
+
120
+
121
+ class MplugOwlVisionEmbeddings(nn.Module):
122
+ def __init__(self, config: MplugOwlVisionConfig):
123
+ super().__init__()
124
+ self.config = config
125
+ self.hidden_size = config.hidden_size
126
+ self.image_size = config.image_size
127
+ self.patch_size = config.patch_size
128
+
129
+ self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size))
130
+
131
+ self.patch_embed = nn.Conv2d(
132
+ in_channels=3,
133
+ out_channels=self.hidden_size,
134
+ kernel_size=self.patch_size,
135
+ stride=self.patch_size,
136
+ bias=False,
137
+ )
138
+
139
+ self.num_patches = (self.image_size // self.patch_size) ** 2
140
+
141
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, self.hidden_size))
142
+
143
+ self.pre_layernorm = LayerNormFp32(self.hidden_size, eps=config.layer_norm_eps)
144
+
145
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
146
+ # [B, C, T, H, W] or [B, C, H, W]
147
+ batch_size = pixel_values.size(0)
148
+ T = pixel_values.size(2) if pixel_values.dim() > 4 else 1
149
+ if T > 1:
150
+ pixel_values = einops.rearrange(pixel_values, 'b c t h w -> (b t) c h w')
151
+ image_embeds = self.patch_embed(pixel_values)
152
+ image_embeds = image_embeds.flatten(2).transpose(1, 2)
153
+
154
+ class_embeds = self.cls_token.expand(batch_size * T, 1, -1).to(image_embeds.dtype)
155
+ embeddings = torch.cat([class_embeds, image_embeds], dim=1)
156
+ embeddings = embeddings + self.position_embedding[:, : embeddings.size(1)].to(image_embeds.dtype)
157
+ embeddings = self.pre_layernorm(embeddings)
158
+ embeddings = einops.rearrange(embeddings, '(b t) n d -> b t n d', b=batch_size)
159
+ return embeddings
160
+
161
+
162
+ class LayerNormFp32(nn.LayerNorm):
163
+ """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
164
+
165
+ def __init__(self, *args, **kwargs):
166
+ super().__init__(*args, **kwargs)
167
+
168
+ def forward(self, x: torch.Tensor):
169
+ output = torch.nn.functional.layer_norm(
170
+ x.float(),
171
+ self.normalized_shape,
172
+ self.weight.float() if self.weight is not None else None,
173
+ self.bias.float() if self.bias is not None else None,
174
+ self.eps,
175
+ )
176
+ return output.type_as(x)
177
+
178
+
179
+ class QuickGELU(nn.Module):
180
+ def forward(self, x: torch.Tensor):
181
+ return x * torch.sigmoid(1.702 * x)
182
+
183
+
184
+ class MplugOwlVisionLocalTemporal(nn.Module):
185
+ def __init__(self, config):
186
+ super(MplugOwlVisionLocalTemporal, self).__init__()
187
+
188
+ self.image_size = config.image_size
189
+ self.patch_size = config.patch_size
190
+ self.num_patches = 1 + (self.image_size // self.patch_size) ** 2
191
+ self.hidden_size = config.hidden_size
192
+ d_bottleneck = self.hidden_size // 2
193
+
194
+ self.ln = LayerNormFp32(self.hidden_size)
195
+ self.down_proj = nn.Conv3d(self.hidden_size, d_bottleneck, kernel_size=1, stride=1, padding=0)
196
+ self.conv = nn.Conv3d(d_bottleneck, d_bottleneck, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), groups=d_bottleneck)
197
+ self.up_proj = nn.Conv3d(d_bottleneck, self.hidden_size, kernel_size=1, stride=1, padding=0)
198
+
199
+ nn.init.constant_(self.up_proj.weight, 0)
200
+ nn.init.constant_(self.up_proj.bias, 0)
201
+
202
+ self.activation_func = QuickGELU()
203
+
204
+ def forward(self, x):
205
+ # [b, t, s, c]
206
+ T = x.size(1)
207
+ H = int((self.num_patches - 1)**0.5)
208
+ cls_token, x = x[:, :, 0:1], x[:, :, 1:]
209
+ x = self.ln(x)
210
+ x = einops.rearrange(x, 'b t (h w) c -> b c t h w', h=H)
211
+ x = self.down_proj(x)
212
+ _device = x.device
213
+ self = self.to('cpu') # hack: cpu offloading since bfloat16 on gpu gives error with conv_depthwise3d
214
+ x = x.to('cpu')
215
+ x = self.conv(x)
216
+ self = self.to(_device)
217
+ x = x.to(_device)
218
+ x = self.activation_func(x)
219
+ x = self.up_proj(x)
220
+ x = einops.rearrange(x, 'b c t h w -> b t (h w) c')
221
+ x = torch.cat([cls_token, x], dim = 2)
222
+ return x
223
+
224
+
225
+ class MplugOwlVisionAttention(nn.Module):
226
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
227
+
228
+ def __init__(self, config):
229
+ super().__init__()
230
+ self.config = config
231
+ self.hidden_size = config.hidden_size
232
+ self.num_heads = config.num_attention_heads
233
+ self.head_dim = self.hidden_size // self.num_heads
234
+ if self.head_dim * self.num_heads != self.hidden_size:
235
+ raise ValueError(
236
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
237
+ f" {self.num_heads})."
238
+ )
239
+ self.scale = self.head_dim**-0.5
240
+ self.dropout = nn.Dropout(config.attention_dropout)
241
+
242
+ self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size)
243
+ self.dense = nn.Linear(self.hidden_size, self.hidden_size)
244
+
245
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
246
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
247
+
248
+ def forward(
249
+ self,
250
+ hidden_states: torch.Tensor,
251
+ head_mask: Optional[torch.Tensor] = None,
252
+ output_attentions: Optional[bool] = False,
253
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
254
+ """Input shape: Batch x Time x Channel"""
255
+
256
+ bsz, seq_len, embed_dim = hidden_states.size()
257
+
258
+ mixed_qkv = self.query_key_value(hidden_states)
259
+
260
+ mixed_qkv = mixed_qkv.reshape(bsz, seq_len, self.num_heads, 3, embed_dim // self.num_heads).permute(
261
+ 3, 0, 2, 1, 4
262
+ ) # [3, b, np, sq, hn]
263
+ query_states, key_states, value_states = (
264
+ mixed_qkv[0],
265
+ mixed_qkv[1],
266
+ mixed_qkv[2],
267
+ )
268
+ # if self.config.use_flash_attn and flash_attn_func is not None:
269
+ if False:
270
+ # [b*sq, np, hn]
271
+ query_states = query_states.permute(0, 2, 1, 3).contiguous()
272
+ query_states = query_states.view(query_states.size(0) * query_states.size(1), query_states.size(2), -1)
273
+
274
+ key_states = key_states.permute(0, 2, 1, 3).contiguous()
275
+ key_states = key_states.view(key_states.size(0) * key_states.size(1), key_states.size(2), -1)
276
+
277
+ value_states = value_states.permute(0, 2, 1, 3).contiguous()
278
+ value_states = value_states.view(value_states.size(0) * value_states.size(1), value_states.size(2), -1)
279
+
280
+ cu_seqlens = torch.arange(
281
+ 0, (bsz + 1) * seq_len, step=seq_len, dtype=torch.int32, device=query_states.device
282
+ )
283
+
284
+ context_layer = flash_attn_func(
285
+ query_states,
286
+ key_states,
287
+ value_states,
288
+ cu_seqlens,
289
+ cu_seqlens,
290
+ seq_len,
291
+ seq_len,
292
+ self.dropout if self.training else 0.0,
293
+ softmax_scale=self.scale,
294
+ causal=False,
295
+ return_attn_probs=False,
296
+ )
297
+ # [b*sq, np, hn] => [b, sq, np, hn]
298
+ context_layer = context_layer.view(bsz, seq_len, context_layer.size(1), context_layer.size(2))
299
+ else:
300
+ # Take the dot product between "query" and "key" to get the raw attention scores.
301
+ attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
302
+
303
+ attention_scores = attention_scores * self.scale
304
+
305
+ # Normalize the attention scores to probabilities.
306
+ attention_probs = torch.softmax(attention_scores, dim=-1)
307
+
308
+ # This is actually dropping out entire tokens to attend to, which might
309
+ # seem a bit unusual, but is taken from the original Transformer paper.
310
+ attention_probs = self.dropout(attention_probs)
311
+
312
+ # Mask heads if we want to
313
+ if head_mask is not None:
314
+ attention_probs = attention_probs * head_mask
315
+
316
+ context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
317
+
318
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
319
+ context_layer = context_layer.reshape(new_context_layer_shape)
320
+
321
+ output = self.dense(context_layer)
322
+
323
+ outputs = (output, attention_probs) if output_attentions else (output, None)
324
+
325
+ return outputs
326
+
327
+
328
+ class MplugOwlMLP(nn.Module):
329
+ def __init__(self, config):
330
+ super().__init__()
331
+ self.config = config
332
+ self.activation_fn = QuickGELU()
333
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
334
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
335
+
336
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
337
+ hidden_states = self.fc1(hidden_states)
338
+ hidden_states = self.activation_fn(hidden_states)
339
+ hidden_states = self.fc2(hidden_states)
340
+ return hidden_states
341
+
342
+
343
+ class MplugOwlVisionEncoderLayer(nn.Module):
344
+ def __init__(self, config: MplugOwlVisionConfig):
345
+ super().__init__()
346
+ self.hidden_size = config.hidden_size
347
+ self.temporal = MplugOwlVisionLocalTemporal(config)
348
+ self.self_attn = MplugOwlVisionAttention(config)
349
+ self.input_layernorm = LayerNormFp32(self.hidden_size, eps=config.layer_norm_eps)
350
+ self.mlp = MplugOwlMLP(config)
351
+ self.post_attention_layernorm = LayerNormFp32(self.hidden_size, eps=config.layer_norm_eps)
352
+
353
+ def forward(
354
+ self,
355
+ hidden_states: torch.Tensor,
356
+ attention_mask: torch.Tensor,
357
+ output_attentions: Optional[bool] = False,
358
+ ) -> Tuple[torch.FloatTensor]:
359
+ """
360
+ Args:
361
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, time, seq_len, embed_dim)`
362
+ attention_mask (`torch.FloatTensor`): attention mask of size
363
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
364
+ `(config.encoder_attention_heads,)`.
365
+ output_attentions (`bool`, *optional*):
366
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
367
+ returned tensors for more detail.
368
+ """
369
+ B, T = hidden_states.size(0), hidden_states.size(1)
370
+ if T > 1:
371
+ hidden_states = hidden_states + self.temporal(hidden_states)
372
+ hidden_states = einops.rearrange(hidden_states, 'b t n d -> (b t) n d')
373
+
374
+ residual = hidden_states
375
+
376
+ hidden_states = self.input_layernorm(hidden_states)
377
+ hidden_states, attn_weights = self.self_attn(
378
+ hidden_states=hidden_states,
379
+ head_mask=attention_mask,
380
+ output_attentions=output_attentions,
381
+ )
382
+ hidden_states = hidden_states + residual
383
+ residual = hidden_states
384
+ hidden_states = self.post_attention_layernorm(hidden_states)
385
+ hidden_states = self.mlp(hidden_states)
386
+
387
+ hidden_states = hidden_states + residual
388
+ hidden_states = einops.rearrange(hidden_states, '(b t) n d -> b t n d', b=B)
389
+
390
+ outputs = (hidden_states,)
391
+
392
+ if output_attentions:
393
+ outputs += (attn_weights,)
394
+
395
+ return outputs
396
+
397
+
398
+ class MplugOwlPreTrainedModel(PreTrainedModel):
399
+ """
400
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
401
+ models.
402
+ """
403
+
404
+ config_class = MplugOwlConfig
405
+ base_model_prefix = "mplug_owl"
406
+ supports_gradient_checkpointing = True
407
+ _keys_to_ignore_on_load_missing = [
408
+ r"position_ids",
409
+ r"language_model.encoder.embed_tokens.weight",
410
+ r"language_model.decoder.embed_tokens.weight",
411
+ r"language_model.lm_head.weight",
412
+ ]
413
+ _no_split_modules = [
414
+ "MplugOwlVisionEncoderLayer",
415
+ "LlamaDecoderLayer",
416
+ "MplugOwlVisualAbstractorLayer",
417
+ "LlamaForCausalLM",
418
+ "Parameter",
419
+ ]
420
+ _keep_in_fp32_modules = ["wo"]
421
+
422
+ def _init_weights(self, module):
423
+ """Initialize the weights"""
424
+ factor = self.config.initializer_range
425
+ if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
426
+ module.weight.data.normal_(mean=0.0, std=factor)
427
+ if hasattr(module, "bias") and module.bias is not None:
428
+ module.bias.data.zero_()
429
+
430
+ if isinstance(module, MplugOwlVisionEmbeddings):
431
+ if hasattr(self.config, "vision_config"):
432
+ factor = self.config.vision_config.initializer_range
433
+ nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
434
+ nn.init.trunc_normal_(module.cls_token, mean=0.0, std=factor)
435
+
436
+ elif isinstance(module, nn.LayerNorm):
437
+ module.bias.data.zero_()
438
+ module.weight.data.fill_(1.0)
439
+ elif isinstance(module, nn.Linear) and module.bias is not None:
440
+ module.bias.data.zero_()
441
+ elif isinstance(module, nn.Parameter):
442
+ raise ValueError
443
+ nn.init.trunc_normal_(module.data, mean=0.0, std=factor)
444
+
445
+ def _set_gradient_checkpointing(self, module, value=False):
446
+ if isinstance(module, MplugOwlVisionEncoder):
447
+ module.gradient_checkpointing = value
448
+
449
+
450
+ MPLUG_OWL_START_DOCSTRING = r"""
451
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
452
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
453
+ etc.)
454
+
455
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
456
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
457
+ and behavior.
458
+
459
+ Parameters:
460
+ config ([`MplugOwlConfig`]): Model configuration class with all the parameters of the model.
461
+ Initializing with a config file does not load the weights associated with the model, only the
462
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
463
+ """
464
+
465
+ MPLUG_OWL_VISION_INPUTS_DOCSTRING = r"""
466
+ Args:
467
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
468
+ Pixel values. Pixel values can be obtained using [`MplugOwlProcessor`]. See [`MplugOwlProcessor.__call__`] for
469
+ details.
470
+ output_attentions (`bool`, *optional*):
471
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
472
+ tensors for more detail.
473
+ output_hidden_states (`bool`, *optional*):
474
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
475
+ more detail.
476
+ return_dict (`bool`, *optional*):
477
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
478
+ """
479
+
480
+ MPLUG_OWL_TEXT_INPUTS_DOCSTRING = r"""
481
+ Args:
482
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
483
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
484
+ it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
485
+ [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids)
486
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
487
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
488
+ - 1 for tokens that are **not masked**,
489
+ - 0 for tokens that are **masked**.
490
+ [What are attention masks?](../glossary#attention-mask)
491
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
492
+ Indices of decoder input sequence tokens in the vocabulary.
493
+
494
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
495
+ [`PreTrainedTokenizer.__call__`] for details.
496
+
497
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
498
+
499
+ T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
500
+ is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
501
+
502
+ To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
503
+ Training](./t5#training).
504
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
505
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
506
+ be used by default.
507
+ output_attentions (`bool`, *optional*):
508
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
509
+ tensors for more detail.
510
+ output_hidden_states (`bool`, *optional*):
511
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
512
+ more detail.
513
+ return_dict (`bool`, *optional*):
514
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
515
+ """
516
+
517
+ MPLUG_OWL_INPUTS_DOCSTRING = r"""
518
+ Args:
519
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
520
+ Pixel values. Pixel values can be obtained using [`MplugOwlProcessor`]. See [`MplugOwlProcessor.__call__`] for
521
+ details.
522
+
523
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
524
+ Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be
525
+ provided to serve as text prompt, which the language model can continue.
526
+
527
+ Indices can be obtained using [`MplugOwlProcessor`]. See [`MplugOwlProcessor.__call__`] for details.
528
+
529
+ [What are input IDs?](../glossary#input-ids)
530
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
531
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
532
+
533
+ - 1 for tokens that are **not masked**,
534
+ - 0 for tokens that are **masked**.
535
+
536
+ [What are attention masks?](../glossary#attention-mask)
537
+
538
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
539
+ Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an
540
+ encoder-decoder language model (like T5) is used.
541
+
542
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
543
+ [`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids)
544
+
545
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
546
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
547
+ be used by default.
548
+
549
+ Only relevant in case an encoder-decoder language model (like T5) is used.
550
+
551
+ output_attentions (`bool`, *optional*):
552
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
553
+ tensors for more detail.
554
+ output_hidden_states (`bool`, *optional*):
555
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
556
+ more detail.
557
+ return_dict (`bool`, *optional*):
558
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
559
+ """
560
+
561
+
562
+ class MplugOwlVisionEncoder(nn.Module):
563
+ """
564
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
565
+ [`MplugOwlVisionEncoderLayer`].
566
+
567
+ Args:
568
+ config (`MplugOwlVisionConfig`):
569
+ The corresponding vision configuration for the `MplugOwlEncoder`.
570
+ """
571
+
572
+ def __init__(self, config: MplugOwlVisionConfig):
573
+ super().__init__()
574
+ self.config = config
575
+ self.layers = nn.ModuleList([MplugOwlVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
576
+ self.gradient_checkpointing = False
577
+
578
+ def forward(
579
+ self,
580
+ inputs_embeds,
581
+ attention_mask: Optional[torch.Tensor] = None,
582
+ output_attentions: Optional[bool] = None,
583
+ output_hidden_states: Optional[bool] = None,
584
+ return_dict: Optional[bool] = None,
585
+ ) -> Union[Tuple, BaseModelOutput]:
586
+ r"""
587
+ Args:
588
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
589
+ Embedded representation of the inputs. Should be float, not int tokens.
590
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
591
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
592
+
593
+ - 1 for tokens that are **not masked**,
594
+ - 0 for tokens that are **masked**.
595
+
596
+ [What are attention masks?](../glossary#attention-mask)
597
+ output_attentions (`bool`, *optional*):
598
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
599
+ returned tensors for more detail.
600
+ output_hidden_states (`bool`, *optional*):
601
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
602
+ for more detail.
603
+ return_dict (`bool`, *optional*):
604
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
605
+ """
606
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
607
+ output_hidden_states = (
608
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
609
+ )
610
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
611
+
612
+ encoder_states = () if output_hidden_states else None
613
+ all_attentions = () if output_attentions else None
614
+
615
+ hidden_states = inputs_embeds
616
+ for idx, encoder_layer in enumerate(self.layers):
617
+ if output_hidden_states:
618
+ encoder_states = encoder_states + (hidden_states,)
619
+ if self.gradient_checkpointing and self.training:
620
+
621
+ def create_custom_forward(module):
622
+ def custom_forward(*inputs):
623
+ return module(*inputs, output_attentions)
624
+
625
+ return custom_forward
626
+
627
+ layer_outputs = torch.utils.checkpoint.checkpoint(
628
+ create_custom_forward(encoder_layer),
629
+ hidden_states,
630
+ attention_mask,
631
+ )
632
+ else:
633
+ layer_outputs = encoder_layer(
634
+ hidden_states,
635
+ attention_mask,
636
+ output_attentions=output_attentions,
637
+ )
638
+
639
+ hidden_states = layer_outputs[0]
640
+
641
+ if output_attentions:
642
+ all_attentions = all_attentions + (layer_outputs[1],)
643
+
644
+ if output_hidden_states:
645
+ encoder_states = encoder_states + (hidden_states,)
646
+
647
+ if not return_dict:
648
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
649
+ return BaseModelOutput(
650
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
651
+ )
652
+
653
+
654
+ class MplugOwlVisionModel(MplugOwlPreTrainedModel):
655
+ main_input_name = "pixel_values"
656
+ config_class = MplugOwlVisionConfig
657
+
658
+ def __init__(self, config: MplugOwlVisionConfig):
659
+ super().__init__(config)
660
+ self.config = config
661
+ self.hidden_size = config.hidden_size
662
+
663
+ self.embeddings = MplugOwlVisionEmbeddings(config)
664
+ self.encoder = MplugOwlVisionEncoder(config)
665
+ self.post_layernorm = LayerNormFp32(self.hidden_size, eps=config.layer_norm_eps)
666
+
667
+ self.post_init()
668
+
669
+ @add_start_docstrings_to_model_forward(MPLUG_OWL_VISION_INPUTS_DOCSTRING)
670
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=MplugOwlVisionConfig)
671
+ def forward(
672
+ self,
673
+ pixel_values: Optional[torch.FloatTensor] = None,
674
+ output_attentions: Optional[bool] = None,
675
+ output_hidden_states: Optional[bool] = None,
676
+ return_dict: Optional[bool] = None,
677
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
678
+ r"""
679
+ Returns:
680
+
681
+ """
682
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
683
+ output_hidden_states = (
684
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
685
+ )
686
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
687
+
688
+ if pixel_values is None:
689
+ raise ValueError("You have to specify pixel_values")
690
+
691
+ hidden_states = self.embeddings(pixel_values) # [B, T, N, D]
692
+
693
+ encoder_outputs = self.encoder(
694
+ inputs_embeds=hidden_states,
695
+ output_attentions=output_attentions,
696
+ output_hidden_states=output_hidden_states,
697
+ return_dict=return_dict,
698
+ )
699
+
700
+ last_hidden_state = encoder_outputs[0]
701
+ last_hidden_state = self.post_layernorm(last_hidden_state)
702
+
703
+ pooled_output = last_hidden_state[:, :, 0, :].mean(1)
704
+ pooled_output = self.post_layernorm(pooled_output)
705
+
706
+ if not return_dict:
707
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
708
+
709
+ return BaseModelOutputWithPooling(
710
+ last_hidden_state=last_hidden_state,
711
+ pooler_output=pooled_output,
712
+ hidden_states=encoder_outputs.hidden_states,
713
+ attentions=encoder_outputs.attentions,
714
+ )
715
+
716
+ def get_input_embeddings(self):
717
+ return self.embeddings
718
+
719
+
720
+ class MplugOwlVisualAbstractorMLP(nn.Module):
721
+ def __init__(self, config: MplugOwlVisualAbstractorConfig):
722
+ super().__init__()
723
+ self.config = config
724
+ in_features = config.hidden_size
725
+ hidden_features = config.intermediate_size
726
+ if hidden_features != 2816:
727
+ hidden_features = int(2 * hidden_features / 3)
728
+ multiple_of = 256
729
+ hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
730
+ self.act = nn.SiLU()
731
+
732
+ self.w1 = nn.Linear(in_features, hidden_features)
733
+ self.w2 = nn.Linear(hidden_features, in_features)
734
+ self.w3 = nn.Linear(in_features, hidden_features)
735
+ self.ffn_ln = LayerNormFp32(hidden_features, eps=config.layer_norm_eps)
736
+
737
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
738
+ hidden_states = self.act(self.w1(hidden_states)) * self.w3(hidden_states)
739
+ hidden_states = self.ffn_ln(hidden_states)
740
+ hidden_states = self.w2(hidden_states)
741
+ return hidden_states
742
+
743
+
744
+ class MplugOwlVisualAbstractorMultiHeadAttention(nn.Module):
745
+ def __init__(self, config: MplugOwlVisualAbstractorConfig):
746
+ super().__init__()
747
+ self.config = config
748
+ if config.hidden_size % config.num_attention_heads != 0:
749
+ raise ValueError(
750
+ "The hidden size (%d) is not a multiple of the number of attention heads (%d)"
751
+ % (config.hidden_size, config.num_attention_heads)
752
+ )
753
+
754
+ self.num_attention_heads = config.num_attention_heads
755
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
756
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
757
+
758
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
759
+ self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size)
760
+ self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size)
761
+
762
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
763
+ self.save_attention = False
764
+
765
+ def save_attn_gradients(self, attn_gradients):
766
+ self.attn_gradients = attn_gradients
767
+
768
+ def get_attn_gradients(self):
769
+ return self.attn_gradients
770
+
771
+ def save_attention_map(self, attention_map):
772
+ self.attention_map = attention_map
773
+
774
+ def get_attention_map(self):
775
+ return self.attention_map
776
+
777
+ def transpose_for_scores(self, x):
778
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
779
+ x = x.view(*new_x_shape)
780
+ return x.permute(0, 2, 1, 3)
781
+
782
+ def forward(
783
+ self,
784
+ hidden_states,
785
+ attention_mask=None,
786
+ head_mask=None,
787
+ encoder_hidden_states=None,
788
+ encoder_attention_mask=None,
789
+ past_key_value=None,
790
+ output_attentions=False,
791
+ ):
792
+ # If this is instantiated as a cross-attention module, the keys
793
+ # and values come from an encoder; the attention mask needs to be
794
+ # such that the encoder's padding tokens are not attended to.
795
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
796
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
797
+ attention_mask = encoder_attention_mask
798
+
799
+ mixed_query_layer = self.query(hidden_states)
800
+
801
+ query_layer = self.transpose_for_scores(mixed_query_layer)
802
+
803
+ past_key_value = (key_layer, value_layer)
804
+
805
+ # Take the dot product between "query" and "key" to get the raw attention scores.
806
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
807
+
808
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
809
+
810
+ if attention_mask is not None:
811
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
812
+ attention_scores = attention_scores + attention_mask
813
+
814
+ # Normalize the attention scores to probabilities.
815
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
816
+
817
+ if self.save_attention:
818
+ self.save_attention_map(attention_probs)
819
+ attention_probs.register_hook(self.save_attn_gradients)
820
+
821
+ # This is actually dropping out entire tokens to attend to, which might
822
+ # seem a bit unusual, but is taken from the original Transformer paper.
823
+ attention_probs_dropped = self.dropout(attention_probs)
824
+
825
+ # Mask heads if we want to
826
+ if head_mask is not None:
827
+ attention_probs_dropped = attention_probs_dropped * head_mask
828
+
829
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
830
+
831
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
832
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
833
+ context_layer = context_layer.view(*new_context_layer_shape)
834
+
835
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
836
+
837
+ outputs = outputs + (past_key_value,)
838
+ return outputs
839
+
840
+
841
+ class MplugOwlVisualAbstractorCrossOutput(nn.Module):
842
+ def __init__(self, config: MplugOwlVisualAbstractorConfig):
843
+ super().__init__()
844
+ dim = config.hidden_size
845
+ self.out_proj = nn.Linear(dim, dim, bias=True)
846
+ self.norm2 = LayerNormFp32(dim)
847
+ self.mlp = MplugOwlVisualAbstractorMLP(config)
848
+
849
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
850
+ input_tensor = input_tensor + self.out_proj(hidden_states)
851
+ input_tensor = input_tensor + self.mlp(self.norm2(input_tensor))
852
+ return input_tensor
853
+
854
+
855
+ class MplugOwlVisualAbstractorAttention(nn.Module):
856
+ def __init__(self, config: MplugOwlVisualAbstractorConfig):
857
+ super().__init__()
858
+ self.attention = MplugOwlVisualAbstractorMultiHeadAttention(config)
859
+ self.output = MplugOwlVisualAbstractorCrossOutput(config)
860
+ self.pruned_heads = set()
861
+ self.norm1 = LayerNormFp32(config.hidden_size)
862
+ self.normk = LayerNormFp32(config.hidden_size)
863
+
864
+ def prune_heads(self, heads):
865
+ if len(heads) == 0:
866
+ return
867
+ heads, index = find_pruneable_heads_and_indices(
868
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
869
+ )
870
+
871
+ # Prune linear layers
872
+ self.attention.query = prune_linear_layer(self.attention.query, index)
873
+ self.attention.key = prune_linear_layer(self.attention.key, index)
874
+ self.attention.value = prune_linear_layer(self.attention.value, index)
875
+ self.output.dense = prune_linear_layer(self.output.out_proj, index, dim=1)
876
+
877
+ # Update hyper params and store pruned heads
878
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
879
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
880
+ self.pruned_heads = self.pruned_heads.union(heads)
881
+
882
+ def forward(
883
+ self,
884
+ hidden_states: torch.Tensor,
885
+ attention_mask: Optional[torch.FloatTensor] = None,
886
+ head_mask: Optional[torch.FloatTensor] = None,
887
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
888
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
889
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
890
+ output_attentions: Optional[bool] = False,
891
+ ) -> Tuple[torch.Tensor]:
892
+ # HACK we apply norm on q and k
893
+ hidden_states = self.norm1(hidden_states)
894
+ encoder_hidden_states = self.normk(encoder_hidden_states)
895
+ encoder_hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
896
+ encoder_attention_mask = torch.cat([attention_mask, encoder_attention_mask], dim=-1)
897
+ self_outputs = self.attention(
898
+ hidden_states,
899
+ attention_mask,
900
+ head_mask,
901
+ encoder_hidden_states,
902
+ encoder_attention_mask,
903
+ past_key_value,
904
+ output_attentions,
905
+ )
906
+ attention_output = self.output(self_outputs[0], hidden_states)
907
+ # add attentions if we output them
908
+ outputs = (attention_output,) + self_outputs[1:]
909
+ return outputs
910
+
911
+
912
+ class MplugOwlVisualAbstractorLayer(nn.Module):
913
+ def __init__(self, config, layer_idx):
914
+ super().__init__()
915
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
916
+ self.seq_len_dim = 1
917
+
918
+ self.layer_idx = layer_idx
919
+
920
+ self.crossattention = MplugOwlVisualAbstractorAttention(config)
921
+ self.has_cross_attention = True
922
+
923
+ def forward(
924
+ self,
925
+ hidden_states,
926
+ attention_mask=None,
927
+ head_mask=None,
928
+ encoder_hidden_states=None,
929
+ encoder_attention_mask=None,
930
+ output_attentions=False,
931
+ ):
932
+ if encoder_hidden_states is None:
933
+ raise ValueError("encoder_hidden_states must be given for cross-attention layers")
934
+ cross_attention_outputs = self.crossattention(
935
+ hidden_states,
936
+ attention_mask,
937
+ head_mask,
938
+ encoder_hidden_states,
939
+ encoder_attention_mask,
940
+ output_attentions=output_attentions,
941
+ )
942
+ query_attention_output = cross_attention_outputs[0]
943
+
944
+ outputs = (query_attention_output,)
945
+ return outputs
946
+
947
+
948
+ class MplugOwlVisualAbstractorEncoder(nn.Module):
949
+ def __init__(self, config):
950
+ super().__init__()
951
+ self.config = config
952
+ self.layers = nn.ModuleList(
953
+ [MplugOwlVisualAbstractorLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
954
+ )
955
+ self.gradient_checkpointing = False
956
+
957
+ def forward(
958
+ self,
959
+ hidden_states,
960
+ attention_mask=None,
961
+ head_mask=None,
962
+ encoder_hidden_states=None,
963
+ encoder_attention_mask=None,
964
+ past_key_values=None,
965
+ output_attentions=False,
966
+ output_hidden_states=False,
967
+ return_dict=True,
968
+ ):
969
+ all_hidden_states = () if output_hidden_states else None
970
+
971
+ for i in range(self.config.num_hidden_layers):
972
+ layer_module = self.layers[i]
973
+ if output_hidden_states:
974
+ all_hidden_states = all_hidden_states + (hidden_states,)
975
+
976
+ layer_head_mask = head_mask[i] if head_mask is not None else None
977
+ past_key_value = past_key_values[i] if past_key_values is not None else None
978
+
979
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
980
+
981
+ def create_custom_forward(module):
982
+ def custom_forward(*inputs):
983
+ return module(*inputs, past_key_value, output_attentions)
984
+
985
+ return custom_forward
986
+
987
+ layer_outputs = torch.utils.checkpoint.checkpoint(
988
+ create_custom_forward(layer_module),
989
+ hidden_states,
990
+ attention_mask,
991
+ layer_head_mask,
992
+ encoder_hidden_states,
993
+ encoder_attention_mask,
994
+ )
995
+ else:
996
+ layer_outputs = layer_module(
997
+ hidden_states,
998
+ attention_mask,
999
+ layer_head_mask,
1000
+ encoder_hidden_states,
1001
+ encoder_attention_mask,
1002
+ output_attentions,
1003
+ )
1004
+
1005
+ hidden_states = layer_outputs[0]
1006
+
1007
+ return BaseModelOutput(
1008
+ last_hidden_state=hidden_states,
1009
+ )
1010
+
1011
+
1012
+ class MplugOwlVisualAbstractorModel(MplugOwlPreTrainedModel):
1013
+ def __init__(self, config: MplugOwlVisualAbstractorConfig, language_hidden_size):
1014
+ super().__init__(config)
1015
+ self.config = config
1016
+
1017
+ self.encoder = MplugOwlVisualAbstractorEncoder(config)
1018
+ self.visual_fc = torch.nn.Linear(config.hidden_size, language_hidden_size)
1019
+ self.temporal_visual_fc = torch.nn.Linear(config.hidden_size, language_hidden_size)
1020
+ self.vit_eos = torch.nn.Parameter(torch.randn(1, 1, language_hidden_size))
1021
+ nn.init.trunc_normal_(self.vit_eos, mean=0.0, std=self.config.initializer_range)
1022
+ self.post_init()
1023
+
1024
+ def _prune_heads(self, heads_to_prune):
1025
+ """
1026
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1027
+ class PreTrainedModel
1028
+ """
1029
+ for layer, heads in heads_to_prune.items():
1030
+ self.encoder.layer[layer].attention.prune_heads(heads)
1031
+
1032
+ def get_extended_attention_mask(
1033
+ self,
1034
+ attention_mask: torch.Tensor,
1035
+ input_shape: Tuple[int],
1036
+ device: torch.device,
1037
+ ) -> torch.Tensor:
1038
+ """
1039
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
1040
+
1041
+ Arguments:
1042
+ attention_mask (`torch.Tensor`):
1043
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
1044
+ input_shape (`Tuple[int]`):
1045
+ The shape of the input to the model.
1046
+ device: (`torch.device`):
1047
+ The device of the input to the model.
1048
+
1049
+ Returns:
1050
+ `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
1051
+ """
1052
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1053
+ # ourselves in which case we just need to make it broadcastable to all heads.
1054
+ if attention_mask.dim() == 3:
1055
+ extended_attention_mask = attention_mask[:, None, :, :]
1056
+ elif attention_mask.dim() == 2:
1057
+ # Provided a padding mask of dimensions [batch_size, seq_length]
1058
+ # - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
1059
+ extended_attention_mask = attention_mask[:, None, None, :]
1060
+ else:
1061
+ raise ValueError(
1062
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
1063
+ input_shape, attention_mask.shape
1064
+ )
1065
+ )
1066
+
1067
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
1068
+ # masked positions, this operation will create a tensor which is 0.0 for
1069
+ # positions we want to attend and -10000.0 for masked positions.
1070
+ # Since we are adding it to the raw scores before the softmax, this is
1071
+ # effectively the same as removing these entirely.
1072
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
1073
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
1074
+ return extended_attention_mask
1075
+
1076
+ def forward(
1077
+ self,
1078
+ query_embeds,
1079
+ temporal_query_embeds=None,
1080
+ attention_mask=None,
1081
+ head_mask=None,
1082
+ encoder_hidden_states=None,
1083
+ encoder_attention_mask=None,
1084
+ past_key_values=None,
1085
+ output_attentions=None,
1086
+ output_hidden_states=None,
1087
+ return_dict=None,
1088
+ ):
1089
+ r"""
1090
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
1091
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1092
+ the model is configured as a decoder.
1093
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
1094
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1095
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1096
+ - 1 for tokens that are **not masked**,
1097
+ - 0 for tokens that are **masked**.
1098
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
1099
+ shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
1100
+ value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
1101
+ used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
1102
+ value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
1103
+ `(batch_size, sequence_length)`.
1104
+ """
1105
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1106
+ output_hidden_states = (
1107
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1108
+ )
1109
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1110
+
1111
+ T = encoder_hidden_states.size(1)
1112
+ if T == 1 or temporal_query_embeds is None:
1113
+ embedding_output = query_embeds
1114
+ else:
1115
+ embedding_output = torch.cat([query_embeds, temporal_query_embeds], dim=1)
1116
+ input_shape = embedding_output.size()[:-1]
1117
+ batch_size, seq_length = input_shape
1118
+ device = embedding_output.device
1119
+
1120
+ encoder_hidden_states = einops.rearrange(
1121
+ encoder_hidden_states, 'b t n d -> b (t n) d'
1122
+ )
1123
+
1124
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1125
+ # ourselves in which case we just need to make it broadcastable to all heads.
1126
+ if attention_mask is None:
1127
+ attention_mask = torch.ones(
1128
+ (embedding_output.shape[0], embedding_output.shape[1]), dtype=torch.long, device=embedding_output.device
1129
+ )
1130
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
1131
+
1132
+ # If a 2D or 3D attention mask is provided for the cross-attention
1133
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1134
+ if encoder_hidden_states is not None:
1135
+ if type(encoder_hidden_states) == list:
1136
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
1137
+ else:
1138
+ (
1139
+ encoder_batch_size,
1140
+ encoder_sequence_length,
1141
+ _,
1142
+ ) = encoder_hidden_states.size()
1143
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1144
+
1145
+ if type(encoder_attention_mask) == list:
1146
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
1147
+ elif encoder_attention_mask is None:
1148
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1149
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1150
+ else:
1151
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1152
+ else:
1153
+ encoder_extended_attention_mask = None
1154
+
1155
+ # Prepare head mask if needed
1156
+ # 1.0 in head_mask indicate we keep the head
1157
+ # attention_probs has shape bsz x n_heads x N x N
1158
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1159
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1160
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1161
+
1162
+ encoder_outputs = self.encoder(
1163
+ embedding_output,
1164
+ attention_mask=extended_attention_mask,
1165
+ head_mask=head_mask,
1166
+ encoder_hidden_states=encoder_hidden_states,
1167
+ encoder_attention_mask=encoder_extended_attention_mask,
1168
+ past_key_values=past_key_values,
1169
+ output_attentions=output_attentions,
1170
+ output_hidden_states=output_hidden_states,
1171
+ return_dict=return_dict,
1172
+ )
1173
+ sequence_output = encoder_outputs[0]
1174
+ pooled_output = sequence_output[:, 0, :]
1175
+
1176
+ if T == 1 or temporal_query_embeds is None:
1177
+ temporal_sequence_output = None
1178
+ else:
1179
+ temporal_sequence_output = sequence_output[:, query_embeds.size(1):]
1180
+ sequence_output = sequence_output[:, :query_embeds.size(1)]
1181
+
1182
+ sequence_output = self.visual_fc(sequence_output)
1183
+ if temporal_sequence_output is not None:
1184
+ sequence_output += self.temporal_visual_fc(temporal_sequence_output)
1185
+ sequence_output = torch.cat([sequence_output, self.vit_eos.repeat(sequence_output.shape[0], 1, 1)], dim=1)
1186
+
1187
+ return BaseModelOutputWithPooling(
1188
+ last_hidden_state=sequence_output,
1189
+ pooler_output=pooled_output,
1190
+ hidden_states=encoder_outputs.hidden_states,
1191
+ )
1192
+
1193
+
1194
+ @add_start_docstrings(
1195
+ """
1196
+ mPLUG-Owl Model for generating text and image features. The model consists of a vision encoder, Querying Transformer
1197
+ (Q-Former) and a language model.
1198
+ """,
1199
+ MPLUG_OWL_START_DOCSTRING,
1200
+ )
1201
+ class MplugOwlModel(MplugOwlPreTrainedModel):
1202
+ config_class = MplugOwlConfig
1203
+ main_input_name = "pixel_values"
1204
+
1205
+ def __init__(self, config: MplugOwlConfig, *inputs, **kwargs):
1206
+ super().__init__(config, *inputs, **kwargs)
1207
+
1208
+ self.vision_model = MplugOwlVisionModel(config.vision_config)
1209
+
1210
+ self.query_tokens = nn.Parameter(
1211
+ torch.zeros(1, config.num_query_tokens, config.visual_abstractor_config.hidden_size)
1212
+ )
1213
+ self.temporal_query_tokens = nn.Parameter(
1214
+ torch.zeros(1, config.num_query_tokens, config.visual_abstractor_config.hidden_size)
1215
+ )
1216
+ self.abstractor = MplugOwlVisualAbstractorModel(
1217
+ config.visual_abstractor_config, config.text_config.hidden_size
1218
+ )
1219
+
1220
+ # if config.use_decoder_only_language_model:
1221
+ # from llama.modeling_llama import LlamaForCausalLM
1222
+ language_model = AutoModelForCausalLM.from_config(config.text_config)
1223
+ # else:
1224
+ # language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
1225
+ self.language_model = language_model
1226
+
1227
+ # Initialize weights and apply final processing
1228
+ self.post_init()
1229
+
1230
+ def get_input_embeddings(self):
1231
+ return self.language_model.get_input_embeddings()
1232
+
1233
+ def set_input_embeddings(self, value):
1234
+ self.language_model.set_input_embeddings(value)
1235
+
1236
+ def set_output_embeddings(self, new_embeddings):
1237
+ self.language_model.set_output_embeddings(new_embeddings)
1238
+
1239
+ def get_output_embeddings(self) -> nn.Module:
1240
+ return self.language_model.get_output_embeddings()
1241
+
1242
+ def get_encoder(self):
1243
+ return self.language_model.get_encoder()
1244
+
1245
+ def get_decoder(self):
1246
+ return self.language_model.get_decoder()
1247
+
1248
+ def _tie_weights(self):
1249
+ if not self.config.use_decoder_only_language_model:
1250
+ self.language_model.encoder.embed_tokens = self.language_model.shared
1251
+ self.language_model.decoder.embed_tokens = self.language_model.shared
1252
+
1253
+ def get_text_features(
1254
+ self,
1255
+ input_ids: Optional[torch.Tensor] = None,
1256
+ attention_mask: Optional[torch.Tensor] = None,
1257
+ decoder_input_ids: Optional[torch.Tensor] = None,
1258
+ decoder_attention_mask: Optional[torch.Tensor] = None,
1259
+ labels: Optional[torch.Tensor] = None,
1260
+ output_attentions: Optional[bool] = None,
1261
+ output_hidden_states: Optional[bool] = None,
1262
+ return_dict: Optional[bool] = None,
1263
+ ):
1264
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1265
+ output_hidden_states = (
1266
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1267
+ )
1268
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1269
+
1270
+ if self.config.use_decoder_only_language_model:
1271
+ text_outputs = self.language_model(
1272
+ input_ids=input_ids,
1273
+ attention_mask=attention_mask,
1274
+ output_attentions=output_attentions,
1275
+ output_hidden_states=output_hidden_states,
1276
+ return_dict=return_dict,
1277
+ )
1278
+ else:
1279
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
1280
+
1281
+ text_outputs = self.language_model(
1282
+ inputs_embeds=inputs_embeds,
1283
+ attention_mask=attention_mask,
1284
+ decoder_input_ids=decoder_input_ids,
1285
+ decoder_attention_mask=decoder_attention_mask,
1286
+ output_attentions=output_attentions,
1287
+ output_hidden_states=output_hidden_states,
1288
+ return_dict=return_dict,
1289
+ labels=labels,
1290
+ )
1291
+
1292
+ return text_outputs
1293
+
1294
+ def get_image_features(
1295
+ self,
1296
+ pixel_values: Optional[torch.FloatTensor] = None,
1297
+ output_attentions: Optional[bool] = None,
1298
+ output_hidden_states: Optional[bool] = None,
1299
+ return_dict: Optional[bool] = None,
1300
+ ):
1301
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1302
+ output_hidden_states = (
1303
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1304
+ )
1305
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1306
+
1307
+ vision_outputs = self.vision_model(
1308
+ pixel_values=pixel_values,
1309
+ output_attentions=output_attentions,
1310
+ output_hidden_states=output_hidden_states,
1311
+ return_dict=return_dict,
1312
+ )
1313
+
1314
+ return vision_outputs
1315
+
1316
+
1317
+ def get_media_indices(my_list):
1318
+ if isinstance(my_list, torch.Tensor):
1319
+ my_list = my_list.cpu().tolist()
1320
+ result = []
1321
+ for i in range(len(my_list)):
1322
+ if i == 0 and my_list[i] < 0:
1323
+ result.append(i)
1324
+ elif my_list[i] != my_list[i - 1] and my_list[i] < 0:
1325
+ result.append(i)
1326
+ return result
1327
+
1328
+ def get_media_types(tensors, positions):
1329
+ if isinstance(tensors, torch.Tensor):
1330
+ tensors = tensors.cpu().tolist()
1331
+ result = []
1332
+ for pos in positions:
1333
+ result.append(tensors[pos])
1334
+ return result
1335
+
1336
+
1337
+ @add_start_docstrings(
1338
+ """
1339
+ mPLUG-Owl Model for generating text given an image and an optional text prompt.
1340
+ """,
1341
+ MPLUG_OWL_START_DOCSTRING,
1342
+ )
1343
+ class MplugOwlForConditionalGeneration(MplugOwlPreTrainedModel):
1344
+ config_class = MplugOwlConfig
1345
+ main_input_name = "pixel_values"
1346
+
1347
+ def __init__(self, config: MplugOwlConfig):
1348
+ super().__init__(config)
1349
+
1350
+ self.vision_model = MplugOwlVisionModel(config.vision_config)
1351
+
1352
+ self.query_tokens = nn.Parameter(
1353
+ torch.zeros(1, config.num_query_tokens, config.visual_abstractor_config.hidden_size)
1354
+ )
1355
+ self.temporal_query_tokens = nn.Parameter(
1356
+ torch.zeros(1, config.num_query_tokens, config.visual_abstractor_config.hidden_size)
1357
+ )
1358
+ self.abstractor = MplugOwlVisualAbstractorModel(
1359
+ config.visual_abstractor_config, config.text_config.hidden_size
1360
+ )
1361
+
1362
+ # if config.use_decoder_only_language_model:
1363
+ # from llama.modeling_llama import LlamaForCausalLM
1364
+ language_model = AutoModelForCausalLM.from_config(config.text_config)
1365
+ # else:
1366
+ # language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
1367
+ self.language_model = language_model
1368
+
1369
+ # Initialize weights and apply final processing
1370
+ self.post_init()
1371
+ self.main_input_name = "input_ids"
1372
+ from transformers import GenerationConfig
1373
+
1374
+ self.generation_config = GenerationConfig(
1375
+ max_length=512, do_sample=True, top_k=3, pad_token_id=0, unk_token_id=0, bos_token_id=1, eos_token_id=2
1376
+ )
1377
+
1378
+ # Hack Bloom
1379
+ if config.text_config.model_type == 'bloom':
1380
+ bound_method = bloom_forward.__get__(self.language_model.transformer, self.language_model.transformer.__class__)
1381
+ setattr(self.language_model.transformer, 'forward', bound_method)
1382
+
1383
+ def get_input_embeddings(self):
1384
+ return self.language_model.get_input_embeddings()
1385
+
1386
+ def set_input_embeddings(self, value):
1387
+ self.language_model.set_input_embeddings(value)
1388
+
1389
+ def set_output_embeddings(self, new_embeddings):
1390
+ self.language_model.set_output_embeddings(new_embeddings)
1391
+
1392
+ def get_output_embeddings(self) -> nn.Module:
1393
+ return self.language_model.get_output_embeddings()
1394
+
1395
+ def get_encoder(self):
1396
+ return self.language_model.get_encoder()
1397
+
1398
+ def get_decoder(self):
1399
+ return self.language_model.get_decoder()
1400
+
1401
+ def _tie_weights(self):
1402
+ if not self.config.use_decoder_only_language_model:
1403
+ self.language_model.encoder.embed_tokens = self.language_model.shared
1404
+ self.language_model.decoder.embed_tokens = self.language_model.shared
1405
+
1406
+ def _preprocess_accelerate(self):
1407
+ r"""
1408
+ Some pre-processing hacks to make the model `accelerate` compatible. Check
1409
+ https://github.com/huggingface/transformers/pull/21707 for more details.
1410
+ """
1411
+ hf_device_map = self.hf_device_map
1412
+
1413
+ if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
1414
+ # warn users about unexpected behavior when using multi-GPU + mPLUG-Owl + `accelerate`.
1415
+ logger.warning(
1416
+ "The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
1417
+ " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
1418
+ " Please pass a `device_map` that contains `language_model` to remove this warning."
1419
+ " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
1420
+ " more details on creating a `device_map` for large models.",
1421
+ )
1422
+
1423
+ if hasattr(self.language_model, "_hf_hook"):
1424
+ self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
1425
+
1426
+ @add_start_docstrings_to_model_forward(MPLUG_OWL_INPUTS_DOCSTRING)
1427
+ @replace_return_docstrings(
1428
+ output_type=MplugOwlForConditionalGenerationModelOutput, config_class=MplugOwlVisionConfig
1429
+ )
1430
+ def forward(
1431
+ self,
1432
+ pixel_values: torch.FloatTensor,
1433
+ video_pixel_values: torch.FloatTensor,
1434
+ input_ids: torch.FloatTensor,
1435
+ num_images,
1436
+ num_videos,
1437
+ non_padding_mask: Optional[torch.LongTensor] = None,
1438
+ non_media_mask: Optional[torch.LongTensor] = None,
1439
+ prompt_mask: Optional[torch.LongTensor] = None,
1440
+ attention_mask: Optional[torch.LongTensor] = None,
1441
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1442
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1443
+ output_attentions: Optional[bool] = None,
1444
+ output_hidden_states: Optional[bool] = None,
1445
+ labels: Optional[torch.LongTensor] = None,
1446
+ return_dict: Optional[bool] = None,
1447
+ **forward_kwargs,
1448
+ ) -> Union[Tuple, MplugOwlForConditionalGenerationModelOutput]:
1449
+ r"""
1450
+ Returns:
1451
+
1452
+ Examples:
1453
+
1454
+ Image captioning (without providing a text prompt):
1455
+
1456
+ ```python
1457
+ >>> from PIL import Image
1458
+ >>> import requests
1459
+ >>> from transformers import MplugOwlProcessor, MplugOwlForConditionalGeneration
1460
+ >>> import torch
1461
+
1462
+ >>> device = "cuda" if torch.cuda.is_available() else "cpu"
1463
+
1464
+ >>> processor = MplugOwlProcessor.from_pretrained("x-plug/x_plug-llama-7b")
1465
+ >>> model = MplugOwlForConditionalGeneration.from_pretrained(
1466
+ ... "x-plug/x_plug-llama-7b", torch_dtype=torch.float16
1467
+ ... )
1468
+ >>> model.to(device) # doctest: +IGNORE_RESULT
1469
+
1470
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1471
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1472
+
1473
+ >>> inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
1474
+
1475
+ >>> generated_ids = model.generate(**inputs)
1476
+ >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
1477
+ >>> print(generated_text)
1478
+ two cats laying on a couch
1479
+ ```
1480
+
1481
+ Visual question answering (prompt = question):
1482
+
1483
+ ```python
1484
+ >>> from PIL import Image
1485
+ >>> import requests
1486
+ >>> from transformers import MplugOwlProcessor, MplugOwlForConditionalGeneration
1487
+ >>> import torch
1488
+
1489
+ >>> device = "cuda" if torch.cuda.is_available() else "cpu"
1490
+
1491
+ >>> processor = MplugOwlProcessor.from_pretrained("x-plug/x_plug-llama-7b")
1492
+ >>> model = MplugOwlForConditionalGeneration.from_pretrained(
1493
+ ... "x-plug/x_plug-llama-7b", torch_dtype=torch.float16
1494
+ ... )
1495
+ >>> model.to(device) # doctest: +IGNORE_RESULT
1496
+
1497
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1498
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1499
+
1500
+ >>> prompt = "Question: how many cats are there? Answer:"
1501
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
1502
+
1503
+ >>> generated_ids = model.generate(**inputs)
1504
+ >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
1505
+ >>> print(generated_text)
1506
+ two
1507
+ ```"""
1508
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1509
+
1510
+ if attention_mask is None:
1511
+ attention_mask = input_ids.new_ones(*input_ids.shape)
1512
+
1513
+ # get text embedding
1514
+ text_tokens_ = input_ids.clone()
1515
+ batch_size = input_ids.shape[0]
1516
+
1517
+ media_token_indices = [
1518
+ # [:-1] since we would not use the last token for embedding
1519
+ get_media_indices(text_tokens_[i][:-1])
1520
+ for i in range(batch_size)
1521
+ ]
1522
+
1523
+ media_token_types = [
1524
+ get_media_types(text_tokens_[i][:-1], media_token_indices[i])
1525
+ for i in range(batch_size)
1526
+ ]
1527
+
1528
+ text_tokens_[text_tokens_ < 0] = 1 # Not used
1529
+ inputs_embeds = self.get_input_embeddings()(text_tokens_) # Temporally Embedding
1530
+
1531
+ if hasattr(self.language_model, 'transformer') and hasattr(self.language_model.transformer, 'word_embeddings_layernorm'):
1532
+ inputs_embeds = self.language_model.transformer.word_embeddings_layernorm(inputs_embeds)
1533
+
1534
+ if pixel_values is not None:
1535
+ image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state
1536
+
1537
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
1538
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
1539
+ temporal_query_tokens = self.temporal_query_tokens.expand(image_embeds.shape[0], -1, -1)
1540
+
1541
+ query_features = self.abstractor(
1542
+ query_embeds=query_tokens,
1543
+ encoder_hidden_states=image_embeds,
1544
+ encoder_attention_mask=image_attention_mask,
1545
+ )["last_hidden_state"]
1546
+ img_seq_length = query_features.shape[1]
1547
+
1548
+ if video_pixel_values is not None:
1549
+ video_embeds = self.vision_model(video_pixel_values, return_dict=True).last_hidden_state
1550
+
1551
+ video_attention_mask = torch.ones(video_embeds.size()[:-1], dtype=torch.long, device=video_embeds.device)
1552
+ video_attention_mask = einops.rearrange(
1553
+ video_attention_mask, 'b t n -> b (t n)'
1554
+ )
1555
+ query_tokens = self.query_tokens.expand(video_embeds.shape[0], -1, -1)
1556
+ temporal_query_tokens = self.temporal_query_tokens.expand(video_embeds.shape[0], -1, -1)
1557
+
1558
+ video_query_features = self.abstractor(
1559
+ query_embeds=query_tokens,
1560
+ temporal_query_embeds=temporal_query_tokens,
1561
+ encoder_hidden_states=video_embeds,
1562
+ encoder_attention_mask=video_attention_mask,
1563
+ )["last_hidden_state"]
1564
+ video_embeds = video_query_features
1565
+ vid_seq_length = video_query_features.shape[1]
1566
+
1567
+ num_images_per_sample = num_images.long().cpu().tolist()
1568
+ num_videos_per_sample = num_videos.long().cpu().tolist()
1569
+
1570
+ text_chunk_embeds = []
1571
+ text_chunk_attns = []
1572
+ img_idx = 0
1573
+ vid_idx = 0
1574
+ for b in range(batch_size):
1575
+ start = 0
1576
+ result = []
1577
+ result_attn = []
1578
+ for i, pos in enumerate(media_token_indices[b]):
1579
+ curr_image_idx, curr_video_idx = 0, 0
1580
+ if pos > start:
1581
+ result.append(inputs_embeds[b, start:pos])
1582
+ result_attn.append(attention_mask[b, start:pos])
1583
+ if media_token_types[b][i] == -1:
1584
+ result.append(image_embeds[img_idx + curr_image_idx])
1585
+ result_attn.append(torch.ones(image_embeds[img_idx + curr_image_idx].shape[0], device=inputs_embeds.device))
1586
+ start = pos + img_seq_length
1587
+ curr_image_idx += 1
1588
+ else:
1589
+ result.append(video_embeds[vid_idx + curr_video_idx])
1590
+ result_attn.append(torch.ones(video_embeds[vid_idx + curr_video_idx].shape[0], device=inputs_embeds.device))
1591
+ start = pos + vid_seq_length
1592
+ curr_video_idx += 1
1593
+ if start < inputs_embeds.shape[1]:
1594
+ result.append(inputs_embeds[b, start:])
1595
+ result_attn.append(attention_mask[b, start:])
1596
+
1597
+ img_idx += num_images_per_sample[b]
1598
+ vid_idx += num_videos_per_sample[b]
1599
+ text_chunk_embeds.append(torch.cat(result, dim=0))
1600
+ text_chunk_attns.append(torch.cat(result_attn, dim=0))
1601
+
1602
+ inputs_embeds = torch.stack(text_chunk_embeds, dim=0)
1603
+ attention_mask = torch.stack(text_chunk_attns, dim=0)
1604
+
1605
+ if labels is not None:
1606
+ # Create causal mask and position ids
1607
+ _, loss_mask, position_ids = get_ltor_masks_and_position_ids_from_embeddings(inputs_embeds)
1608
+
1609
+ # Calculate the loss_mask
1610
+ non_padding_mask = non_padding_mask.long()
1611
+ non_media_mask = non_media_mask.long()
1612
+ prompt_mask = prompt_mask.long() # TODO How to deal with prompt mask
1613
+ loss_mask = loss_mask[:, :-1]
1614
+
1615
+ loss_mask = loss_mask * non_padding_mask * non_media_mask * prompt_mask
1616
+ labels[:, 1:][loss_mask != 1] = -100
1617
+
1618
+ # Forward into GPT
1619
+ outputs = self.language_model(
1620
+ inputs_embeds=inputs_embeds,
1621
+ attention_mask=attention_mask,
1622
+ labels=labels,
1623
+ return_dict=return_dict,
1624
+ output_attentions=self.config.output_attentions,
1625
+ )
1626
+
1627
+ return outputs
1628
+
1629
+ @torch.no_grad()
1630
+ def generate(
1631
+ self,
1632
+ pixel_values: torch.FloatTensor = None,
1633
+ video_pixel_values: torch.FloatTensor = None,
1634
+ input_ids: Optional[torch.LongTensor] = None,
1635
+ attention_mask: Optional[torch.LongTensor] = None,
1636
+ isdecoder=True,
1637
+ get_logits_only=False,
1638
+ **generate_kwargs,
1639
+ ) -> torch.LongTensor:
1640
+ """
1641
+ Overrides `generate` function to be able to use the model as a conditional generator.
1642
+
1643
+ Args:
1644
+ pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)):
1645
+ Input images to be processed.
1646
+ input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
1647
+ The sequence used as a prompt for the generation.
1648
+ attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
1649
+ Mask to avoid performing attention on padding token indices
1650
+
1651
+ Returns:
1652
+ captions (list): A list of strings of length batch_size * num_captions.
1653
+ """
1654
+ if input_ids is None:
1655
+ return self.language_model.generate(attention_mask=attention_mask, **generate_kwargs)
1656
+
1657
+ if attention_mask is None:
1658
+ attention_mask = input_ids.new_ones(*input_ids.shape)
1659
+
1660
+ batch_size = input_ids.size(0)
1661
+ media_token_indices = [get_media_indices(input_ids[i]) for i in range(batch_size)]
1662
+ media_token_types = [
1663
+ get_media_types(input_ids[i], media_token_indices[i])
1664
+ for i in range(batch_size)
1665
+ ]
1666
+ num_images_per_sample = [len([y for y in x if y==-1]) for x in media_token_types]
1667
+ num_videos_per_sample = [len([y for y in x if y<-1]) for x in media_token_types]
1668
+ input_ids = input_ids.clone() # prevent inplace modify
1669
+ input_ids[input_ids < 0] = 0 # Not used
1670
+
1671
+ if hasattr(self, "hf_device_map"):
1672
+ # preprocess for `accelerate`
1673
+ self._preprocess_accelerate()
1674
+
1675
+ batch_size = input_ids.shape[0]
1676
+ # get text embedding
1677
+ inputs_embeds = self.get_input_embeddings()(input_ids)
1678
+ if hasattr(self.language_model, 'transformer') and hasattr(self.language_model.transformer, 'word_embeddings_layernorm'):
1679
+ inputs_embeds = self.language_model.transformer.word_embeddings_layernorm(inputs_embeds)
1680
+ # get visual embedding
1681
+ if pixel_values is not None:
1682
+ pixel_values = pixel_values.to(input_ids.device)
1683
+ with torch.no_grad():
1684
+ image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state
1685
+ image_attention_mask = torch.ones(
1686
+ image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device
1687
+ )
1688
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
1689
+ query_outputs = self.abstractor(
1690
+ query_embeds=query_tokens,
1691
+ encoder_hidden_states=image_embeds,
1692
+ encoder_attention_mask=image_attention_mask,
1693
+ return_dict=True,
1694
+ )
1695
+ query_output = query_outputs["last_hidden_state"]
1696
+ image_embeds = query_output
1697
+ img_seq_length = image_embeds.shape[1]
1698
+
1699
+ if video_pixel_values is not None:
1700
+ video_pixel_values = video_pixel_values.to(input_ids.device)
1701
+ with torch.no_grad():
1702
+ video_embeds = self.vision_model(video_pixel_values, return_dict=True).last_hidden_state
1703
+ video_attention_mask = torch.ones(
1704
+ video_embeds.size()[:-1], dtype=torch.long, device=video_embeds.device
1705
+ )
1706
+ video_attention_mask = einops.rearrange(
1707
+ video_attention_mask, 'b t n -> b (t n)'
1708
+ )
1709
+ query_tokens = self.query_tokens.expand(video_embeds.shape[0], -1, -1)
1710
+ temporal_query_tokens = self.temporal_query_tokens.expand(video_embeds.shape[0], -1, -1)
1711
+ query_outputs = self.abstractor(
1712
+ query_embeds=query_tokens,
1713
+ temporal_query_embeds=temporal_query_tokens,
1714
+ encoder_hidden_states=video_embeds,
1715
+ encoder_attention_mask=video_attention_mask,
1716
+ return_dict=True,
1717
+ )
1718
+ query_output = query_outputs["last_hidden_state"]
1719
+ video_embeds = query_output
1720
+ vid_seq_length = video_embeds.shape[1]
1721
+
1722
+ # ===================
1723
+ # Get actual input embeddings
1724
+ # ===================
1725
+ text_chunk_embeds = []
1726
+ text_chunk_attns = []
1727
+ img_idx = 0
1728
+ vid_idx = 0
1729
+
1730
+ for b in range(batch_size):
1731
+ start = 0
1732
+ result = []
1733
+ result_attn = []
1734
+ for i, pos in enumerate(media_token_indices[b]):
1735
+ curr_image_idx, curr_video_idx = 0, 0
1736
+ if pos > start:
1737
+ result.append(inputs_embeds[b, start:pos])
1738
+ result_attn.append(attention_mask[b, start:pos])
1739
+ if media_token_types[b][i] == -1:
1740
+ result.append(image_embeds[img_idx + curr_image_idx])
1741
+ result_attn.append(torch.ones(image_embeds[img_idx + curr_image_idx].shape[0], device=inputs_embeds.device))
1742
+ start = pos + img_seq_length
1743
+ curr_image_idx += 1
1744
+ else:
1745
+ result.append(video_embeds[vid_idx + curr_video_idx])
1746
+ result_attn.append(torch.ones(video_embeds[vid_idx + curr_video_idx].shape[0], device=inputs_embeds.device))
1747
+ start = pos + vid_seq_length
1748
+ curr_video_idx += 1
1749
+ if start < inputs_embeds.shape[1]:
1750
+ result.append(inputs_embeds[b, start:])
1751
+ result_attn.append(attention_mask[b, start:])
1752
+
1753
+ img_idx += num_images_per_sample[b]
1754
+ vid_idx += num_videos_per_sample[b]
1755
+ text_chunk_embeds.append(torch.cat(result, dim=0))
1756
+ text_chunk_attns.append(torch.cat(result_attn, dim=0))
1757
+ inputs_embeds = torch.stack(text_chunk_embeds, dim=0)
1758
+ attention_mask = torch.stack(text_chunk_attns, dim=0)
1759
+
1760
+ if get_logits_only:
1761
+ outputs = self.language_model(
1762
+ inputs_embeds=inputs_embeds,
1763
+ attention_mask=attention_mask,
1764
+ return_dict=True,
1765
+ output_attentions=self.config.output_attentions,
1766
+ )
1767
+ else:
1768
+ outputs = self.language_model.generate(
1769
+ inputs_embeds=inputs_embeds,
1770
+ attention_mask=attention_mask,
1771
+ **generate_kwargs,
1772
+ )
1773
+
1774
+ return outputs
1775
+
1776
+ def prepare_inputs_for_generation(
1777
+ self, input_ids, pixel_values=None, video_pixel_values=None,
1778
+ past_key_values=None, attention_mask=None, **model_kwargs
1779
+ ):
1780
+ input_shape = input_ids.shape
1781
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1782
+ if attention_mask is None:
1783
+ attention_mask = input_ids.new_ones(input_shape)
1784
+
1785
+ # # cut decoder_input_ids if past_key_values is used
1786
+ # if past_key_values is not None:
1787
+ # input_ids = input_ids[:, -1:]
1788
+
1789
+ return {
1790
+ "input_ids": input_ids,
1791
+ "pixel_values": pixel_values,
1792
+ "video_pixel_values": video_pixel_values,
1793
+ "attention_mask": attention_mask,
1794
+ # "past_key_values": past_key_values,
1795
+ # "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1796
+ # "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1797
+ "is_decoder": True,
1798
+ }
1799
+
1800
+
1801
+ def bloom_forward(
1802
+ self,
1803
+ input_ids: Optional[torch.LongTensor] = None,
1804
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1805
+ attention_mask: Optional[torch.Tensor] = None,
1806
+ head_mask: Optional[torch.LongTensor] = None,
1807
+ inputs_embeds: Optional[torch.LongTensor] = None,
1808
+ use_cache: Optional[bool] = None,
1809
+ output_attentions: Optional[bool] = None,
1810
+ output_hidden_states: Optional[bool] = None,
1811
+ return_dict: Optional[bool] = None,
1812
+ **deprecated_arguments,
1813
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
1814
+ if deprecated_arguments.pop("position_ids", False) is not False:
1815
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
1816
+ warnings.warn(
1817
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
1818
+ " passing `position_ids`.",
1819
+ FutureWarning,
1820
+ )
1821
+ if len(deprecated_arguments) > 0:
1822
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
1823
+
1824
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1825
+ output_hidden_states = (
1826
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1827
+ )
1828
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1829
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1830
+
1831
+ if input_ids is not None and inputs_embeds is not None:
1832
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1833
+ elif input_ids is not None:
1834
+ batch_size, seq_length = input_ids.shape
1835
+ elif inputs_embeds is not None:
1836
+ batch_size, seq_length, _ = inputs_embeds.shape
1837
+ else:
1838
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1839
+
1840
+ if past_key_values is None:
1841
+ past_key_values = tuple([None] * len(self.h))
1842
+
1843
+ # Prepare head mask if needed
1844
+ # 1.0 in head_mask indicate we keep the head
1845
+ # attention_probs has shape batch_size x num_heads x N x N
1846
+ # head_mask has shape n_layer x batch x num_heads x N x N
1847
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
1848
+
1849
+ if inputs_embeds is None:
1850
+ inputs_embeds = self.word_embeddings(input_ids)
1851
+ inputs_embeds = self.word_embeddings_layernorm(inputs_embeds)
1852
+
1853
+ hidden_states = inputs_embeds
1854
+
1855
+ presents = () if use_cache else None
1856
+ all_self_attentions = () if output_attentions else None
1857
+ all_hidden_states = () if output_hidden_states else None
1858
+
1859
+ if self.gradient_checkpointing and self.training:
1860
+ if use_cache:
1861
+ logger.warning_once(
1862
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1863
+ )
1864
+ use_cache = False
1865
+
1866
+ # Compute alibi tensor: check build_alibi_tensor documentation
1867
+ seq_length_with_past = seq_length
1868
+ past_key_values_length = 0
1869
+ if past_key_values[0] is not None:
1870
+ past_key_values_length = past_key_values[0][0].shape[2]
1871
+ seq_length_with_past = seq_length_with_past + past_key_values_length
1872
+ if attention_mask is None:
1873
+ attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
1874
+ else:
1875
+ attention_mask = attention_mask.to(hidden_states.device)
1876
+
1877
+ alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
1878
+
1879
+ causal_mask = self._prepare_attn_mask(
1880
+ attention_mask,
1881
+ input_shape=(batch_size, seq_length),
1882
+ past_key_values_length=past_key_values_length,
1883
+ )
1884
+
1885
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
1886
+ if output_hidden_states:
1887
+ all_hidden_states = all_hidden_states + (hidden_states,)
1888
+
1889
+ if self.gradient_checkpointing and self.training:
1890
+
1891
+ def create_custom_forward(module):
1892
+ def custom_forward(*inputs):
1893
+ # None for past_key_value
1894
+ return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
1895
+
1896
+ return custom_forward
1897
+
1898
+ outputs = torch.utils.checkpoint.checkpoint(
1899
+ create_custom_forward(block),
1900
+ hidden_states,
1901
+ alibi,
1902
+ causal_mask,
1903
+ layer_past,
1904
+ head_mask[i],
1905
+ )
1906
+ else:
1907
+ outputs = block(
1908
+ hidden_states,
1909
+ layer_past=layer_past,
1910
+ attention_mask=causal_mask,
1911
+ head_mask=head_mask[i],
1912
+ use_cache=use_cache,
1913
+ output_attentions=output_attentions,
1914
+ alibi=alibi,
1915
+ )
1916
+
1917
+ hidden_states = outputs[0]
1918
+ if use_cache is True:
1919
+ presents = presents + (outputs[1],)
1920
+
1921
+ if output_attentions:
1922
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
1923
+
1924
+ # Add last hidden state
1925
+ hidden_states = self.ln_f(hidden_states)
1926
+
1927
+ if output_hidden_states:
1928
+ all_hidden_states = all_hidden_states + (hidden_states,)
1929
+
1930
+ if not return_dict:
1931
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
1932
+
1933
+ return BaseModelOutputWithPastAndCrossAttentions(
1934
+ last_hidden_state=hidden_states,
1935
+ past_key_values=presents,
1936
+ hidden_states=all_hidden_states,
1937
+ attentions=all_self_attentions,
1938
+ )
pipeline_video/mplug_owl_video/processing_mplug_owl.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import random
3
+ import torch
4
+ import torch.utils.checkpoint
5
+
6
+ from transformers.processing_utils import ProcessorMixin
7
+ from transformers.tokenization_utils_base import BatchEncoding
8
+ from transformers.models.clip.image_processing_clip import CLIPImageProcessor
9
+ from .tokenization_mplug_owl import MplugOwlTokenizer
10
+
11
+ from decord import VideoReader
12
+ import numpy as np
13
+ from PIL import Image
14
+
15
+ def get_index(num_frames, num_segments):
16
+ seg_size = float(num_frames - 1) / num_segments
17
+ start = int(seg_size / 2)
18
+ offsets = np.array([
19
+ start + int(np.round(seg_size * idx)) for idx in range(num_segments)
20
+ ])
21
+ return offsets
22
+
23
+ def load_video(path, num_frames=4):
24
+ vr = VideoReader(path, height=224, width=224)
25
+ total_frames = len(vr)
26
+ frame_indices = get_index(total_frames, num_frames)
27
+ images_group = list()
28
+ for frame_index in frame_indices:
29
+ img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')
30
+ images_group.append(img)
31
+ return images_group
32
+
33
+ class MplugOwlProcessor(ProcessorMixin):
34
+ attributes = []
35
+ tokenizer_class = ("MplugOwlTokenizer")
36
+
37
+ def __init__(self, image_processor=None, tokenizer=None, **kwargs):
38
+
39
+ super().__init__(**kwargs)
40
+ self.tokens_to_generate = 0
41
+ self.image_processor = image_processor
42
+ self.tokenizer = tokenizer
43
+ self.add_BOS = True
44
+
45
+ def __call__(self, videos=None, text=None, num_frames=4, return_tensors=None, **kwargs):
46
+
47
+ if text is not None:
48
+ encoding = tokenize_prompts(
49
+ prompts=text,
50
+ tokens_to_generate=self.tokens_to_generate,
51
+ add_BOS=self.add_BOS,
52
+ tokenizer=self.tokenizer,
53
+ ignore_dist=True,
54
+ **kwargs,
55
+ )
56
+
57
+ if videos is not None:
58
+ video_features = []
59
+ for video in videos:
60
+ video_frames = load_video(video, num_frames)
61
+ video_feature = self.image_processor(video_frames, return_tensors=return_tensors, **kwargs)['pixel_values']
62
+ video_features.append(video_feature)
63
+ video_features = torch.stack(video_features, dim=0)
64
+ video_features = video_features.permute(0, 2, 1, 3, 4)
65
+
66
+ if text is not None and videos is not None:
67
+ encoding["video_pixel_values"] = video_features
68
+ return encoding
69
+ if text is not None and videos is None:
70
+ return encoding
71
+
72
+ return video_features
73
+
74
+ def batch_decode(self, skip_special_tokens=True, *args, **kwargs):
75
+ """
76
+ This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
77
+ refer to the docstring of this method for more information.
78
+ """
79
+ return self.tokenizer.batch_decode(*args, skip_special_tokens=skip_special_tokens, **kwargs)
80
+
81
+ def decode(self, skip_special_tokens=True, *args, **kwargs):
82
+ """
83
+ This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
84
+ the docstring of this method for more information.
85
+ """
86
+ return self.tokenizer.decode(*args, skip_special_tokens=skip_special_tokens, **kwargs)
87
+
88
+
89
+ class MplugOwlImageProcessor(CLIPImageProcessor):
90
+ pass
91
+
92
+
93
+ def detokenize_generations(tokens_gpu_tensor, lengths_gpu_tensor, return_segments, tokenizer):
94
+ """Detokenize the generated tokens."""
95
+
96
+ prompts_plus_generations = []
97
+ if return_segments:
98
+ prompts_plus_generations_segments = []
99
+
100
+ tokens = tokens_gpu_tensor.cpu().numpy().tolist()
101
+ lengths = lengths_gpu_tensor.cpu().numpy().tolist()
102
+ for sequence_tokens, length in zip(tokens, lengths):
103
+ sequence_tokens = sequence_tokens[:length]
104
+ prompts_plus_generations.append(tokenizer.detokenize(sequence_tokens))
105
+ if return_segments:
106
+ from tokenizers.decoders import Metaspace
107
+
108
+ if hasattr(tokenizer, "tokenizer"):
109
+ if isinstance(tokenizer.tokenizer.decoder, Metaspace):
110
+ words = tokenizer.tokenizer.decode(sequence_tokens)
111
+ else:
112
+ words = []
113
+ for token in sequence_tokens:
114
+ word = tokenizer.tokenizer.decoder[token]
115
+ word = bytearray([tokenizer.tokenizer.byte_decoder[c] for c in word]).decode(
116
+ "utf-8", errors="replace"
117
+ )
118
+ words.append(word)
119
+ prompts_plus_generations_segments.append(words)
120
+ else:
121
+ words = tokenizer.detokenize(sequence_tokens)
122
+ # else:
123
+ # words = []
124
+ # for token in sequence_tokens:
125
+ # word = tokenizer.tokenizer.decoder[token]
126
+ # word = bytearray(
127
+ # [tokenizer.tokenizer.byte_decoder[c] for c in word]).decode(
128
+ # 'utf-8', errors='replace')
129
+ # words.append(word)
130
+ prompts_plus_generations_segments.append(words)
131
+
132
+ if return_segments:
133
+ return tokens, prompts_plus_generations, prompts_plus_generations_segments
134
+
135
+ return tokens, prompts_plus_generations
136
+
137
+
138
+ def tokenize_prompts(
139
+ prompts=None, tokens_to_generate=None, add_BOS=None, rank=0, tokenizer=None, ignore_dist=False, **kwargs
140
+ ):
141
+ """Tokenize prompts and make them avaiable on all ranks."""
142
+
143
+ # On all ranks set to None so we can pass them to functions
144
+ prompts_tokens_cuda_long_tensor = None
145
+ prompts_length_cuda_long_tensor = None
146
+
147
+ # On the specified rank, build the above.
148
+ attention_mask = None
149
+ if ignore_dist or torch.distributed.get_rank() == rank:
150
+ assert prompts is not None
151
+ assert tokens_to_generate is not None
152
+ # Tensor of tokens padded and their unpadded length.
153
+ prompts_tokens_cuda_long_tensor, prompts_length_cuda_long_tensor, attention_mask = _tokenize_prompts_and_batch(
154
+ prompts, tokens_to_generate, add_BOS, tokenizer, **kwargs
155
+ )
156
+ # We need the sizes of these tensors for the boradcast
157
+ [
158
+ prompts_tokens_cuda_long_tensor.size(0), # Batch size
159
+ prompts_tokens_cuda_long_tensor.size(1),
160
+ ] # Sequence lenght
161
+
162
+ return {
163
+ "input_ids": prompts_tokens_cuda_long_tensor,
164
+ "attention_mask": attention_mask,
165
+ # "prompt_length": prompts_length_cuda_long_tensor,
166
+ }
167
+
168
+
169
+ def _tokenize_prompts_and_batch(prompts, tokens_to_generate, add_BOS, tokenizer, **kwargs):
170
+ """Given a set of prompts and number of tokens to generate:
171
+ - tokenize prompts
172
+ - set the sequence length to be the max of length of prompts
173
+ plus the number of tokens we would like to generate
174
+ - pad all the sequences to this length so we can convert them
175
+ into a 2D tensor.
176
+ """
177
+
178
+ # Tokenize all the prompts.
179
+ # if add_BOS:
180
+ # prompts_tokens = [[tokenizer.bos] + tokenizer.tokenize(prompt)
181
+ # for prompt in prompts]
182
+ # else:
183
+ # prompts_tokens = [tokenizer.tokenize(prompt) for prompt in prompts]
184
+
185
+ prompts_tokens = [_tokenize_prompt(prompt, tokenizer, add_BOS, **kwargs) for prompt in prompts]
186
+
187
+ # Now we have a list of list of tokens which each list has a different
188
+ # size. We want to extend this list to:
189
+ # - incorporate the tokens that need to be generated
190
+ # - make all the sequences equal length.
191
+ # Get the prompts length.
192
+ prompts_length = [len(prompt_tokens) for prompt_tokens in prompts_tokens]
193
+ # Get the max prompts length.
194
+ max_prompt_len = max(prompts_length)
195
+ # Number of tokens in the each sample of the batch.
196
+ samples_length = max_prompt_len + tokens_to_generate
197
+ # Now update the list of list to be of the same size: samples_length.
198
+ for prompt_tokens, prompt_length in zip(prompts_tokens, prompts_length):
199
+ padding_size = samples_length - prompt_length
200
+ prompt_tokens.extend([tokenizer.eos_token_id] * padding_size)
201
+
202
+ # Now we are in a structured format, we can convert to tensors.
203
+ prompts_tokens_tensor = torch.LongTensor(prompts_tokens)
204
+ prompts_length_tensor = torch.LongTensor(prompts_length)
205
+ attention_mask = torch.zeros(prompts_tokens_tensor.shape[:2])
206
+ for i, l in enumerate(prompts_length_tensor):
207
+ attention_mask[i, :l] = 1
208
+ return prompts_tokens_tensor, prompts_length_tensor, attention_mask
209
+
210
+
211
+ def _tokenize_prompt(
212
+ prompt, tokenizer, add_BOS=False,
213
+ media_info={"<image>": 65, "<|video|>": 65},
214
+ **kwargs
215
+ ):
216
+ media_tokens = {k: -int(i + 1) for i, k in enumerate(media_info.keys())}
217
+ media_lengths = media_info.copy()
218
+
219
+ if add_BOS:
220
+ prompt_chunk = [tokenizer.bos_token_id]
221
+ else:
222
+ prompt_chunk = []
223
+
224
+ # Pure Text
225
+ if all([media_token not in prompt for media_token in media_tokens.keys()]):
226
+ enc_chunk = prompt_chunk + tokenizer(prompt, add_special_tokens=False, **kwargs)["input_ids"]
227
+
228
+ # Multi-Modal Text
229
+ else:
230
+ enc_chunk = prompt_chunk
231
+ pattern = "|".join(map(re.escape, list(media_tokens.keys())))
232
+ chunk_strs = re.split(f"({pattern})", prompt)
233
+ chunk_strs = [x for x in chunk_strs if len(x) > 0]
234
+ for idx, chunk_str in enumerate(chunk_strs):
235
+ if chunk_str in media_tokens:
236
+ enc_chunk += [media_tokens[chunk_str]] * media_lengths[chunk_str]
237
+ else:
238
+ tmp_chunk = tokenizer(chunk_str, add_special_tokens=False)["input_ids"]
239
+ # if idx < len(chunk_strs) - 1: # Last chunk should not have eos
240
+ # tmp_chunk += [tokenizer.eod_id]
241
+ enc_chunk += tmp_chunk
242
+ return enc_chunk
243
+
244
+
245
+ if __name__ == "__main__":
246
+ pass
pipeline_video/mplug_owl_video/tokenization_mplug_owl.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 x-plug and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for MplugOwl."""
16
+
17
+ from transformers.utils import logging
18
+ from transformers.models.llama.tokenization_llama import LlamaTokenizer
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
24
+
25
+ PRETRAINED_VOCAB_FILES_MAP = {
26
+ "vocab_file": {
27
+ "MAGAer13/mplug-owl-llama-7b": "https://huggingface.co/MAGAer13/mplug-owl-llama-7b/resolve/main/vocab.txt",
28
+ },
29
+ }
30
+
31
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
32
+ "MAGAer13/mplug-owl-llama-7b": 2048,
33
+ }
34
+
35
+
36
+ class MplugOwlTokenizer(LlamaTokenizer):
37
+ def __init__(
38
+ self,
39
+ vocab_file,
40
+ unk_token="<unk>",
41
+ bos_token="<s>",
42
+ eos_token="</s>",
43
+ pad_token="<unk>",
44
+ sp_model_kwargs=None,
45
+ add_bos_token=False,
46
+ add_eos_token=False,
47
+ clean_up_tokenization_spaces=False,
48
+ **kwargs,
49
+ ):
50
+ super().__init__(
51
+ vocab_file,
52
+ unk_token,
53
+ bos_token,
54
+ eos_token,
55
+ pad_token,
56
+ sp_model_kwargs,
57
+ add_bos_token,
58
+ add_eos_token,
59
+ clean_up_tokenization_spaces,
60
+ **kwargs,
61
+ )
62
+ self.eod_id = self.eos_token_id
pipeline_video/nle_inference.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import csv
3
+ import json
4
+ import torch
5
+ import argparse
6
+ import pandas as pd
7
+ from tqdm import tqdm
8
+ from peft import LoraConfig, get_peft_model
9
+ from torch.utils.data import Dataset, DataLoader
10
+ from transformers.models.llama.tokenization_llama import LlamaTokenizer
11
+ from mplug_owl_video.modeling_mplug_owl import MplugOwlForConditionalGeneration
12
+ from mplug_owl_video.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor
13
+
14
+ parser = argparse.ArgumentParser()
15
+
16
+ parser.add_argument('--input_file', type = str, required = True, help = 'input csv file')
17
+ parser.add_argument('--output_file', type = str, help = 'output csv file')
18
+ parser.add_argument('--pretrained_ckpt', type = str, required = True, help = 'pretrained ckpt')
19
+ parser.add_argument('--trained_ckpt', type = str, help = 'trained ckpt')
20
+ parser.add_argument('--lora_r', type = int, default = 32)
21
+ parser.add_argument('--use_lora', action = 'store_true', help = 'lora model')
22
+ parser.add_argument('--all_params', action = 'store_true', help = 'all params')
23
+ parser.add_argument('--batch_size', type = int, default = 1)
24
+ parser.add_argument('--num_frames', type = int, default = 32)
25
+
26
+ args = parser.parse_args()
27
+
28
+ PROMPT_FEEDBACK = '''The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
29
+ Human: <|video|>
30
+ Human: What is the misalignment between this video and the description: "{caption}"?
31
+ AI: '''
32
+
33
+ generate_kwargs = {
34
+ 'do_sample': True,
35
+ 'top_k': 5,
36
+ 'max_length': 512
37
+ }
38
+
39
+ class VideoCaptionDataset(Dataset):
40
+
41
+ def __init__(self, input_file):
42
+ self.data = pd.read_csv(input_file)
43
+
44
+ def __len__(self):
45
+ return len(self.data)
46
+
47
+ def __getitem__(self, index):
48
+ item = {}
49
+ item['videopath'] = self.data.iloc[index]['videopath']
50
+ item['neg_caption'] = self.data.iloc[index]['neg_caption']
51
+ return item
52
+
53
+ def get_nle(args, model, processor, tokenizer, dataloader):
54
+
55
+ with torch.no_grad():
56
+ for _, batch in tqdm(enumerate(dataloader)):
57
+ videopaths = batch['videopath']
58
+ neg_caption = batch['neg_caption'][0]
59
+ prompts = [PROMPT_FEEDBACK.format(caption = neg_caption)]
60
+ inputs = processor(text=prompts, videos=videopaths, num_frames=args.num_frames, return_tensors='pt')
61
+ inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()}
62
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
63
+ res = model.generate(**inputs, **generate_kwargs)
64
+ generated_nle = tokenizer.decode(res.tolist()[0], skip_special_tokens=True)
65
+
66
+ with open(args.output_file, 'a') as f:
67
+ writer = csv.writer(f)
68
+ writer.writerow([videopaths[0], neg_caption, generated_nle])
69
+
70
+ def main():
71
+
72
+ # Create dataloader
73
+ dataset = VideoCaptionDataset(args.input_file)
74
+ dataloader = DataLoader(dataset, batch_size = args.batch_size)
75
+
76
+ pretrained_ckpt = args.pretrained_ckpt
77
+
78
+ # Processors
79
+ tokenizer = LlamaTokenizer.from_pretrained(pretrained_ckpt)
80
+ image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt)
81
+ processor = MplugOwlProcessor(image_processor, tokenizer)
82
+
83
+ # Instantiate model
84
+ model = MplugOwlForConditionalGeneration.from_pretrained(
85
+ pretrained_ckpt,
86
+ torch_dtype=torch.bfloat16,
87
+ device_map={'':0}
88
+ )
89
+
90
+ if args.use_lora:
91
+ for name, param in model.named_parameters():
92
+ param.requires_grad = False
93
+ if args.all_params:
94
+ peft_config = LoraConfig(
95
+ target_modules=r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj|gate_proj|down_proj|up_proj)',
96
+ inference_mode=True,
97
+ r=args.lora_r,
98
+ lora_alpha=16,
99
+ lora_dropout=0.05
100
+ )
101
+ else:
102
+ peft_config = LoraConfig(
103
+ target_modules=r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj)',
104
+ inference_mode=True,
105
+ r=args.lora_r,
106
+ lora_alpha=16,
107
+ lora_dropout=0.05
108
+ )
109
+
110
+ model = get_peft_model(model, peft_config)
111
+ model.print_trainable_parameters()
112
+ with open(args.trained_ckpt, 'rb') as f:
113
+ ckpt = torch.load(f, map_location = torch.device(f"cuda:0"))
114
+ model.load_state_dict(ckpt)
115
+ model = model.to(torch.bfloat16)
116
+ print('Model Loaded')
117
+
118
+ model.eval()
119
+
120
+ # get nle
121
+ get_nle(args, model, processor, tokenizer, dataloader)
122
+
123
+
124
+
125
+ if __name__ == "__main__":
126
+ main()
pipeline_video/train.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from functools import partial
3
+ import os
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.utils.data import DataLoader
7
+ from torch.utils.data.distributed import DistributedSampler
8
+
9
+ from sconf import Config
10
+ from icecream import ic
11
+ from peft import LoraConfig, get_peft_model
12
+ from transformers import Trainer
13
+ from transformers.training_args import TrainingArguments
14
+
15
+ from mplug_owl_video.modeling_mplug_owl import MplugOwlForConditionalGeneration
16
+ from transformers.models.llama.tokenization_llama import LlamaTokenizer
17
+ from data_utils import train_valid_test_datasets_provider
18
+ from utils import batchify, set_args
19
+
20
+
21
+ parser = argparse.ArgumentParser()
22
+ # Model
23
+ parser.add_argument('--pretrained-ckpt', type=str, default='MAGAer13/mplug-owl-llama-7b-pt',
24
+ help='Path to the pretrained checkpoint.')
25
+ parser.add_argument('--finetuned-ckpt', type=str, default=None,
26
+ help='Path to the finetuned checkpoint.')
27
+ parser.add_argument('--inference_mode', type=bool, default=False,
28
+ help='The inference mode.')
29
+ parser.add_argument('--seq-length', type=int, default=1024,
30
+ help='Maximum sequence length to process.')
31
+
32
+ parser.add_argument('--use-lora', action='store_true', help='LORA.')
33
+ parser.add_argument('--all-params', action='store_true', help='All params in LORA')
34
+ parser.add_argument('--lora-r', type=int, default=8,
35
+ help='curvature.')
36
+ parser.add_argument('--lora-alpha', type=int, default=32,
37
+ help='The initialization coefficient of lora-alpha.')
38
+ parser.add_argument('--lora-dropout', type=int, default=0.05,
39
+ help='The initialization coefficient of lora_dropout.')
40
+ parser.add_argument('--bf16', action='store_true', default=False,
41
+ help='Run model in bfloat16 mode.')
42
+
43
+ parser.add_argument('--wandb_run_name', type=str, default="test", help='wandb run name.')
44
+
45
+ # Data
46
+ parser.add_argument('--mm-config', type=str, default=None, help='Multimodal Config.')
47
+ parser.add_argument('--num-workers', type=int, default=8,
48
+ help="Dataloader number of workers.")
49
+
50
+ # Training HyperParameters
51
+ parser.add_argument('--train-epochs', type=int, default=3,
52
+ help='Total number of epochs to train over all '
53
+ 'training runs.')
54
+ parser.add_argument('--micro-batch-size', type=int, default=None,
55
+ help='Batch size per model instance (local batch size). '
56
+ 'Global batch size is local batch size times data '
57
+ 'parallel size times number of micro batches.')
58
+ parser.add_argument('--lr', type=float, default=None,
59
+ help='Initial learning rate. Depending on decay style '
60
+ 'and initial warmup, the learing rate at each '
61
+ 'iteration would be different.')
62
+ parser.add_argument('--min-lr', type=float, default=1e-6,
63
+ help='Minumum value for learning rate. The scheduler'
64
+ 'clip values below this threshold.')
65
+ parser.add_argument('--weight-decay', type=float, default=0.01,
66
+ help='Weight decay coefficient for L2 regularization.')
67
+ parser.add_argument('--gradient-accumulation-steps', type=int, default=8,
68
+ help='The gradient accumulation steps.')
69
+ parser.add_argument('--clip-grad', type=float, default=1.0,
70
+ help='Gradient clipping based on global L2 norm.')
71
+ parser.add_argument('--adam-beta1', type=float, default=0.9,
72
+ help='First coefficient for computing running averages '
73
+ 'of gradient and its square')
74
+ parser.add_argument('--adam-beta2', type=float, default=0.999,
75
+ help='Second coefficient for computing running averages '
76
+ 'of gradient and its square')
77
+ parser.add_argument('--adam-eps', type=float, default=1e-08,
78
+ help='Term added to the denominator to improve'
79
+ 'numerical stability')
80
+
81
+ parser.add_argument('--num-warmup-steps', type=int, default=50,
82
+ help='The number of warmup steps.')
83
+ parser.add_argument('--num-training-steps', type=int, default=4236,
84
+ help='The number of total training steps for lr scheduler.')
85
+ parser.add_argument('--loss_objective', default = 'sequential', choices = ['sequential'], help = 'toggle loss objectives')
86
+
87
+ # Evaluation & Save
88
+ parser.add_argument('--save-path', type=str, default=None,
89
+ help='Output directory to save checkpoints to.')
90
+ parser.add_argument('--save-interval', type=int, default=None,
91
+ help='Number of iterations between checkpoint saves.')
92
+ parser.add_argument('--eval-iters', type=int, default=100,
93
+ help='Number of iterations to run for evaluation'
94
+ 'validation/test for.')
95
+
96
+ # Other
97
+ parser.add_argument('--gradient-checkpointing', action='store_true',
98
+ help='The gradient checkpointing.')
99
+ parser.add_argument('--logging-nan-inf-filter', action='store_true',
100
+ help='The logging nan inf filter.')
101
+ parser.add_argument('--ddp-find-unused-parameters', action='store_true',
102
+ help='unused parameters finding.')
103
+ parser.add_argument('--do-train', action='store_true', default=True,
104
+ help='Whether to do training.')
105
+ parser.add_argument('--local_rank', type=int, default=-1,
106
+ help='Local rank')
107
+
108
+ softmax = nn.Softmax(dim=2)
109
+ sigm = torch.nn.Sigmoid()
110
+
111
+
112
+ class CustomTrainer(Trainer):
113
+ def __init__(self, **kwargs):
114
+ super().__init__(**kwargs)
115
+
116
+ def get_train_dataloader(self) -> DataLoader:
117
+ dataset = self.train_dataset
118
+ sampler = DistributedSampler(dataset)
119
+ return torch.utils.data.DataLoader(
120
+ dataset, batch_size=self._train_batch_size,
121
+ sampler=sampler,
122
+ num_workers=self.args.dataloader_num_workers,
123
+ drop_last=True,
124
+ pin_memory=True,
125
+ collate_fn=batchify)
126
+
127
+ def get_eval_dataloader(self, eval_dataset) -> DataLoader:
128
+ dataset = self.eval_dataset
129
+ sampler = DistributedSampler(dataset, shuffle=False)
130
+ return torch.utils.data.DataLoader(
131
+ dataset, batch_size=self._train_batch_size,
132
+ sampler=sampler,
133
+ num_workers=self.args.dataloader_num_workers,
134
+ drop_last=True,
135
+ pin_memory=True,
136
+ collate_fn=batchify)
137
+
138
+ def compute_loss(self, model, inputs, return_outputs = False):
139
+ outputs = model(pixel_values = inputs['pixel_values'], video_pixel_values = inputs['video_pixel_values'], labels = inputs['labels'],
140
+ num_images = inputs['num_images'], num_videos = inputs['num_videos'], input_ids = inputs['input_ids'], non_padding_mask = inputs['non_padding_mask'], \
141
+ non_media_mask = inputs['non_media_mask'], prompt_mask = inputs['prompt_mask'])
142
+ loss = outputs.loss
143
+ return loss
144
+
145
+ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys = None):
146
+ for k, v in inputs.items():
147
+ if torch.is_tensor(v):
148
+ if v.dtype == torch.float:
149
+ inputs[k] = v.bfloat16()
150
+ inputs[k] = inputs[k].to(model.device)
151
+ with torch.no_grad():
152
+ loss = self.compute_loss(model, inputs)
153
+ loss = loss.detach()
154
+ return loss, None, None
155
+
156
+ def main():
157
+ args, left_argv = parser.parse_known_args()
158
+ ic(left_argv)
159
+ config = Config(args.mm_config)
160
+
161
+ set_args(args)
162
+ print(args.pretrained_ckpt)
163
+ model = MplugOwlForConditionalGeneration.from_pretrained(
164
+ args.pretrained_ckpt,
165
+ torch_dtype=torch.bfloat16 if args.bf16 else torch.half,
166
+ )
167
+ tokenizer = LlamaTokenizer.from_pretrained(args.pretrained_ckpt)
168
+ if args.use_lora:
169
+ for name, param in model.named_parameters():
170
+ param.requires_grad = False
171
+ if args.all_params:
172
+ peft_config = LoraConfig(
173
+ target_modules=r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj|gate_proj|down_proj|up_proj)',
174
+ inference_mode=args.inference_mode,
175
+ r=args.lora_r,
176
+ lora_alpha=args.lora_alpha,
177
+ lora_dropout=args.lora_dropout
178
+ )
179
+ else:
180
+ peft_config = LoraConfig(
181
+ target_modules=r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj)',
182
+ inference_mode=args.inference_mode,
183
+ r=args.lora_r,
184
+ lora_alpha=args.lora_alpha,
185
+ lora_dropout=args.lora_dropout
186
+ )
187
+ model = get_peft_model(model, peft_config)
188
+ model.print_trainable_parameters()
189
+
190
+ if args.gradient_checkpointing:
191
+ def make_inputs_require_grad(module, input, output):
192
+ output.requires_grad_(True)
193
+ model.language_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
194
+ model.language_model.apply(
195
+ partial(model.language_model._set_gradient_checkpointing, value=True))
196
+
197
+ else:
198
+ for name, param in model.named_parameters():
199
+ if 'language_model' in name:
200
+ param.requires_grad = True
201
+ else:
202
+ param.requires_grad = False
203
+ if args.gradient_checkpointing:
204
+ model.language_model.apply(
205
+ partial(model.language_model._set_gradient_checkpointing, value=True))
206
+
207
+ model.train()
208
+
209
+ train_data, valid_data = train_valid_test_datasets_provider(
210
+ config.data_files, config=config,
211
+ tokenizer=tokenizer, seq_length=args.seq_length, loss_objective = args.loss_objective
212
+ )
213
+
214
+ if len(valid_data) > 500:
215
+ valid_data = torch.utils.data.Subset(valid_data, range(500))
216
+
217
+ trainer = CustomTrainer(
218
+ model=model,
219
+ train_dataset=train_data,
220
+ eval_dataset=valid_data,
221
+ args=TrainingArguments(
222
+ learning_rate=args.lr,
223
+ warmup_steps=args.num_warmup_steps,
224
+ do_train=args.do_train,
225
+ do_eval=True,
226
+ num_train_epochs=args.train_epochs,
227
+ output_dir=args.save_path,
228
+ save_strategy='epoch',
229
+ evaluation_strategy='steps',
230
+ eval_steps=args.eval_iters,
231
+ per_device_train_batch_size=args.micro_batch_size,
232
+ max_grad_norm=args.clip_grad,
233
+ weight_decay=args.weight_decay,
234
+ bf16=args.bf16,
235
+ fp16=not args.bf16,
236
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
237
+ gradient_checkpointing=args.gradient_checkpointing,
238
+ logging_steps=args.eval_iters//10,
239
+ logging_dir=args.save_path,
240
+ logging_nan_inf_filter=args.logging_nan_inf_filter,
241
+ ddp_find_unused_parameters=args.ddp_find_unused_parameters,
242
+ run_name=args.wandb_run_name,
243
+ prediction_loss_only=True,
244
+ ),
245
+ )
246
+ trainer.loss_objective = args.loss_objective
247
+ trainer.tokenizer = tokenizer
248
+
249
+ if torch.__version__ >= "2" and sys.platform != "win32":
250
+ model = torch.compile(model)
251
+
252
+ if args.local_rank == 0:
253
+ with open(os.path.join(args.save_path, "params.txt"), "w") as file:
254
+ for key in sorted(vars(args)):
255
+ value = getattr(args, key)
256
+ file.write(f"{key}: {value}\n")
257
+
258
+ trainer.train()
259
+
260
+ model.save_pretrained(args.save_path)
261
+
262
+ if __name__ == '__main__':
263
+ main()
pipeline_video/utils.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ import numpy as np
5
+ from icecream import ic
6
+
7
+ def print_rank_0(message):
8
+ """If distributed is initialized, print only on rank 0."""
9
+ if torch.distributed.is_initialized():
10
+ if torch.distributed.get_rank() == 0:
11
+ print(message, flush=True)
12
+ else:
13
+ print(message, flush=True)
14
+
15
+ ARGS = None
16
+ def set_args(args):
17
+ global ARGS
18
+ ARGS = args
19
+
20
+ def get_args():
21
+ return ARGS
22
+
23
+ TOKENIZER = None
24
+ def set_tokenizer(tokenizer):
25
+ global TOKENIZER
26
+ TOKENIZER = tokenizer
27
+
28
+ def get_tokenizer():
29
+ return TOKENIZER
30
+ from torch import distributed as dist
31
+
32
+ class worker_init:
33
+ def __init__(self, epoch_id):
34
+ self.epoch_id = epoch_id
35
+ def _worker_init_fn(self, worker_id):
36
+ random.seed(worker_id + self.epoch_id*1e4 + dist.get_rank()*1e8)
37
+
38
+
39
+ def batchify(batch):
40
+ # collate_fn
41
+ video = [data["video"] if data["video"] is not None else None for data in batch]
42
+ if all([img is None for img in video]):
43
+ video = None
44
+ else:
45
+ video = torch.cat([img for img in video if img is not None], dim=0)
46
+ num_videos_per_sample = torch.LongTensor([data["video"].size(0) if data['video'] is not None else 0 for data in batch])
47
+ num_images_per_sample = torch.LongTensor([0 for data in batch])
48
+
49
+ text = torch.stack([torch.LongTensor(data["text"]['input_ids']) for data in batch], dim=0)
50
+ non_padding_mask = torch.stack([torch.LongTensor(data["text"]['non_padding_mask']) for data in batch], dim=0)
51
+ non_media_mask = torch.stack([torch.LongTensor(data["text"]['non_media_mask']) for data in batch], dim=0)
52
+ prompt_mask = torch.stack([torch.LongTensor(data["text"]['prompt_mask']) for data in batch], dim=0)
53
+ videopaths = [data["videopath"] for data in batch]
54
+ captions = [data["caption"] for data in batch]
55
+ output_batch = {
56
+ "pixel_values": None,
57
+ "video_pixel_values": video,
58
+ "input_ids": text.long(),
59
+ "labels": text.long().clone(),
60
+ "num_images": num_images_per_sample.long(),
61
+ "num_videos": num_videos_per_sample.long(),
62
+ "non_padding_mask": non_padding_mask.long(),
63
+ "non_media_mask": non_media_mask.long(),
64
+ "prompt_mask": prompt_mask.long(),
65
+ "videopaths": videopaths,
66
+ "captions": captions,
67
+ }
68
+
69
+ return output_batch
70
+
71
+
72
+ def get_param_groups(modules,
73
+ no_weight_decay_cond,
74
+ scale_lr_cond,
75
+ lr_mult):
76
+ """creates param groups based on weight decay condition (regularized vs non regularized)
77
+ and learning rate scale condition (args.lr vs lr_mult * args.lr)
78
+ scale_lr_cond is used during finetuning where head of the network requires a scaled
79
+ version of the base learning rate.
80
+ """
81
+ wd_no_scale_lr = []
82
+ wd_scale_lr = []
83
+ no_wd_no_scale_lr = []
84
+ no_wd_scale_lr = []
85
+ for module in modules:
86
+ for name, param in module.named_parameters():
87
+ if not param.requires_grad:
88
+ continue
89
+
90
+ if no_weight_decay_cond is not None:
91
+ no_wd = no_weight_decay_cond(name, param)
92
+ else:
93
+ # do not regularize biases nor Norm parameters
94
+ no_wd = name.endswith(".bias") or len(param.shape) == 1
95
+
96
+ if scale_lr_cond is not None:
97
+ scale_lr = scale_lr_cond(name, param)
98
+ else:
99
+ scale_lr = False
100
+
101
+ if not no_wd and not scale_lr:
102
+ wd_no_scale_lr.append(param)
103
+ elif not no_wd and scale_lr:
104
+ wd_scale_lr.append(param)
105
+ elif no_wd and not scale_lr:
106
+ no_wd_no_scale_lr.append(param)
107
+ else:
108
+ no_wd_scale_lr.append(param)
109
+
110
+ param_groups = []
111
+ if len(wd_no_scale_lr):
112
+ param_groups.append(
113
+ {'params': wd_no_scale_lr, 'wd_mult': 1.0, 'lr_mult': 1.0})
114
+ if len(wd_scale_lr):
115
+ param_groups.append(
116
+ {'params': wd_scale_lr, 'wd_mult': 1.0, 'lr_mult': lr_mult})
117
+ if len(no_wd_no_scale_lr):
118
+ param_groups.append({'params': no_wd_no_scale_lr,
119
+ 'wd_mult': 0.0, 'lr_mult': 1.0})
120
+ if len(no_wd_scale_lr):
121
+ param_groups.append(
122
+ {'params': no_wd_scale_lr, 'wd_mult': 0.0, 'lr_mult': lr_mult})
123
+
124
+ return param_groups
125
+
126
+ def get_cosine_schedule_with_warmup(
127
+ optimizer, lr, min_lr, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
128
+ ):
129
+ """
130
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
131
+ initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
132
+ initial lr set in the optimizer.
133
+
134
+ Args:
135
+ optimizer ([`~torch.optim.Optimizer`]):
136
+ The optimizer for which to schedule the learning rate.
137
+ num_warmup_steps (`int`):
138
+ The number of steps for the warmup phase.
139
+ num_training_steps (`int`):
140
+ The total number of training steps.
141
+ num_cycles (`float`, *optional*, defaults to 0.5):
142
+ The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
143
+ following a half-cosine).
144
+ last_epoch (`int`, *optional*, defaults to -1):
145
+ The index of the last epoch when resuming training.
146
+
147
+ Return:
148
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
149
+ """
150
+
151
+ delta_min_lr = (lr-min_lr)/lr # 0.95
152
+
153
+ def lr_lambda(current_step):
154
+ if current_step < num_warmup_steps:
155
+ return (1-delta_min_lr) + delta_min_lr * float(current_step) / float(max(1, num_warmup_steps))
156
+ progress = float(current_step - num_warmup_steps) / \
157
+ float(max(1, num_training_steps - num_warmup_steps))
158
+ return delta_min_lr + (1-delta_min_lr) * max(0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
159
+ from torch.optim.lr_scheduler import LambdaLR
160
+ return LambdaLR(optimizer, lr_lambda, last_epoch)