evaldas-leliuga commited on
Commit
63ac37e
1 Parent(s): 6094794
Files changed (2) hide show
  1. inference.py +254 -0
  2. vocab_v20230424.txt +0 -0
inference.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ np.set_printoptions(precision=4, suppress=True, linewidth=200)
4
+ import types, torch
5
+ from torch.nn import functional as F
6
+
7
+ MyModule = torch.jit.ScriptModule
8
+ MyFunction = torch.jit.script_method
9
+
10
+
11
+ class RWKV_TOKENIZER():
12
+ table: list[list[list[bytes]]]
13
+ good: list[set[int]]
14
+ wlen: list[int]
15
+
16
+ def __init__(self, file_name):
17
+ self.idx2token = {}
18
+ sorted = [] # must be already sorted
19
+ lines = open(file_name, "r", encoding="utf-8").readlines()
20
+ for l in lines:
21
+ idx = int(l[:l.index(' ')])
22
+ x = eval(l[l.index(' '):l.rindex(' ')])
23
+ x = x.encode("utf-8") if isinstance(x, str) else x
24
+ assert isinstance(x, bytes)
25
+ assert len(x) == int(l[l.rindex(' '):])
26
+ sorted += [x]
27
+ self.idx2token[idx] = x
28
+
29
+ self.token2idx = {}
30
+ for k, v in self.idx2token.items():
31
+ self.token2idx[v] = int(k)
32
+
33
+ # precompute some tables for fast matching
34
+ self.table = [[[] for j in range(256)] for i in range(256)]
35
+ self.good = [set() for i in range(256)]
36
+ self.wlen = [0 for i in range(256)]
37
+
38
+ for i in reversed(range(len(sorted))): # reverse order - match longer tokens first
39
+ s = sorted[i]
40
+ if len(s) >= 2:
41
+ s0 = int(s[0])
42
+ s1 = int(s[1])
43
+ self.table[s0][s1] += [s]
44
+ self.wlen[s0] = max(self.wlen[s0], len(s))
45
+ self.good[s0].add(s1)
46
+
47
+ def encodeBytes(self, src: bytes) -> list[int]:
48
+ src_len: int = len(src)
49
+ tokens: list[int] = []
50
+ i: int = 0
51
+ while i < src_len:
52
+ s: bytes = src[i: i + 1]
53
+
54
+ if i < src_len - 1:
55
+ s1: int = int(src[i + 1])
56
+ s0: int = int(src[i])
57
+ if s1 in self.good[s0]:
58
+ sss: bytes = src[i: i + self.wlen[s0]]
59
+ try:
60
+ s = next(filter(sss.startswith, self.table[s0][s1]))
61
+ except:
62
+ pass
63
+ tokens.append(self.token2idx[s])
64
+ i += len(s)
65
+
66
+ return tokens
67
+
68
+ def decodeBytes(self, tokens):
69
+ return b''.join(map(lambda i: self.idx2token[i], tokens))
70
+
71
+ def encode(self, src: str):
72
+ return self.encodeBytes(src.encode("utf-8"))
73
+
74
+ def decode(self, tokens):
75
+ return self.decodeBytes(tokens).decode('utf-8')
76
+
77
+ def printTokens(self, tokens):
78
+ for i in tokens:
79
+ s = self.idx2token[i]
80
+ try:
81
+ s = s.decode('utf-8')
82
+ except:
83
+ pass
84
+ print(f'{repr(s)}{i}', end=' ')
85
+ # print(repr(s), i)
86
+ print()
87
+
88
+
89
+ ########################################################################################################
90
+
91
+ def sample_logits(out, temperature=1.0, top_p=0.8):
92
+ probs = F.softmax(out, dim=-1).numpy()
93
+ sorted_probs = np.sort(probs)[::-1]
94
+ cumulative_probs = np.cumsum(sorted_probs)
95
+ cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
96
+ probs[probs < cutoff] = 0
97
+ if temperature != 1.0:
98
+ probs = probs.pow(1.0 / temperature)
99
+ probs = probs / np.sum(probs)
100
+ out = np.random.choice(a=len(probs), p=probs)
101
+ return out
102
+
103
+
104
+ ########################################################################################################
105
+ class RWKV_RNN(MyModule):
106
+ def __init__(self, args):
107
+ super().__init__()
108
+ self.args = args
109
+ self.eval() # set torch to inference mode
110
+
111
+ w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
112
+
113
+ for k in w.keys():
114
+ w[k] = w[k].float() # convert to f32 type
115
+ if '.time_' in k: w[k] = w[k].squeeze()
116
+ if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)
117
+
118
+ self.n_head = w['blocks.0.att.time_faaaa'].shape[0]
119
+ self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
120
+
121
+ self.w = types.SimpleNamespace() # set self.w from w
122
+ self.w.blocks = {}
123
+ for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first
124
+ parts = k.split('.')
125
+ last = parts.pop()
126
+ here = self.w
127
+ for p in parts:
128
+ if p.isdigit():
129
+ p = int(p)
130
+ if p not in here: here[p] = types.SimpleNamespace()
131
+ here = here[p]
132
+ else:
133
+ if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
134
+ here = getattr(here, p)
135
+ setattr(here, last, w[k])
136
+
137
+ def layer_norm(self, x, w):
138
+ return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)
139
+
140
+ @MyFunction
141
+ def channel_mixing(self, x, state, i: int, time_maa_k, time_maa_r, kw, vw, rw):
142
+ i0 = (2 + self.head_size) * i + 0
143
+ sx = state[i0] - x
144
+ xk = x + sx * time_maa_k
145
+ xr = x + sx * time_maa_r
146
+ state[i0] = x
147
+ r = torch.sigmoid(rw @ xr)
148
+ k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
149
+ return r * (vw @ k)
150
+
151
+ @MyFunction
152
+ def time_mixing(self, x, state, i: int, x_maa, w_maa, k_maa, v_maa, r_maa, g_maa, tm_w1, tm_w2, td_w1, td_w2,
153
+ time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):
154
+ H = self.n_head
155
+ S = self.head_size
156
+
157
+ i1 = (2 + S) * i + 1
158
+ sx = state[i1] - x
159
+ state[i1] = x
160
+ xxx = x + sx * x_maa
161
+ xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1)
162
+ xxx = torch.bmm(xxx, tm_w2).view(5, -1)
163
+ mw, mk, mv, mr, mg = xxx.unbind(dim=0)
164
+
165
+ xw = x + sx * (w_maa + mw)
166
+ xk = x + sx * (k_maa + mk)
167
+ xv = x + sx * (v_maa + mv)
168
+ xr = x + sx * (r_maa + mr)
169
+ xg = x + sx * (g_maa + mg)
170
+
171
+ w = (time_decay + (torch.tanh(xw @ td_w1) @ td_w2).float()).view(H, S, 1)
172
+ w = torch.exp(-torch.exp(w.float()))
173
+
174
+ r = (rw @ xr).view(H, 1, S)
175
+ k = (kw @ xk).view(H, S, 1)
176
+ v = (vw @ xv).view(H, 1, S)
177
+ g = F.silu(gw @ xg)
178
+
179
+ s = state[(2 + S) * i + 2:(2 + S) * (i + 1), :].reshape(H, S, S)
180
+
181
+ x = torch.zeros(H, S)
182
+ a = k @ v
183
+ x = r @ (time_first * a + s)
184
+ s = a + w * s
185
+
186
+ state[(2 + S) * i + 2:(2 + S) * (i + 1), :] = s.reshape(S, -1)
187
+ x = x.flatten()
188
+
189
+ x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps=64e-5).squeeze(
190
+ 0) * g # same as gn(x/8, eps=1e-5)
191
+ return ow @ x
192
+
193
+ def forward(self, token, state):
194
+ with torch.no_grad():
195
+ if state == None:
196
+ state = torch.zeros(self.args.n_layer * (2 + self.head_size), self.args.n_embd)
197
+
198
+ x = self.w.emb.weight[token]
199
+ x = self.layer_norm(x, self.w.blocks[0].ln0)
200
+ for i in range(self.args.n_layer):
201
+ att = self.w.blocks[i].att
202
+ x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i,
203
+ att.time_maa_x, att.time_maa_w, att.time_maa_k, att.time_maa_v, att.time_maa_r,
204
+ att.time_maa_g, att.time_maa_w1, att.time_maa_w2,
205
+ att.time_decay_w1, att.time_decay_w2, att.time_faaaa, att.time_decay,
206
+ att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight,
207
+ att.output.weight,
208
+ att.ln_x.weight, att.ln_x.bias)
209
+ ffn = self.w.blocks[i].ffn
210
+ x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i,
211
+ ffn.time_maa_k, ffn.time_maa_r,
212
+ ffn.key.weight, ffn.value.weight, ffn.receptance.weight)
213
+
214
+ x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
215
+ return x.float(), state
216
+
217
+
218
+ tokenizer = RWKV_TOKENIZER("vocab_v20230424.txt")
219
+
220
+ args = types.SimpleNamespace()
221
+ args.MODEL_NAME = 'rwkv-30'
222
+ args.n_layer = 12
223
+ args.n_embd = 768
224
+ args.vocab_size = 65536
225
+
226
+ context = "Today is a beautiful"
227
+ NUM_TRIALS = 3
228
+ LENGTH_PER_TRIAL = 50
229
+ TEMPERATURE = 1.0
230
+ TOP_P = 0.7
231
+
232
+ print(f'model= {args.MODEL_NAME}\n')
233
+ model = RWKV_RNN(args)
234
+ init_state = None
235
+ for token in tokenizer.encode(context):
236
+ init_out, init_state = model.forward(token, init_state)
237
+
238
+ for TRIAL in range(NUM_TRIALS):
239
+ print(f'Trial {TRIAL + 1}=', context, end="")
240
+ all_tokens = []
241
+ n_sampled = 0
242
+ out, state = init_out.clone(), init_state.clone()
243
+ for i in range(LENGTH_PER_TRIAL):
244
+ token = sample_logits(out, TEMPERATURE, TOP_P)
245
+ all_tokens += [token]
246
+ try:
247
+ tmp = tokenizer.decode(all_tokens[n_sampled:])
248
+ if '\ufffd' not in tmp: # only print when we have a valid utf-8 string
249
+ print(tmp, end="", flush=True)
250
+ n_sampled = i + 1
251
+ except:
252
+ pass
253
+ out, state = model.forward(token, state)
254
+ print("\nSampled tokens=", n_sampled, "out of", LENGTH_PER_TRIAL, "tokens\n")
vocab_v20230424.txt ADDED
The diff for this file is too large to render. See raw diff