KpLBaTMaN commited on
Commit
8e8b282
1 Parent(s): 85923e1

Modified modeling_GOT.py - load_image

Browse files
assets/got_logo.png ADDED
assets/got_support.jpg ADDED
assets/train_sample.jpg ADDED
config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "ucaslcl/GOT-OCR2_0",
3
+ "architectures": [
4
+ "GOTQwenForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "modeling_GOT.GOTConfig",
8
+ "AutoModel": "modeling_GOT.GOTQwenForCausalLM"
9
+ },
10
+ "attention_dropout": 0.0,
11
+ "bos_token_id": 151643,
12
+ "eos_token_id": 151643,
13
+ "freeze_vision_tower": false,
14
+ "hidden_act": "silu",
15
+ "hidden_size": 1024,
16
+ "im_end_token": 151858,
17
+ "im_patch_token": 151859,
18
+ "im_start_token": 151857,
19
+ "image_token_len": 256,
20
+ "initializer_range": 0.02,
21
+ "intermediate_size": 2816,
22
+ "max_position_embeddings": 32768,
23
+ "max_window_layers": 21,
24
+ "model_type": "GOT",
25
+ "num_attention_heads": 16,
26
+ "num_hidden_layers": 24,
27
+ "num_key_value_heads": 16,
28
+ "rms_norm_eps": 1e-06,
29
+ "rope_theta": 1000000.0,
30
+ "sliding_window": 32768,
31
+ "tie_word_embeddings": true,
32
+ "torch_dtype": "bfloat16",
33
+ "transformers_version": "4.37.2",
34
+ "use_cache": true,
35
+ "use_im_start_end": true,
36
+ "use_sliding_window": false,
37
+ "vocab_size": 151860
38
+ }
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "eos_token_id": 151643,
4
+ "max_new_tokens": 2048,
5
+ "transformers_version": "4.37.2"
6
+ }
got_vision_b.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from typing import Optional, Tuple, Type
4
+ from functools import partial
5
+ import torch.nn as nn
6
+ from typing import Type
7
+
8
+
9
+
10
+ class MLPBlock(nn.Module):
11
+ def __init__(
12
+ self,
13
+ embedding_dim: int,
14
+ mlp_dim: int,
15
+ act: Type[nn.Module] = nn.GELU,
16
+ ) -> None:
17
+ super().__init__()
18
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
19
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
20
+ self.act = act()
21
+
22
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
23
+ return self.lin2(self.act(self.lin1(x)))
24
+
25
+
26
+
27
+ class LayerNorm2d(nn.Module):
28
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
29
+ super().__init__()
30
+ self.weight = nn.Parameter(torch.ones(num_channels))
31
+ self.bias = nn.Parameter(torch.zeros(num_channels))
32
+ self.eps = eps
33
+
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ u = x.mean(1, keepdim=True)
36
+ s = (x - u).pow(2).mean(1, keepdim=True)
37
+ x = (x - u) / torch.sqrt(s + self.eps)
38
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
39
+ return x
40
+
41
+
42
+
43
+ class ImageEncoderViT(nn.Module):
44
+ def __init__(
45
+ self,
46
+ img_size: int = 1024,
47
+ patch_size: int = 16,
48
+ in_chans: int = 3,
49
+ embed_dim: int = 768,
50
+ depth: int = 12,
51
+ num_heads: int = 12,
52
+ mlp_ratio: float = 4.0,
53
+ out_chans: int = 256,
54
+ qkv_bias: bool = True,
55
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
56
+ act_layer: Type[nn.Module] = nn.GELU,
57
+ use_abs_pos: bool = True,
58
+ use_rel_pos: bool = False,
59
+ rel_pos_zero_init: bool = True,
60
+ window_size: int = 0,
61
+ global_attn_indexes: Tuple[int, ...] = (),
62
+ ) -> None:
63
+ """
64
+ Args:
65
+ img_size (int): Input image size.
66
+ patch_size (int): Patch size.
67
+ in_chans (int): Number of input image channels.
68
+ embed_dim (int): Patch embedding dimension.
69
+ depth (int): Depth of ViT.
70
+ num_heads (int): Number of attention heads in each ViT block.
71
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
72
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
73
+ norm_layer (nn.Module): Normalization layer.
74
+ act_layer (nn.Module): Activation layer.
75
+ use_abs_pos (bool): If True, use absolute positional embeddings.
76
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
77
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
78
+ window_size (int): Window size for window attention blocks.
79
+ global_attn_indexes (list): Indexes for blocks using global attention.
80
+ """
81
+ super().__init__()
82
+ self.img_size = img_size
83
+
84
+ self.patch_embed = PatchEmbed(
85
+ kernel_size=(patch_size, patch_size),
86
+ stride=(patch_size, patch_size),
87
+ in_chans=in_chans,
88
+ embed_dim=embed_dim,
89
+ )
90
+
91
+ self.pos_embed: Optional[nn.Parameter] = None
92
+ if use_abs_pos:
93
+ # Initialize absolute positional embedding with pretrain image size.
94
+ self.pos_embed = nn.Parameter(
95
+ torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
96
+ )
97
+
98
+ self.blocks = nn.ModuleList()
99
+ for i in range(depth):
100
+ block = Block(
101
+ dim=embed_dim,
102
+ num_heads=num_heads,
103
+ mlp_ratio=mlp_ratio,
104
+ qkv_bias=qkv_bias,
105
+ norm_layer=norm_layer,
106
+ act_layer=act_layer,
107
+ use_rel_pos=use_rel_pos,
108
+ rel_pos_zero_init=rel_pos_zero_init,
109
+ window_size=window_size if i not in global_attn_indexes else 0,
110
+ input_size=(img_size // patch_size, img_size // patch_size),
111
+ )
112
+ self.blocks.append(block)
113
+
114
+ self.neck = nn.Sequential(
115
+ nn.Conv2d(
116
+ embed_dim,
117
+ out_chans,
118
+ kernel_size=1,
119
+ bias=False,
120
+ ),
121
+ LayerNorm2d(out_chans),
122
+ nn.Conv2d(
123
+ out_chans,
124
+ out_chans,
125
+ kernel_size=3,
126
+ padding=1,
127
+ bias=False,
128
+ ),
129
+ LayerNorm2d(out_chans),
130
+ )
131
+
132
+
133
+ self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
134
+ self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)
135
+
136
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
137
+ x = self.patch_embed(x)
138
+ if self.pos_embed is not None:
139
+ x = x + self.pos_embed
140
+
141
+ for blk in self.blocks:
142
+ x = blk(x)
143
+
144
+ x = self.neck(x.permute(0, 3, 1, 2))
145
+ x = self.net_2(x)
146
+ x = self.net_3(x)
147
+
148
+
149
+ return x
150
+
151
+
152
+ class Block(nn.Module):
153
+ """Transformer blocks with support of window attention and residual propagation blocks"""
154
+
155
+ def __init__(
156
+ self,
157
+ dim: int,
158
+ num_heads: int,
159
+ mlp_ratio: float = 4.0,
160
+ qkv_bias: bool = True,
161
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
162
+ act_layer: Type[nn.Module] = nn.GELU,
163
+ use_rel_pos: bool = False,
164
+ rel_pos_zero_init: bool = True,
165
+ window_size: int = 0,
166
+ input_size: Optional[Tuple[int, int]] = None,
167
+ ) -> None:
168
+ """
169
+ Args:
170
+ dim (int): Number of input channels.
171
+ num_heads (int): Number of attention heads in each ViT block.
172
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
173
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
174
+ norm_layer (nn.Module): Normalization layer.
175
+ act_layer (nn.Module): Activation layer.
176
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
177
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
178
+ window_size (int): Window size for window attention blocks. If it equals 0, then
179
+ use global attention.
180
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
181
+ positional parameter size.
182
+ """
183
+ super().__init__()
184
+ self.norm1 = norm_layer(dim)
185
+ self.attn = Attention(
186
+ dim,
187
+ num_heads=num_heads,
188
+ qkv_bias=qkv_bias,
189
+ use_rel_pos=use_rel_pos,
190
+ rel_pos_zero_init=rel_pos_zero_init,
191
+ input_size=input_size if window_size == 0 else (window_size, window_size),
192
+ )
193
+
194
+ self.norm2 = norm_layer(dim)
195
+ self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
196
+
197
+ self.window_size = window_size
198
+
199
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
200
+ shortcut = x
201
+ x = self.norm1(x)
202
+ # Window partition
203
+ if self.window_size > 0:
204
+ H, W = x.shape[1], x.shape[2]
205
+ x, pad_hw = window_partition(x, self.window_size)
206
+
207
+ x = self.attn(x)
208
+ # Reverse window partition
209
+ if self.window_size > 0:
210
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
211
+
212
+ x = shortcut + x
213
+ x = x + self.mlp(self.norm2(x))
214
+
215
+ return x
216
+
217
+
218
+ class Attention(nn.Module):
219
+ """Multi-head Attention block with relative position embeddings."""
220
+
221
+ def __init__(
222
+ self,
223
+ dim: int,
224
+ num_heads: int = 8,
225
+ qkv_bias: bool = True,
226
+ use_rel_pos: bool = False,
227
+ rel_pos_zero_init: bool = True,
228
+ input_size: Optional[Tuple[int, int]] = None,
229
+ ) -> None:
230
+ """
231
+ Args:
232
+ dim (int): Number of input channels.
233
+ num_heads (int): Number of attention heads.
234
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
235
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
236
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
237
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
238
+ positional parameter size.
239
+ """
240
+ super().__init__()
241
+ self.num_heads = num_heads
242
+ head_dim = dim // num_heads
243
+ self.scale = head_dim**-0.5
244
+
245
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
246
+ self.proj = nn.Linear(dim, dim)
247
+
248
+ self.use_rel_pos = use_rel_pos
249
+ if self.use_rel_pos:
250
+ assert (
251
+ input_size is not None
252
+ ), "Input size must be provided if using relative positional encoding."
253
+ # initialize relative positional embeddings
254
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
255
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
256
+
257
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
258
+ B, H, W, _ = x.shape
259
+ # qkv with shape (3, B, nHead, H * W, C)
260
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
261
+ # q, k, v with shape (B * nHead, H * W, C)
262
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
263
+
264
+ attn = (q * self.scale) @ k.transpose(-2, -1)
265
+
266
+ if self.use_rel_pos:
267
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
268
+
269
+ attn = attn.softmax(dim=-1)
270
+ x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
271
+ x = self.proj(x)
272
+
273
+ return x
274
+
275
+
276
+ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
277
+ """
278
+ Partition into non-overlapping windows with padding if needed.
279
+ Args:
280
+ x (tensor): input tokens with [B, H, W, C].
281
+ window_size (int): window size.
282
+
283
+ Returns:
284
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
285
+ (Hp, Wp): padded height and width before partition
286
+ """
287
+ B, H, W, C = x.shape
288
+
289
+ pad_h = (window_size - H % window_size) % window_size
290
+ pad_w = (window_size - W % window_size) % window_size
291
+ if pad_h > 0 or pad_w > 0:
292
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
293
+ Hp, Wp = H + pad_h, W + pad_w
294
+
295
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
296
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
297
+ return windows, (Hp, Wp)
298
+
299
+
300
+ def window_unpartition(
301
+ windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
302
+ ) -> torch.Tensor:
303
+ """
304
+ Window unpartition into original sequences and removing padding.
305
+ Args:
306
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
307
+ window_size (int): window size.
308
+ pad_hw (Tuple): padded height and width (Hp, Wp).
309
+ hw (Tuple): original height and width (H, W) before padding.
310
+
311
+ Returns:
312
+ x: unpartitioned sequences with [B, H, W, C].
313
+ """
314
+ Hp, Wp = pad_hw
315
+ H, W = hw
316
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
317
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
318
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
319
+
320
+ if Hp > H or Wp > W:
321
+ x = x[:, :H, :W, :].contiguous()
322
+ return x
323
+
324
+
325
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
326
+ """
327
+ Get relative positional embeddings according to the relative positions of
328
+ query and key sizes.
329
+ Args:
330
+ q_size (int): size of query q.
331
+ k_size (int): size of key k.
332
+ rel_pos (Tensor): relative position embeddings (L, C).
333
+
334
+ Returns:
335
+ Extracted positional embeddings according to relative positions.
336
+ """
337
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
338
+ # Interpolate rel pos if needed.
339
+ if rel_pos.shape[0] != max_rel_dist:
340
+ # Interpolate rel pos.
341
+ rel_pos_resized = F.interpolate(
342
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
343
+ size=max_rel_dist,
344
+ mode="linear",
345
+ )
346
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
347
+ else:
348
+ rel_pos_resized = rel_pos
349
+
350
+ # Scale the coords with short length if shapes for q and k are different.
351
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
352
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
353
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
354
+
355
+ return rel_pos_resized[relative_coords.long()]
356
+
357
+
358
+ def add_decomposed_rel_pos(
359
+ attn: torch.Tensor,
360
+ q: torch.Tensor,
361
+ rel_pos_h: torch.Tensor,
362
+ rel_pos_w: torch.Tensor,
363
+ q_size: Tuple[int, int],
364
+ k_size: Tuple[int, int],
365
+ ) -> torch.Tensor:
366
+ """
367
+ Args:
368
+ attn (Tensor): attention map.
369
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
370
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
371
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
372
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
373
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
374
+
375
+ Returns:
376
+ attn (Tensor): attention map with added relative positional embeddings.
377
+ """
378
+ q_h, q_w = q_size
379
+ k_h, k_w = k_size
380
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
381
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
382
+
383
+ B, _, dim = q.shape
384
+ r_q = q.reshape(B, q_h, q_w, dim)
385
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
386
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
387
+
388
+ attn = (
389
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
390
+ ).view(B, q_h * q_w, k_h * k_w)
391
+
392
+ return attn
393
+
394
+
395
+ class PatchEmbed(nn.Module):
396
+ """
397
+ Image to Patch Embedding.
398
+ """
399
+
400
+ def __init__(
401
+ self,
402
+ kernel_size: Tuple[int, int] = (16, 16),
403
+ stride: Tuple[int, int] = (16, 16),
404
+ padding: Tuple[int, int] = (0, 0),
405
+ in_chans: int = 3,
406
+ embed_dim: int = 768,
407
+ ) -> None:
408
+ """
409
+ Args:
410
+ kernel_size (Tuple): kernel size of the projection layer.
411
+ stride (Tuple): stride of the projection layer.
412
+ padding (Tuple): padding size of the projection layer.
413
+ in_chans (int): Number of input image channels.
414
+ embed_dim (int): Patch embedding dimension.
415
+ """
416
+ super().__init__()
417
+
418
+ self.proj = nn.Conv2d(
419
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
420
+ )
421
+
422
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
423
+ x = self.proj(x)
424
+ # B C H W -> B H W C
425
+ x = x.permute(0, 2, 3, 1)
426
+ return x
427
+
428
+
429
+
430
+ def build_GOT_vit_b(checkpoint=None):
431
+ return _build_GOT_vision(
432
+ encoder_embed_dim=768,
433
+ encoder_depth=12,
434
+ encoder_num_heads=12,
435
+ encoder_global_attn_indexes=[2, 5, 8, 11],
436
+ checkpoint=checkpoint,
437
+ )
438
+
439
+
440
+ def _build_GOT_vision(
441
+ encoder_embed_dim,
442
+ encoder_depth,
443
+ encoder_num_heads,
444
+ encoder_global_attn_indexes,
445
+ checkpoint=None,
446
+ ):
447
+ prompt_embed_dim = 256
448
+ image_size = 1024
449
+ vit_patch_size = 16
450
+ image_embedding_size = image_size // vit_patch_size
451
+ image_encoder=ImageEncoderViT(
452
+ depth=encoder_depth,
453
+ embed_dim=encoder_embed_dim,
454
+ img_size=image_size,
455
+ mlp_ratio=4,
456
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
457
+ num_heads=encoder_num_heads,
458
+ patch_size=vit_patch_size,
459
+ qkv_bias=True,
460
+ use_rel_pos=True,
461
+ global_attn_indexes=encoder_global_attn_indexes,
462
+ window_size=14,
463
+ out_chans=prompt_embed_dim,
464
+ )
465
+
466
+
467
+ return image_encoder
468
+
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77d6144039548b14253176b6eb264896bc39eba532f8894700f210a7fd2a5956
3
+ size 1432121416
modeling_GOT.py ADDED
@@ -0,0 +1,898 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM, StoppingCriteria, TextStreamer
2
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
3
+ from typing import List, Optional, Tuple, Union
4
+ from transformers.cache_utils import Cache
5
+ import requests
6
+ from PIL import Image
7
+ from io import BytesIO
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import CrossEntropyLoss
11
+ from .got_vision_b import build_GOT_vit_b
12
+ from torchvision import transforms
13
+ from torchvision.transforms.functional import InterpolationMode
14
+ import dataclasses
15
+ import numpy as np
16
+ import cv2
17
+ from io import BytesIO
18
+ ###
19
+
20
+ DEFAULT_IMAGE_TOKEN = "<image>"
21
+ DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
22
+ DEFAULT_IM_START_TOKEN = '<img>'
23
+ DEFAULT_IM_END_TOKEN = '</img>'
24
+
25
+ from enum import auto, Enum
26
+ class SeparatorStyle(Enum):
27
+ """Different separator style."""
28
+ SINGLE = auto()
29
+ TWO = auto()
30
+ MPT = auto()
31
+
32
+
33
+ @dataclasses.dataclass
34
+ class Conversation:
35
+ """A class that keeps all conversation history."""
36
+ system: str
37
+ roles: List[str]
38
+ messages: List[List[str]]
39
+ offset: int
40
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
41
+ sep: str = "<|im_end|>"
42
+ sep2: str = None
43
+ version: str = "Unknown"
44
+
45
+ skip_next: bool = False
46
+
47
+ def get_prompt(self):
48
+ if self.sep_style == SeparatorStyle.SINGLE:
49
+ ret = self.system + self.sep + '\n'
50
+ for role, message in self.messages:
51
+ if message:
52
+ if type(message) is tuple:
53
+ message, _, _ = message
54
+ ret += role + ": " + message + self.sep
55
+ else:
56
+ ret += role + ":"
57
+ return ret
58
+ elif self.sep_style == SeparatorStyle.TWO:
59
+ seps = [self.sep, self.sep2]
60
+ ret = self.system + seps[0]
61
+ for i, (role, message) in enumerate(self.messages):
62
+ if message:
63
+ if type(message) is tuple:
64
+ message, _, _ = message
65
+ ret += role + ": " + message + seps[i % 2]
66
+ else:
67
+ ret += role + ":"
68
+ return ret
69
+ if self.sep_style == SeparatorStyle.MPT:
70
+ if self.system:
71
+ ret = self.system + self.sep
72
+ else:
73
+ ret = ''
74
+ for role, message in self.messages:
75
+ if message:
76
+ if type(message) is tuple:
77
+ message, _, _ = message
78
+ ret += role + message + self.sep
79
+ else:
80
+ ret += role
81
+ return ret
82
+ else:
83
+ raise ValueError(f"Invalid style: {self.sep_style}")
84
+
85
+
86
+ def append_message(self, role, message):
87
+ self.messages.append([role, message])
88
+
89
+ def copy(self):
90
+ return Conversation(
91
+ system=self.system,
92
+ roles=self.roles,
93
+ messages=[[x, y] for x, y in self.messages],
94
+ offset=self.offset,
95
+ sep_style=self.sep_style,
96
+ sep=self.sep,
97
+ sep2=self.sep2)
98
+
99
+
100
+
101
+ class KeywordsStoppingCriteria(StoppingCriteria):
102
+ def __init__(self, keywords, tokenizer, input_ids):
103
+ self.keywords = keywords
104
+ self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
105
+ self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1]
106
+ self.tokenizer = tokenizer
107
+ self.start_len = None
108
+ self.input_ids = input_ids
109
+
110
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
111
+ if self.start_len is None:
112
+ self.start_len = self.input_ids.shape[1]
113
+ else:
114
+ for keyword_id in self.keyword_ids:
115
+ if output_ids[0, -1] == keyword_id:
116
+ return True
117
+ outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
118
+ for keyword in self.keywords:
119
+ if keyword in outputs:
120
+ return True
121
+ return False
122
+
123
+
124
+ class GOTImageEvalProcessor:
125
+ def __init__(self, image_size=384, mean=None, std=None):
126
+ if mean is None:
127
+ mean = (0.48145466, 0.4578275, 0.40821073)
128
+ if std is None:
129
+ std = (0.26862954, 0.26130258, 0.27577711)
130
+
131
+ self.normalize = transforms.Normalize(mean, std)
132
+
133
+ self.transform = transforms.Compose(
134
+ [
135
+ transforms.Resize(
136
+ (image_size, image_size), interpolation=InterpolationMode.BICUBIC
137
+ ),
138
+ transforms.ToTensor(),
139
+ self.normalize,
140
+ ]
141
+ )
142
+ def __call__(self, item):
143
+ return self.transform(item)
144
+
145
+
146
+
147
+ class GOTConfig(Qwen2Config):
148
+ model_type = "GOT"
149
+
150
+
151
+ class GOTQwenModel(Qwen2Model):
152
+ config_class = GOTConfig
153
+
154
+ def __init__(self, config: Qwen2Config):
155
+ super(GOTQwenModel, self).__init__(config)
156
+
157
+ self.vision_tower_high = build_GOT_vit_b()
158
+
159
+ self.mm_projector_vary = nn.Linear(1024, 1024)
160
+
161
+
162
+ def initialize_vision_modules(
163
+ self,
164
+ vision_tower,
165
+ pretrained_stage1_model=None,
166
+ freeze_vision_tower=False,
167
+ use_im_start_end=False,
168
+ vision_select_layer=-1,
169
+ dtype=torch.float16,
170
+ device="cuda"
171
+ ):
172
+
173
+
174
+ image_processor_high = GOTImageEvalProcessor(image_size=1024)
175
+
176
+ self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device)
177
+
178
+ self.mm_projector_vary = self.mm_projector_vary.to(dtype=dtype, device=device)
179
+
180
+
181
+ image_token_len = 256
182
+
183
+ self.config.vision_tower = vision_tower
184
+ self.config.image_token_len = image_token_len
185
+
186
+ self.config.use_im_start_end = True
187
+
188
+ self.config.vision_select_layer = vision_select_layer
189
+ self.config.freeze_vision_tower = freeze_vision_tower
190
+
191
+ return dict(
192
+ image_processor_high=image_processor_high,
193
+ image_token_len=image_token_len,
194
+ )
195
+
196
+
197
+ def forward(
198
+ self,
199
+ input_ids: torch.LongTensor = None,
200
+ attention_mask: Optional[torch.Tensor] = None,
201
+ position_ids: Optional[torch.LongTensor] = None,
202
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
203
+ inputs_embeds: Optional[torch.FloatTensor] = None,
204
+ use_cache: Optional[bool] = None,
205
+ output_attentions: Optional[bool] = None,
206
+ output_hidden_states: Optional[bool] = None,
207
+ images: Optional[torch.FloatTensor] = None,
208
+ return_dict: Optional[bool] = None,
209
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
210
+
211
+ # HACK: replace back original embeddings for LLaVA pretraining
212
+ orig_embeds_params = getattr(self, 'orig_embeds_params', None)
213
+ if orig_embeds_params is not None:
214
+ with torch.no_grad():
215
+ self.get_input_embeddings().weight[:-self.num_new_tokens] = orig_embeds_params[:-self.num_new_tokens].data
216
+
217
+ if inputs_embeds is None:
218
+ inputs_embeds = self.embed_tokens(input_ids)
219
+
220
+
221
+ vision_tower_high = getattr(self, 'vision_tower_high', None)
222
+
223
+
224
+ if vision_tower_high is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
225
+ use_im_start_end = getattr(self.config, "use_im_start_end", -1)
226
+
227
+ vision_select_layer = getattr(self.config, "vision_select_layer", -1)
228
+ im_patch_token = getattr(self.config, "im_patch_token", -1)
229
+ im_start_token = getattr(self.config, "im_start_token", -1)
230
+ im_end_token = getattr(self.config, "im_end_token", -1)
231
+ freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False)
232
+
233
+ im_patch_token = 151859
234
+
235
+ im_start_token = 151857
236
+
237
+ im_end_token = 151858
238
+
239
+ image_features = []
240
+
241
+ for image in images:
242
+ P, C, H, W = image.shape
243
+ if P == 1:
244
+ with torch.set_grad_enabled(False):
245
+ cnn_feature = vision_tower_high(image)
246
+ cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024
247
+ image_feature = self.mm_projector_vary(cnn_feature)
248
+ image_features.append(image_feature)
249
+
250
+ else:
251
+ image_patches = torch.unbind(image)
252
+ image_patches_features = []
253
+ for image_patch in image_patches:
254
+ image_p = torch.stack([image_patch])
255
+
256
+ with torch.set_grad_enabled(False):
257
+ cnn_feature_p = vision_tower_high(image_p)
258
+ cnn_feature_p = cnn_feature_p.flatten(2).permute(0, 2, 1)
259
+ image_feature_p = self.mm_projector_vary(cnn_feature_p)
260
+ image_patches_features.append(image_feature_p)
261
+ image_feature = torch.cat(image_patches_features, dim=1)
262
+ image_features.append(image_feature)
263
+
264
+
265
+ dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
266
+ dummy_image_features = dummy_image_features_2
267
+ use_im_start_end = True
268
+ new_input_embeds = []
269
+ for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):
270
+ if (cur_input_ids == im_patch_token).sum() == 0:
271
+ cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
272
+ new_input_embeds.append(cur_input_embeds)
273
+ continue
274
+
275
+ if use_im_start_end:
276
+ if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum():
277
+ raise ValueError("The number of image start tokens and image end tokens should be the same.")
278
+
279
+ image_start_tokens = torch.where(cur_input_ids == im_start_token)[0]
280
+ for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features):
281
+ per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device)
282
+ num_patches = per_cur_image_features.shape[0]
283
+
284
+ if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token:
285
+ raise ValueError("The image end token should follow the image start token.")
286
+
287
+ cur_input_embeds = torch.cat(
288
+ (
289
+ cur_input_embeds[:image_start_token_pos+1],
290
+ per_cur_image_features,
291
+ cur_input_embeds[image_start_token_pos + num_patches + 1:]
292
+ ),
293
+ dim=0
294
+ )
295
+
296
+
297
+ new_input_embeds.append(cur_input_embeds)
298
+ else:
299
+ raise NotImplementedError
300
+
301
+ inputs_embeds = torch.stack(new_input_embeds, dim=0)
302
+
303
+ return super(GOTQwenModel, self).forward(
304
+ input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
305
+ inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids,
306
+ output_attentions=output_attentions, output_hidden_states=output_hidden_states,
307
+ return_dict=return_dict
308
+ )
309
+
310
+
311
+
312
+ class GOTQwenForCausalLM(Qwen2ForCausalLM):
313
+ config_class = GOTConfig
314
+ # supports_gradient_checkpointing = True
315
+
316
+ def __init__(self, config):
317
+ super(Qwen2ForCausalLM, self).__init__(config)
318
+ self.model = GOTQwenModel(config)
319
+
320
+ self.vocab_size = config.vocab_size
321
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
322
+
323
+ # Initialize weights and apply final processing
324
+ self.post_init()
325
+
326
+ def get_model(self):
327
+ return self.model
328
+
329
+ def forward(
330
+ self,
331
+ input_ids: torch.LongTensor = None,
332
+ attention_mask: Optional[torch.Tensor] = None,
333
+ position_ids: Optional[torch.LongTensor] = None,
334
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
335
+ inputs_embeds: Optional[torch.FloatTensor] = None,
336
+ labels: Optional[torch.LongTensor] = None,
337
+ use_cache: Optional[bool] = None,
338
+ output_attentions: Optional[bool] = None,
339
+ output_hidden_states: Optional[bool] = None,
340
+ images: Optional[torch.FloatTensor] = None,
341
+ return_dict: Optional[bool] = None,
342
+
343
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
344
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
345
+ output_hidden_states = (
346
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
347
+ )
348
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
349
+
350
+ outputs = self.model(
351
+ input_ids=input_ids,
352
+ past_key_values=past_key_values,
353
+ attention_mask=attention_mask,
354
+ position_ids=position_ids,
355
+ inputs_embeds=inputs_embeds,
356
+ use_cache=use_cache,
357
+ output_attentions=output_attentions,
358
+ output_hidden_states=output_hidden_states,
359
+ images=images,
360
+ return_dict=return_dict
361
+
362
+ )
363
+
364
+ hidden_states = outputs[0]
365
+ logits = self.lm_head(hidden_states)
366
+ logits = logits.float()
367
+
368
+ # logits
369
+
370
+ loss = None
371
+ if labels is not None:
372
+ # Shift so that tokens < n predict n
373
+ shift_logits = logits[..., :-1, :].contiguous()
374
+ shift_labels = labels[..., 1:].contiguous()
375
+ # Flatten the tokens
376
+ loss_fct = CrossEntropyLoss()
377
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
378
+ shift_labels = shift_labels.view(-1)
379
+ # Enable model parallelism
380
+ shift_labels = shift_labels.to(shift_logits.device)
381
+ loss = loss_fct(shift_logits, shift_labels)
382
+
383
+ if not return_dict:
384
+ output = (logits,) + outputs[1:]
385
+ return (loss,) + output if loss is not None else output
386
+
387
+ return CausalLMOutputWithPast(
388
+ loss=loss,
389
+ logits=logits,
390
+ past_key_values=outputs.past_key_values,
391
+ hidden_states=outputs.hidden_states,
392
+ attentions=outputs.attentions,
393
+ )
394
+
395
+
396
+ def prepare_inputs_for_generation(
397
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
398
+ ):
399
+ # Omit tokens covered by past_key_values
400
+ if past_key_values is not None:
401
+ if isinstance(past_key_values, Cache):
402
+ cache_length = past_key_values.get_seq_length()
403
+ past_length = past_key_values.seen_tokens
404
+ max_cache_length = past_key_values.get_max_length()
405
+ else:
406
+ cache_length = past_length = past_key_values[0][0].shape[2]
407
+ max_cache_length = None
408
+
409
+ # Keep only the unprocessed tokens:
410
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
411
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
412
+ # input)
413
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
414
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
415
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
416
+ # input_ids based on the past_length.
417
+ elif past_length < input_ids.shape[1]:
418
+ input_ids = input_ids[:, past_length:]
419
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
420
+
421
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
422
+ if (
423
+ max_cache_length is not None
424
+ and attention_mask is not None
425
+ and cache_length + input_ids.shape[1] > max_cache_length
426
+ ):
427
+ attention_mask = attention_mask[:, -max_cache_length:]
428
+
429
+ position_ids = kwargs.get("position_ids", None)
430
+ if attention_mask is not None and position_ids is None:
431
+ # create position_ids on the fly for batch generation
432
+ position_ids = attention_mask.long().cumsum(-1) - 1
433
+ position_ids.masked_fill_(attention_mask == 0, 1)
434
+ if past_key_values:
435
+ position_ids = position_ids[:, -input_ids.shape[1] :]
436
+
437
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
438
+ if inputs_embeds is not None and past_key_values is None:
439
+ model_inputs = {"inputs_embeds": inputs_embeds}
440
+ else:
441
+ model_inputs = {"input_ids": input_ids}
442
+
443
+ model_inputs.update(
444
+ {
445
+ "position_ids": position_ids,
446
+ "past_key_values": past_key_values,
447
+ "use_cache": kwargs.get("use_cache"),
448
+ "attention_mask": attention_mask,
449
+ "images": kwargs.get("images", None),
450
+ }
451
+ )
452
+ return model_inputs
453
+
454
+ def initialize_vision_tokenizer(
455
+ self,
456
+ tokenizer,
457
+ freeze_lm_model=False,
458
+ pretrained_stage1_model=None,
459
+ device="cuda"
460
+ ):
461
+ config = self.get_model().config
462
+
463
+
464
+ self.resize_token_embeddings(len(tokenizer))
465
+
466
+ config.im_patch_token = 151859
467
+
468
+ config.use_im_start_end = True
469
+
470
+ if config.use_im_start_end:
471
+ self.resize_token_embeddings(len(tokenizer))
472
+ config.im_start_token, config.im_end_token = 151857, 151858
473
+
474
+ def load_image(self, image_input):
475
+ if isinstance(image_input, Image.Image):
476
+ # If it's already a PIL Image, return it directly
477
+ return image_input
478
+ elif isinstance(image_input, np.ndarray):
479
+ # If it's a NumPy array (e.g., from OpenCV), convert it to a PIL Image
480
+ return Image.fromarray(cv2.cvtColor(image_input, cv2.COLOR_BGR2RGB))
481
+ elif isinstance(image_input, bytes):
482
+ # If it's bytes, convert it to a PIL Image
483
+ image = Image.open(BytesIO(image_input)).convert('RGB')
484
+ return image
485
+ elif isinstance(image_input, str):
486
+ # If it's a URL or file path, load the image accordingly
487
+ if image_input.startswith('http://') or image_input.startswith('https://'):
488
+ response = requests.get(image_input)
489
+ image = Image.open(BytesIO(response.content)).convert('RGB')
490
+ else:
491
+ image = Image.open(image_input).convert('RGB')
492
+ return image
493
+ else:
494
+ raise ValueError("Invalid image input. Must be a file path, URL, PIL Image, NumPy array, or bytes.")
495
+
496
+ def disable_torch_init(self):
497
+ """
498
+ Disable the redundant torch default initialization to accelerate model creation.
499
+ """
500
+ import torch
501
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
502
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
503
+
504
+ def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
505
+
506
+ self.disable_torch_init()
507
+
508
+
509
+ image_processor_high = GOTImageEvalProcessor(image_size=1024)
510
+
511
+ use_im_start_end = True
512
+
513
+ image_token_len = 256
514
+
515
+ if gradio_input:
516
+ image = image_file.copy()
517
+ else:
518
+ image = self.load_image(image_file)
519
+
520
+ w, h = image.size
521
+
522
+ if ocr_type == 'format':
523
+ qs = 'OCR with format: '
524
+ else:
525
+ qs = 'OCR: '
526
+
527
+ if ocr_box:
528
+ bbox = eval(ocr_box)
529
+ if len(bbox) == 2:
530
+ bbox[0] = int(bbox[0]/w*1000)
531
+ bbox[1] = int(bbox[1]/h*1000)
532
+ if len(bbox) == 4:
533
+ bbox[0] = int(bbox[0]/w*1000)
534
+ bbox[1] = int(bbox[1]/h*1000)
535
+ bbox[2] = int(bbox[2]/w*1000)
536
+ bbox[3] = int(bbox[3]/h*1000)
537
+ if ocr_type == 'format':
538
+ qs = str(bbox) + ' ' + 'OCR with format: '
539
+ else:
540
+ qs = str(bbox) + ' ' + 'OCR: '
541
+
542
+ if ocr_color:
543
+ if ocr_type == 'format':
544
+ qs = '[' + ocr_color + ']' + ' ' + 'OCR with format: '
545
+ else:
546
+ qs = '[' + ocr_color + ']' + ' ' + 'OCR: '
547
+
548
+ if use_im_start_end:
549
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
550
+ else:
551
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
552
+
553
+
554
+ conv_mpt = Conversation(
555
+ system="""<|im_start|>system
556
+ You should follow the instructions carefully and explain your answers in detail.""",
557
+ # system = None,
558
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
559
+ version="mpt",
560
+ messages=(),
561
+ offset=0,
562
+ sep_style=SeparatorStyle.MPT,
563
+ sep="<|im_end|>",
564
+ )
565
+
566
+ conv = conv_mpt.copy()
567
+ conv.append_message(conv.roles[0], qs)
568
+ conv.append_message(conv.roles[1], None)
569
+ prompt = conv.get_prompt()
570
+
571
+ if print_prompt:
572
+ print(prompt)
573
+
574
+ inputs = tokenizer([prompt])
575
+
576
+ image_tensor_1 = image_processor_high(image)
577
+
578
+ input_ids = torch.as_tensor(inputs.input_ids).cuda()
579
+
580
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
581
+ keywords = [stop_str]
582
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
583
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
584
+
585
+ if stream_flag:
586
+ with torch.autocast("cuda", dtype=torch.bfloat16):
587
+ output_ids = self.generate(
588
+ input_ids,
589
+ images=[image_tensor_1.unsqueeze(0).half().cuda()],
590
+ do_sample=False,
591
+ num_beams = 1,
592
+ no_repeat_ngram_size = 20,
593
+ streamer=streamer,
594
+ max_new_tokens=4096,
595
+ stopping_criteria=[stopping_criteria]
596
+ )
597
+ else:
598
+ with torch.autocast("cuda", dtype=torch.bfloat16):
599
+ output_ids = self.generate(
600
+ input_ids,
601
+ images=[image_tensor_1.unsqueeze(0).half().cuda()],
602
+ do_sample=False,
603
+ num_beams = 1,
604
+ no_repeat_ngram_size = 20,
605
+ # streamer=streamer,
606
+ max_new_tokens=4096,
607
+ stopping_criteria=[stopping_criteria]
608
+ )
609
+
610
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
611
+
612
+ if outputs.endswith(stop_str):
613
+ outputs = outputs[:-len(stop_str)]
614
+ outputs = outputs.strip()
615
+ response_str = outputs
616
+
617
+ if render:
618
+ print('==============rendering===============')
619
+ from .render_tools import svg_to_html, content_mmd_to_html, tik_html, translation_table
620
+
621
+ if '**kern' in outputs:
622
+ import verovio
623
+ tk = verovio.toolkit()
624
+ tk.loadData(outputs)
625
+ tk.setOptions({"pageWidth": 2100, "footer": 'none',
626
+ 'barLineWidth': 0.5, 'beamMaxSlope': 15,
627
+ 'staffLineWidth': 0.2, 'spacingStaff': 6})
628
+ tk.getPageCount()
629
+ svg = tk.renderToSVG()
630
+ svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
631
+
632
+ svg_to_html(svg, save_render_file)
633
+
634
+ if ocr_type == 'format' and '**kern' not in outputs:
635
+
636
+
637
+ if '\\begin{tikzpicture}' not in outputs:
638
+ html_path_2 = save_render_file
639
+ right_num = outputs.count('\\right')
640
+ left_num = outputs.count('\left')
641
+
642
+ if right_num != left_num:
643
+ outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
644
+
645
+
646
+ outputs = outputs.replace('"', '``').replace('$', '')
647
+
648
+ outputs_list = outputs.split('\n')
649
+ gt= ''
650
+ for out in outputs_list:
651
+ gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
652
+
653
+ gt = gt[:-2]
654
+
655
+
656
+ lines = content_mmd_to_html
657
+ lines = lines.split("const text =")
658
+ new_web = lines[0] + 'const text =' + gt + lines[1]
659
+
660
+ else:
661
+ html_path_2 = save_render_file
662
+ outputs = outputs.translate(translation_table)
663
+ outputs_list = outputs.split('\n')
664
+ gt= ''
665
+ for out in outputs_list:
666
+ if out:
667
+ if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
668
+ while out[-1] == ' ':
669
+ out = out[:-1]
670
+ if out is None:
671
+ break
672
+
673
+ if out:
674
+ if out[-1] != ';':
675
+ gt += out[:-1] + ';\n'
676
+ else:
677
+ gt += out + '\n'
678
+ else:
679
+ gt += out + '\n'
680
+
681
+
682
+ lines = tik_html
683
+ lines = lines.split("const text =")
684
+ new_web = lines[0] + gt + lines[1]
685
+
686
+ with open(html_path_2, 'w') as web_f_new:
687
+ web_f_new.write(new_web)
688
+ return response_str
689
+
690
+ def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
691
+
692
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
693
+ best_ratio_diff = float('inf')
694
+ best_ratio = (1, 1)
695
+ area = width * height
696
+ for ratio in target_ratios:
697
+ target_aspect_ratio = ratio[0] / ratio[1]
698
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
699
+ if ratio_diff < best_ratio_diff:
700
+ best_ratio_diff = ratio_diff
701
+ best_ratio = ratio
702
+ elif ratio_diff == best_ratio_diff:
703
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
704
+ best_ratio = ratio
705
+ # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
706
+ return best_ratio
707
+
708
+ orig_width, orig_height = image.size
709
+ aspect_ratio = orig_width / orig_height
710
+
711
+ # calculate the existing image aspect ratio
712
+ target_ratios = set(
713
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
714
+ i * j <= max_num and i * j >= min_num)
715
+ # print(target_ratios)
716
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
717
+
718
+ # find the closest aspect ratio to the target
719
+ target_aspect_ratio = find_closest_aspect_ratio(
720
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
721
+
722
+ # print(target_aspect_ratio)
723
+ # calculate the target width and height
724
+ target_width = image_size * target_aspect_ratio[0]
725
+ target_height = image_size * target_aspect_ratio[1]
726
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
727
+
728
+ # resize the image
729
+ resized_img = image.resize((target_width, target_height))
730
+ processed_images = []
731
+ for i in range(blocks):
732
+ box = (
733
+ (i % (target_width // image_size)) * image_size,
734
+ (i // (target_width // image_size)) * image_size,
735
+ ((i % (target_width // image_size)) + 1) * image_size,
736
+ ((i // (target_width // image_size)) + 1) * image_size
737
+ )
738
+ # split the image
739
+ split_img = resized_img.crop(box)
740
+ processed_images.append(split_img)
741
+ assert len(processed_images) == blocks
742
+ if use_thumbnail and len(processed_images) != 1:
743
+ thumbnail_img = image.resize((image_size, image_size))
744
+ processed_images.append(thumbnail_img)
745
+ return processed_images
746
+
747
+
748
+ def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
749
+ # Model
750
+ self.disable_torch_init()
751
+ multi_page=False
752
+
753
+
754
+ image_processor_high = GOTImageEvalProcessor(image_size=1024)
755
+
756
+ use_im_start_end = True
757
+
758
+
759
+ image_token_len = 256
760
+
761
+ image_list = []
762
+
763
+ # if len(image_file_list)>1:
764
+ # multi_page = True
765
+
766
+ if multi_page:
767
+ qs = 'OCR with format across multi pages: '
768
+ # only for png files
769
+ # import glob
770
+ # from natsort import natsorted
771
+ # patches = glob.glob(image_file + '/*png')
772
+ patches = image_file
773
+ # patches = natsorted(patches)
774
+ sub_images = []
775
+ for sub_image in patches:
776
+ sub_images.append(self.load_image(sub_image))
777
+
778
+ ll = len(patches)
779
+ # print(patches)
780
+ # print("len ll: ", ll)
781
+
782
+ else:
783
+ if ocr_type == 'format':
784
+ qs = 'OCR with format upon the patch reference: '
785
+ else:
786
+ qs = 'OCR upon the patch reference: '
787
+ if gradio_input:
788
+ img = image_file.copy()
789
+ else:
790
+ img = self.load_image(image_file)
791
+ sub_images = self.dynamic_preprocess(img)
792
+ ll = len(sub_images)
793
+
794
+ for image in sub_images:
795
+ image_tensor_1 = image_processor_high(image)
796
+ image_list.append(image_tensor_1)
797
+
798
+
799
+ image_list = torch.stack(image_list)
800
+
801
+ print('====new images batch size======: \n',image_list.shape)
802
+
803
+
804
+ if use_im_start_end:
805
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len*ll + DEFAULT_IM_END_TOKEN + '\n' + qs
806
+ else:
807
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
808
+
809
+
810
+ conv_mpt = Conversation(
811
+ system="""<|im_start|>system
812
+ You should follow the instructions carefully and explain your answers in detail.""",
813
+ # system = None,
814
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
815
+ version="mpt",
816
+ messages=(),
817
+ offset=0,
818
+ sep_style=SeparatorStyle.MPT,
819
+ sep="<|im_end|>",
820
+ )
821
+
822
+ conv = conv_mpt.copy()
823
+ conv.append_message(conv.roles[0], qs)
824
+ conv.append_message(conv.roles[1], None)
825
+ prompt = conv.get_prompt()
826
+
827
+ if print_prompt:
828
+ print(prompt)
829
+
830
+ inputs = tokenizer([prompt])
831
+
832
+ input_ids = torch.as_tensor(inputs.input_ids).cuda()
833
+
834
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
835
+ keywords = [stop_str]
836
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
837
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
838
+
839
+ if stream_flag:
840
+ with torch.autocast("cuda", dtype=torch.bfloat16):
841
+ output_ids = self.generate(
842
+ input_ids,
843
+ images=[image_list.half().cuda()],
844
+ do_sample=False,
845
+ num_beams = 1,
846
+ # no_repeat_ngram_size = 20,
847
+ streamer=streamer,
848
+ max_new_tokens=4096,
849
+ stopping_criteria=[stopping_criteria]
850
+ )
851
+ else:
852
+ with torch.autocast("cuda", dtype=torch.bfloat16):
853
+ output_ids = self.generate(
854
+ input_ids,
855
+ images=[image_list.half().cuda()],
856
+ do_sample=False,
857
+ num_beams = 1,
858
+ # no_repeat_ngram_size = 20,
859
+ # streamer=streamer,
860
+ max_new_tokens=4096,
861
+ stopping_criteria=[stopping_criteria]
862
+ )
863
+
864
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
865
+
866
+ if outputs.endswith(stop_str):
867
+ outputs = outputs[:-len(stop_str)]
868
+ outputs = outputs.strip()
869
+ response_str = outputs
870
+
871
+ if render:
872
+ print('==============rendering===============')
873
+ from .render_tools import content_mmd_to_html
874
+ html_path_2 = save_render_file
875
+ right_num = outputs.count('\\right')
876
+ left_num = outputs.count('\left')
877
+
878
+ if right_num != left_num:
879
+ outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
880
+
881
+
882
+ outputs = outputs.replace('"', '``').replace('$', '')
883
+
884
+ outputs_list = outputs.split('\n')
885
+ gt= ''
886
+ for out in outputs_list:
887
+ gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
888
+
889
+ gt = gt[:-2]
890
+
891
+ lines = content_mmd_to_html
892
+ lines = lines.split("const text =")
893
+ new_web = lines[0] + 'const text =' + gt + lines[1]
894
+
895
+ with open(html_path_2, 'w') as web_f_new:
896
+ web_f_new.write(new_web)
897
+
898
+ return response_str
qwen.tiktoken ADDED
The diff for this file is too large to render. See raw diff
 
render_tools.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ punctuation_dict = {
3
+ ",": ",",
4
+ "。": ".",
5
+
6
+ }
7
+ translation_table = str.maketrans(punctuation_dict)
8
+
9
+ def svg_to_html(svg_content, output_filename):
10
+
11
+ html_content = f"""
12
+ <!DOCTYPE html>
13
+ <html lang="en">
14
+ <head>
15
+ <meta charset="UTF-8">
16
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
17
+ <title>SVG Embedded in HTML</title>
18
+ </head>
19
+ <body>
20
+ <svg width="2100" height="15000" xmlns="http://www.w3.org/2000/svg">
21
+ {svg_content}
22
+ </svg>
23
+ </body>
24
+ </html>
25
+ """
26
+
27
+ with open(output_filename, 'w') as file:
28
+ file.write(html_content)
29
+
30
+
31
+
32
+ content_mmd_to_html = """<!DOCTYPE html>
33
+ <html lang="en" data-lt-installed="true"><head>
34
+ <meta charset="UTF-8">
35
+ <title>Title</title>
36
+ <script>
37
+ const text =
38
+ </script>
39
+ <style>
40
+ #content {
41
+ max-width: 800px;
42
+ margin: auto;
43
+ }
44
+ </style>
45
+ <script>
46
+ let script = document.createElement('script');
47
+ script.src = "https://cdn.jsdelivr.net/npm/[email protected]/es5/bundle.js";
48
+ document.head.append(script);
49
+
50
+ script.onload = function() {
51
+ const isLoaded = window.loadMathJax();
52
+ if (isLoaded) {
53
+ console.log('Styles loaded!')
54
+ }
55
+
56
+ const el = window.document.getElementById('content-text');
57
+ if (el) {
58
+ const options = {
59
+ htmlTags: true
60
+ };
61
+ const html = window.render(text, options);
62
+ el.outerHTML = html;
63
+ }
64
+ };
65
+ </script>
66
+ </head>
67
+ <body>
68
+ <div id="content"><div id="content-text"></div></div>
69
+ </body>
70
+ </html>
71
+ """
72
+
73
+
74
+
75
+ tik_html = """
76
+ <!DOCTYPE html>
77
+
78
+ <html>
79
+
80
+ <head>
81
+ <meta charset="UTF-8">
82
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
83
+ <title>Document</title>
84
+ <link rel="stylesheet" type="text/css" href="https://tikzjax.com/v1/fonts.css">
85
+ <script src="https://tikzjax.com/v1/tikzjax.js"></script>
86
+ </head>
87
+ <body>
88
+ <script type="text/tikz">
89
+ const text =
90
+ </script>
91
+ </body>
92
+ </html>"""
93
+
94
+
95
+
96
+ # print(tik_html)
special_tokens_map.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "pad_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ }
9
+ }
tokenization_qwen.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Tokenization classes for QWen."""
7
+
8
+ import base64
9
+ import logging
10
+ import os
11
+ import unicodedata
12
+ from typing import Collection, Dict, List, Set, Tuple, Union
13
+
14
+ import tiktoken
15
+ from transformers import PreTrainedTokenizer, AddedToken
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"}
21
+
22
+ PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
23
+ ENDOFTEXT = "<|endoftext|>"
24
+ IMSTART = "<|im_start|>"
25
+ IMEND = "<|im_end|>"
26
+ # as the default behavior is changed to allow special tokens in
27
+ # regular texts, the surface forms of special tokens need to be
28
+ # as different as possible to minimize the impact
29
+ EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
30
+ SPECIAL_TOKENS = (
31
+ ENDOFTEXT,
32
+ IMSTART,
33
+ IMEND,
34
+ ) + EXTRAS
35
+
36
+
37
+ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
38
+ with open(tiktoken_bpe_file, "rb") as f:
39
+ contents = f.read()
40
+ return {
41
+ base64.b64decode(token): int(rank)
42
+ for token, rank in (line.split() for line in contents.splitlines() if line)
43
+ }
44
+
45
+ class QWenTokenizer(PreTrainedTokenizer):
46
+ """QWen tokenizer."""
47
+
48
+ vocab_files_names = VOCAB_FILES_NAMES
49
+
50
+ def __init__(
51
+ self,
52
+ vocab_file,
53
+ errors="replace",
54
+ image_start_tag='<img>',
55
+ image_end_tag='</img>',
56
+ image_pad_tag='<imgpad>',
57
+ ref_start_tag='<ref>',
58
+ ref_end_tag='</ref>',
59
+ box_start_tag='<box>',
60
+ box_end_tag='</box>',
61
+ quad_start_tag='<quad>',
62
+ quad_end_tag='</quad>',
63
+ **kwargs,
64
+ ):
65
+ super().__init__(**kwargs)
66
+
67
+ self.image_start_tag = image_start_tag
68
+ self.image_end_tag = image_end_tag
69
+ self.image_pad_tag = image_pad_tag
70
+ self.ref_start_tag = ref_start_tag
71
+ self.ref_end_tag = ref_end_tag
72
+ self.box_start_tag = box_start_tag
73
+ self.box_end_tag = box_end_tag
74
+ self.quad_start_tag = quad_start_tag
75
+ self.quad_end_tag = quad_end_tag
76
+ self.IMAGE_ST = (
77
+ ref_start_tag, ref_end_tag,
78
+ box_start_tag, box_end_tag,
79
+ quad_start_tag, quad_end_tag,
80
+ image_start_tag, image_end_tag,
81
+ image_pad_tag
82
+ )
83
+
84
+ self.errors = errors # how to handle errors in decoding
85
+
86
+ self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int]
87
+ self.special_tokens = {
88
+ token: index
89
+ for index, token in enumerate(
90
+ SPECIAL_TOKENS + self.IMAGE_ST, start=len(self.mergeable_ranks)
91
+ )
92
+ }
93
+
94
+ self.img_start_id = self.special_tokens[self.image_start_tag]
95
+ self.img_end_id = self.special_tokens[self.image_end_tag]
96
+ self.img_pad_id = self.special_tokens[self.image_pad_tag]
97
+ self.ref_start_id = self.special_tokens[self.ref_start_tag]
98
+ self.ref_end_id = self.special_tokens[self.ref_end_tag]
99
+ self.box_start_id = self.special_tokens[self.box_start_tag]
100
+ self.box_end_id = self.special_tokens[self.box_end_tag]
101
+ self.quad_start_id = self.special_tokens[self.quad_start_tag]
102
+ self.quad_end_id = self.special_tokens[self.quad_end_tag]
103
+
104
+ enc = tiktoken.Encoding(
105
+ "Qwen",
106
+ pat_str=PAT_STR,
107
+ mergeable_ranks=self.mergeable_ranks,
108
+ special_tokens=self.special_tokens,
109
+ )
110
+ assert (
111
+ len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
112
+ ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
113
+
114
+ self.decoder = {
115
+ v: k for k, v in self.mergeable_ranks.items()
116
+ } # type: dict[int, bytes|str]
117
+ self.decoder.update({v: k for k, v in self.special_tokens.items()})
118
+
119
+ self.tokenizer = enc # type: tiktoken.Encoding
120
+
121
+ self.eod_id = self.tokenizer.eot_token
122
+ self.im_start_id = self.special_tokens[IMSTART]
123
+ self.im_end_id = self.special_tokens[IMEND]
124
+
125
+ def __len__(self) -> int:
126
+ return self.tokenizer.n_vocab
127
+
128
+ def get_vocab(self) -> Dict[bytes, int]:
129
+ return self.mergeable_ranks
130
+
131
+ def convert_tokens_to_ids(
132
+ self, tokens: Union[bytes, str, List[Union[bytes, str]]]
133
+ ) -> List[int]:
134
+ ids = []
135
+ if isinstance(tokens, (str, bytes)):
136
+ if tokens in self.special_tokens:
137
+ return self.special_tokens[tokens]
138
+ else:
139
+ return self.mergeable_ranks.get(tokens)
140
+ for token in tokens:
141
+ if token in self.special_tokens:
142
+ ids.append(self.special_tokens[token])
143
+ else:
144
+ ids.append(self.mergeable_ranks.get(token))
145
+ return ids
146
+
147
+ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
148
+ if not special_tokens and new_tokens:
149
+ raise ValueError('Adding regular tokens is not supported')
150
+ for token in new_tokens:
151
+ surface_form = token.content if isinstance(token, AddedToken) else token
152
+ if surface_form not in SPECIAL_TOKENS:
153
+ raise ValueError('Adding unknown special tokens is not supported')
154
+ return 0
155
+
156
+ def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
157
+ """
158
+ Save only the vocabulary of the tokenizer (vocabulary).
159
+
160
+ Returns:
161
+ `Tuple(str)`: Paths to the files saved.
162
+ """
163
+ file_path = os.path.join(save_directory, "qwen.tiktoken")
164
+ with open(file_path, "w", encoding="utf8") as w:
165
+ for k, v in self.mergeable_ranks.items():
166
+ line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n"
167
+ w.write(line)
168
+ return (file_path,)
169
+
170
+ def tokenize(
171
+ self,
172
+ text: str,
173
+ allowed_special: Union[Set, str] = "all",
174
+ disallowed_special: Union[Collection, str] = (),
175
+ **kwargs,
176
+ ) -> List[Union[bytes, str]]:
177
+ """
178
+ Converts a string in a sequence of tokens.
179
+
180
+ Args:
181
+ text (`str`):
182
+ The sequence to be encoded.
183
+ allowed_special (`Literal["all"]` or `set`):
184
+ The surface forms of the tokens to be encoded as special tokens in regular texts.
185
+ Default to "all".
186
+ disallowed_special (`Literal["all"]` or `Collection`):
187
+ The surface forms of the tokens that should not be in regular texts and trigger errors.
188
+ Default to an empty tuple.
189
+
190
+ kwargs (additional keyword arguments, *optional*):
191
+ Will be passed to the underlying model specific encode method.
192
+
193
+ Returns:
194
+ `List[bytes|str]`: The list of tokens.
195
+ """
196
+ tokens = []
197
+ text = unicodedata.normalize("NFC", text)
198
+
199
+ # this implementation takes a detour: text -> token id -> token surface forms
200
+ for t in self.tokenizer.encode(
201
+ text, allowed_special=allowed_special, disallowed_special=disallowed_special
202
+ ):
203
+ tokens.append(self.decoder[t])
204
+ return tokens
205
+
206
+ def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
207
+ """
208
+ Converts a sequence of tokens in a single string.
209
+ """
210
+ text = ""
211
+ temp = b""
212
+ for t in tokens:
213
+ if isinstance(t, str):
214
+ if temp:
215
+ text += temp.decode("utf-8", errors=self.errors)
216
+ temp = b""
217
+ text += t
218
+ elif isinstance(t, bytes):
219
+ temp += t
220
+ else:
221
+ raise TypeError("token should only be of type types or str")
222
+ if temp:
223
+ text += temp.decode("utf-8", errors=self.errors)
224
+ return text
225
+
226
+ @property
227
+ def vocab_size(self):
228
+ return self.tokenizer.n_vocab
229
+
230
+ def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
231
+ """Converts an id to a token, special tokens included"""
232
+ if index in self.decoder:
233
+ return self.decoder[index]
234
+ raise ValueError("unknown ids")
235
+
236
+ def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
237
+ """Converts a token to an id using the vocab, special tokens included"""
238
+ if token in self.special_tokens:
239
+ return self.special_tokens[token]
240
+ if token in self.mergeable_ranks:
241
+ return self.mergeable_ranks[token]
242
+ raise ValueError("unknown token")
243
+
244
+ def _tokenize(self, text: str, **kwargs):
245
+ """
246
+ Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
247
+ vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
248
+
249
+ Do NOT take care of added tokens.
250
+ """
251
+ raise NotImplementedError
252
+
253
+ def _decode(
254
+ self,
255
+ token_ids: Union[int, List[int]],
256
+ skip_special_tokens: bool = False,
257
+ errors: str = None,
258
+ **kwargs,
259
+ ) -> str:
260
+ if isinstance(token_ids, int):
261
+ token_ids = [token_ids]
262
+ if skip_special_tokens:
263
+ token_ids = [i for i in token_ids if i < self.eod_id]
264
+ return self.tokenizer.decode(token_ids, errors=errors or self.errors)
tokenizer_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {},
3
+ "auto_map": {
4
+ "AutoTokenizer": [
5
+ "tokenization_qwen.QWenTokenizer",
6
+ null
7
+ ]
8
+ },
9
+ "clean_up_tokenization_spaces": true,
10
+ "model_max_length": 8000,
11
+ "pad_token": "<|endoftext|>",
12
+ "padding_side": "right",
13
+ "tokenizer_class": "QWenTokenizer"
14
+ }