scampion commited on
Commit
c480b17
1 Parent(s): 7b5fef3

Upload 4 files

Browse files
Files changed (4) hide show
  1. eurovoc.py +212 -0
  2. handler.py +75 -0
  3. mlb.pickle +3 -0
  4. requirements.txt +8 -0
eurovoc.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ import numpy as np
4
+ import pytorch_lightning as pl
5
+ import torch.nn as nn
6
+ from transformers import BertTokenizerFast as BertTokenizer, AdamW, get_linear_schedule_with_warmup, AutoTokenizer, AutoModel
7
+ from huggingface_hub import PyTorchModelHubMixin
8
+
9
+
10
+ class EurovocDataset(Dataset):
11
+
12
+ def __init__(
13
+ self,
14
+ text: np.array,
15
+ labels: np.array,
16
+ tokenizer: BertTokenizer,
17
+ max_token_len: int = 128
18
+ ):
19
+ self.tokenizer = tokenizer
20
+ self.text = text
21
+ self.labels = labels
22
+ self.max_token_len = max_token_len
23
+
24
+ def __len__(self):
25
+ return len(self.labels)
26
+
27
+ def __getitem__(self, index: int):
28
+ text = self.text[index][0]
29
+ labels = self.labels[index]
30
+
31
+ encoding = self.tokenizer.encode_plus(
32
+ text,
33
+ add_special_tokens=True,
34
+ max_length=self.max_token_len,
35
+ return_token_type_ids=False,
36
+ padding="max_length",
37
+ truncation=True,
38
+ return_attention_mask=True,
39
+ return_tensors='pt',
40
+ )
41
+
42
+ return dict(
43
+ text=text,
44
+ input_ids=encoding["input_ids"].flatten(),
45
+ attention_mask=encoding["attention_mask"].flatten(),
46
+ labels=torch.FloatTensor(labels)
47
+ )
48
+
49
+
50
+ class EuroVocLongTextDataset(Dataset):
51
+
52
+ def __splitter__(text, max_lenght):
53
+ l = text.split()
54
+ for i in range(0, len(l), max_lenght):
55
+ yield l[i:i + max_lenght]
56
+
57
+ def __init__(
58
+ self,
59
+ text: np.array,
60
+ labels: np.array,
61
+ tokenizer: BertTokenizer,
62
+ max_token_len: int = 128
63
+ ):
64
+ self.tokenizer = tokenizer
65
+ self.text = text
66
+ self.labels = labels
67
+ self.max_token_len = max_token_len
68
+
69
+ self.chunks_and_labels = [(c, l) for t, l in zip(self.text, self.labels) for c in self.__splitter__(t)]
70
+
71
+ self.encoding = self.tokenizer.batch_encode_plus(
72
+ [c for c, _ in self.chunks_and_labels],
73
+ add_special_tokens=True,
74
+ max_length=self.max_token_len,
75
+ return_token_type_ids=False,
76
+ padding="max_length",
77
+ truncation=True,
78
+ return_attention_mask=True,
79
+ return_tensors='pt',
80
+ )
81
+
82
+ def __len__(self):
83
+ return len(self.chunks_and_labels)
84
+
85
+ def __getitem__(self, index: int):
86
+ text, labels = self.chunks_and_labels[index]
87
+
88
+ return dict(
89
+ text=text,
90
+ input_ids=self.encoding[index]["input_ids"].flatten(),
91
+ attention_mask=self.encoding[index]["attention_mask"].flatten(),
92
+ labels=torch.FloatTensor(labels)
93
+ )
94
+
95
+
96
+ class EurovocDataModule(pl.LightningDataModule):
97
+
98
+ def __init__(self, bert_model_name, x_tr, y_tr, x_test, y_test, batch_size=8, max_token_len=512):
99
+ super().__init__()
100
+
101
+ self.batch_size = batch_size
102
+ self.x_tr = x_tr
103
+ self.y_tr = y_tr
104
+ self.x_test = x_test
105
+ self.y_test = y_test
106
+ self.tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
107
+ self.max_token_len = max_token_len
108
+
109
+ def setup(self, stage=None):
110
+ self.train_dataset = EurovocDataset(
111
+ self.x_tr,
112
+ self.y_tr,
113
+ self.tokenizer,
114
+ self.max_token_len
115
+ )
116
+
117
+ self.test_dataset = EurovocDataset(
118
+ self.x_test,
119
+ self.y_test,
120
+ self.tokenizer,
121
+ self.max_token_len
122
+ )
123
+
124
+ def train_dataloader(self):
125
+ return DataLoader(
126
+ self.train_dataset,
127
+ batch_size=self.batch_size,
128
+ shuffle=True,
129
+ num_workers=2
130
+ )
131
+
132
+ def val_dataloader(self):
133
+ return DataLoader(
134
+ self.test_dataset,
135
+ batch_size=self.batch_size,
136
+ num_workers=2
137
+ )
138
+
139
+ def test_dataloader(self):
140
+ return DataLoader(
141
+ self.test_dataset,
142
+ batch_size=self.batch_size,
143
+ num_workers=2
144
+ )
145
+
146
+
147
+ class EurovocTagger(pl.LightningModule, PyTorchModelHubMixin):
148
+
149
+ def __init__(self, bert_model_name, n_classes, lr=2e-5, eps=1e-8):
150
+ super().__init__()
151
+ self.bert = AutoModel.from_pretrained(bert_model_name)
152
+ self.dropout = nn.Dropout(p=0.2)
153
+ self.classifier1 = nn.Linear(self.bert.config.hidden_size, n_classes)
154
+ self.criterion = nn.BCELoss()
155
+ self.lr = lr
156
+ self.eps = eps
157
+
158
+ def forward(self, input_ids, attention_mask, labels=None):
159
+ output = self.bert(input_ids, attention_mask=attention_mask)
160
+ output = self.dropout(output.pooler_output)
161
+ output = self.classifier1(output)
162
+ output = torch.sigmoid(output)
163
+ loss = 0
164
+ if labels is not None:
165
+ loss = self.criterion(output, labels)
166
+ return loss, output
167
+
168
+ def training_step(self, batch, batch_idx):
169
+ input_ids = batch["input_ids"]
170
+ attention_mask = batch["attention_mask"]
171
+ labels = batch["labels"]
172
+ loss, outputs = self(input_ids, attention_mask, labels)
173
+ self.log("train_loss", loss, prog_bar=True, logger=True)
174
+ return {"loss": loss, "predictions": outputs, "labels": labels}
175
+
176
+ def validation_step(self, batch, batch_idx):
177
+ input_ids = batch["input_ids"]
178
+ attention_mask = batch["attention_mask"]
179
+ labels = batch["labels"]
180
+ loss, outputs = self(input_ids, attention_mask, labels)
181
+ self.log("val_loss", loss, prog_bar=True, logger=True)
182
+ return loss
183
+
184
+ def test_step(self, batch, batch_idx):
185
+ input_ids = batch["input_ids"]
186
+ attention_mask = batch["attention_mask"]
187
+ labels = batch["labels"]
188
+ loss, outputs = self(input_ids, attention_mask, labels)
189
+ self.log("test_loss", loss, prog_bar=True, logger=True)
190
+ return loss
191
+
192
+ def on_train_epoch_end(self, *args, **kwargs):
193
+ return
194
+ #labels = []
195
+ #predictions = []
196
+ #for output in args['outputs']:
197
+ # for out_labels in output["labels"].detach().cpu():
198
+ # labels.append(out_labels)
199
+ # for out_predictions in output["predictions"].detach().cpu():
200
+ # predictions.append(out_predictions)
201
+
202
+ #labels = torch.stack(labels).int()
203
+ #predictions = torch.stack(predictions)
204
+
205
+ #for i, name in enumerate(mlb.classes_):
206
+ # class_roc_auc = auroc(predictions[:, i], labels[:, i])
207
+ # self.logger.experiment.add_scalar(f"{name}_roc_auc/Train", class_roc_auc, self.current_epoch)
208
+
209
+
210
+ def configure_optimizers(self):
211
+ return torch.optim.AdamW(self.parameters(), lr=self.lr, eps=self.eps)
212
+
handler.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import numpy as np
3
+ import pickle
4
+
5
+ from sklearn.preprocessing import MultiLabelBinarizer
6
+ from transformers import AutoTokenizer
7
+ import torch
8
+
9
+ from eurovoc import EurovocTagger
10
+
11
+ BERT_MODEL_NAME = "nlpaueb/legal-bert-base-uncased"
12
+ MAX_LEN = 512
13
+ TEXT_MAX_LEN = MAX_LEN * 50
14
+ tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)
15
+
16
+
17
+ class EndpointHandler:
18
+ mlb = MultiLabelBinarizer()
19
+
20
+ def __init__(self, path=""):
21
+ self.mlb = pickle.load(open(f"{path}/mlb.pickle", "rb"))
22
+
23
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ self.model = EurovocTagger.from_pretrained(path,
25
+ bert_model_name=BERT_MODEL_NAME,
26
+ n_classes=len(self.mlb.classes_),
27
+ map_location=self.device)
28
+ self.model.eval()
29
+ self.model.freeze()
30
+
31
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
32
+ """
33
+ data args:
34
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
35
+ kwargs
36
+ Return:
37
+ A :obj:`list` | `dict`: will be serialized and returned
38
+ """
39
+
40
+ text = data.pop("inputs", data)
41
+ topk = data.pop("topk", 5)
42
+ threshold = data.pop("threshold", 0.16)
43
+ debug = data.pop("debug", False)
44
+ prediction = self.get_prediction(text)
45
+ results = [{"label": label, "score": float(score)} for label, score in
46
+ zip(self.mlb.classes_, prediction[0].tolist())]
47
+ results = sorted(results, key=lambda x: x["score"], reverse=True)
48
+ results = [r for r in results if r["score"] > threshold]
49
+ results = results[:topk]
50
+ if debug:
51
+ return {"results": results, "values": prediction, "input": text}
52
+ else:
53
+ return {"results": results}
54
+
55
+ def get_prediction(self, text):
56
+ # split text into chunks of MAX_LEN and get average prediction for each chunk
57
+ chunks = [text[i:i + MAX_LEN] for i in range(0, min(len(text), TEXT_MAX_LEN), MAX_LEN)]
58
+ predictions = [self._get_prediction(chunk) for chunk in chunks]
59
+ predictions = np.array(predictions).mean(axis=0)
60
+ return predictions
61
+
62
+ def _get_prediction(self, text):
63
+ item = tokenizer.encode_plus(
64
+ text,
65
+ add_special_tokens=True,
66
+ max_length=MAX_LEN,
67
+ return_token_type_ids=False,
68
+ padding="max_length",
69
+ truncation=True,
70
+ return_attention_mask=True,
71
+ return_tensors='pt')
72
+ _, prediction = self.model(item["input_ids"], item["attention_mask"])
73
+ prediction = prediction.cpu().detach().numpy()
74
+ return prediction
75
+
mlb.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:35015ecbd09a8524d555feb303f81788fc8be9dd28ae2eae9f4e05f7417b1d71
3
+ size 122082
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ datasets==2.13.1
2
+ ipykernel==6.24.0
3
+ lightning==2.0.5
4
+ pip-chill==1.0.3
5
+ scikit-learn==1.3.0
6
+ scikit-multilearn==0.2.0
7
+ transformers==4.30.2
8
+