Elron commited on
Commit
88a9416
1 Parent(s): 2109a58

Upload templates.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. templates.py +69 -42
templates.py CHANGED
@@ -1,6 +1,5 @@
1
  import json
2
  from abc import abstractmethod
3
- from dataclasses import field
4
  from typing import Any, Dict, List, Optional, Tuple
5
 
6
  from .collections import ListCollection
@@ -14,12 +13,21 @@ class Template(StreamInstanceOperator):
14
  """The role of template is to take the fields of every instance and verbalize it.
15
 
16
  Meaning the template is taking the instance and generating source, target and references.
 
 
 
 
 
 
 
17
  """
18
 
19
  skip_rendered_instance: bool = NonPositionalField(default=True)
20
  postprocessors: List[str] = NonPositionalField(
21
  default_factory=lambda: ["processors.to_string_stripped"]
22
  )
 
 
23
 
24
  def process(
25
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
@@ -35,7 +43,7 @@ class Template(StreamInstanceOperator):
35
  inputs = instance.get("inputs")
36
  outputs = instance.get("outputs")
37
 
38
- source = self.inputs_to_source(inputs)
39
  target, references = self.outputs_to_target_and_references(outputs)
40
 
41
  return {
@@ -43,10 +51,12 @@ class Template(StreamInstanceOperator):
43
  "source": source,
44
  "target": target,
45
  "references": references,
 
 
46
  }
47
 
48
  @abstractmethod
49
- def inputs_to_source(self, inputs: Dict[str, object]) -> str:
50
  pass
51
 
52
  @abstractmethod
@@ -72,13 +82,17 @@ class InputOutputTemplate(Template):
72
  data = {k: ", ".join(v) if isinstance(v, list) else v for k, v in data.items()}
73
  return template.format(**data)
74
 
75
- def inputs_to_source(self, inputs: Dict[str, object]) -> str:
76
- try:
77
- return self.process_template(self.input_format, inputs)
78
- except KeyError as e:
79
- raise KeyError(
80
- f"Available inputs are {list(inputs.keys())} but input format requires a different ones: '{self.input_format}'"
81
- ) from e
 
 
 
 
82
 
83
  def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
84
  try:
@@ -92,6 +106,25 @@ class InputOutputTemplate(Template):
92
  return target, references
93
 
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  class MultipleChoiceTemplate(Template):
96
  """Formats the input (that specifies the question), the multiple choices to select the answer from, and specifies the field with the correct answer."""
97
 
@@ -149,19 +182,22 @@ class MultipleChoiceTemplate(Template):
149
  )
150
  return enumrated_choices
151
 
152
- def inputs_to_source(self, inputs: Dict[str, object]) -> str:
153
  choices = self.get_choices(inputs, self.source_choice_format)
154
  inputs = {
155
  "numerals": ",".join(self.get_choices(inputs, "{choice_numeral}")),
156
  **inputs,
157
  self.choices_field: self.choices_seperator.join(choices),
158
  }
159
- try:
160
- return self.input_format.format(**inputs)
161
- except KeyError as e:
162
- raise KeyError(
163
- f"Available inputs are {inputs.keys()} but input format requires a different one: {self.input_format}"
164
- ) from e
 
 
 
165
 
166
  def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
167
  target = outputs[self.target_field]
@@ -221,20 +257,20 @@ class YesNoTemplate(Template):
221
  label_field: str = None
222
  yes_answer: str = "Yes"
223
  no_answer: str = "No"
224
- postprocessors: List[str] = field(
225
- default_factory=lambda: ["processors.to_string_stripped"]
226
- )
227
 
228
- def inputs_to_source(self, inputs: Dict[str, object]) -> str:
229
- try:
230
- data = {
231
- k: ", ".join(v) if isinstance(v, list) else v for k, v in inputs.items()
232
- }
233
- return self.input_format.format(**data)
234
- except KeyError as e:
235
- raise RuntimeError(
236
- f"Available inputs are {list(inputs.keys())} but input format requires a different one: {self.input_format}"
237
- ) from e
 
 
 
238
 
239
  def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
240
  try:
@@ -266,9 +302,6 @@ class YesNoTemplate(Template):
266
  return self.yes_answer, [self.yes_answer]
267
  return self.no_answer, [self.no_answer]
268
 
269
- def get_postprocessors(self) -> List[str]:
270
- return self.postprocessors
271
-
272
 
273
  class KeyValTemplate(Template):
274
  """Generate field 'source' from fields designated as input, and fields 'target' and 'references' from fields designated as output, of the processed instance.
@@ -282,10 +315,6 @@ class KeyValTemplate(Template):
282
  outputs_key_val_seperator: str = ": "
283
  use_keys_for_outputs: bool = False
284
 
285
- postprocessors: List[str] = field(
286
- default_factory=lambda: ["processors.to_string_stripped"]
287
- )
288
-
289
  def process_dict(
290
  self, dic: Dict[str, object], key_val_sep, pairs_sep, use_keys
291
  ) -> str:
@@ -299,13 +328,14 @@ class KeyValTemplate(Template):
299
  pairs.append(key_val_sep.join(key_val))
300
  return pairs_sep.join(pairs)
301
 
302
- def inputs_to_source(self, inputs: Dict[str, object]) -> str:
303
- return self.process_dict(
304
  inputs,
305
  key_val_sep=self.key_val_seperator,
306
  pairs_sep=self.pairs_seperator,
307
  use_keys=self.use_keys_for_inputs,
308
  )
 
309
 
310
  def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
311
  target = self.process_dict(
@@ -316,9 +346,6 @@ class KeyValTemplate(Template):
316
  )
317
  return target, [target]
318
 
319
- def get_postprocessors(self) -> List[str]:
320
- return self.postprocessors
321
-
322
 
323
  class OutputQuantizingTemplate(InputOutputTemplate):
324
  quantum: float = 0.1
 
1
  import json
2
  from abc import abstractmethod
 
3
  from typing import Any, Dict, List, Optional, Tuple
4
 
5
  from .collections import ListCollection
 
13
  """The role of template is to take the fields of every instance and verbalize it.
14
 
15
  Meaning the template is taking the instance and generating source, target and references.
16
+
17
+ Args:
18
+ skip_rendered_instance (bool): if "source", "target", and "references" are already defined fields in the instance, skip its processing
19
+ postprocessors: a list of strings being artifact names of text processors, to be applied on the model output
20
+ instruction: a formatting string that yields an instruction with potential participation of values from the "inputs" part of the instance
21
+ target_prefix: a string to be used to format the prompt. Not a formatting string.
22
+
23
  """
24
 
25
  skip_rendered_instance: bool = NonPositionalField(default=True)
26
  postprocessors: List[str] = NonPositionalField(
27
  default_factory=lambda: ["processors.to_string_stripped"]
28
  )
29
+ instruction: str = NonPositionalField(default_factory=lambda: "")
30
+ target_prefix: str = NonPositionalField(default_factory=lambda: "")
31
 
32
  def process(
33
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
 
43
  inputs = instance.get("inputs")
44
  outputs = instance.get("outputs")
45
 
46
+ source, instruction = self.inputs_to_source(inputs)
47
  target, references = self.outputs_to_target_and_references(outputs)
48
 
49
  return {
 
51
  "source": source,
52
  "target": target,
53
  "references": references,
54
+ "instruction": instruction,
55
+ "target_prefix": self.target_prefix.format(**inputs),
56
  }
57
 
58
  @abstractmethod
59
+ def inputs_to_source(self, inputs: Dict[str, object]) -> Tuple[str, str]:
60
  pass
61
 
62
  @abstractmethod
 
82
  data = {k: ", ".join(v) if isinstance(v, list) else v for k, v in data.items()}
83
  return template.format(**data)
84
 
85
+ def inputs_to_source(self, inputs: Dict[str, object]) -> Tuple[str, str]:
86
+ formatted = []
87
+ for formatting in [self.input_format, self.instruction]:
88
+ try:
89
+ formatted.append(self.process_template(formatting, inputs))
90
+ except KeyError as e:
91
+ raise KeyError(
92
+ f"Available inputs are {list(inputs.keys())} but input format requires a different ones: '{formatting}'"
93
+ ) from e
94
+
95
+ return tuple(formatted)
96
 
97
  def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
98
  try:
 
106
  return target, references
107
 
108
 
109
+ class InputOutputReferenceTemplate(InputOutputTemplate):
110
+ reference: str
111
+
112
+ def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
113
+ output_fields = {}
114
+ for name, val in [
115
+ ("target", self.output_format),
116
+ ("reference", self.reference),
117
+ ]:
118
+ try:
119
+ result = self.process_template(val, outputs)
120
+ output_fields[name] = result
121
+ except KeyError as e:
122
+ raise KeyError(
123
+ f"Available outputs are {outputs.keys()} but {name} requires a different one: {val}"
124
+ ) from e
125
+ return output_fields["target"], [output_fields["reference"]]
126
+
127
+
128
  class MultipleChoiceTemplate(Template):
129
  """Formats the input (that specifies the question), the multiple choices to select the answer from, and specifies the field with the correct answer."""
130
 
 
182
  )
183
  return enumrated_choices
184
 
185
+ def inputs_to_source(self, inputs: Dict[str, object]) -> Tuple[str, str]:
186
  choices = self.get_choices(inputs, self.source_choice_format)
187
  inputs = {
188
  "numerals": ",".join(self.get_choices(inputs, "{choice_numeral}")),
189
  **inputs,
190
  self.choices_field: self.choices_seperator.join(choices),
191
  }
192
+ formatted = []
193
+ for formatting in [self.input_format, self.instruction]:
194
+ try:
195
+ formatted.append(formatting.format(**inputs))
196
+ except KeyError as e:
197
+ raise KeyError(
198
+ f"Available inputs are {inputs.keys()} but input format requires a different one: {formatting}"
199
+ ) from e
200
+ return tuple(formatted)
201
 
202
  def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
203
  target = outputs[self.target_field]
 
257
  label_field: str = None
258
  yes_answer: str = "Yes"
259
  no_answer: str = "No"
 
 
 
260
 
261
+ def inputs_to_source(self, inputs: Dict[str, object]) -> Tuple[str, str]:
262
+ data = {
263
+ k: ", ".join(v) if isinstance(v, list) else v for k, v in inputs.items()
264
+ }
265
+ formatted = []
266
+ for formatting in [self.input_format, self.instruction]:
267
+ try:
268
+ formatted.append(formatting.format(**data))
269
+ except KeyError as e:
270
+ raise RuntimeError(
271
+ f"Available inputs are {list(inputs.keys())} but input format requires a different one: {formatting}"
272
+ ) from e
273
+ return tuple(formatted)
274
 
275
  def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
276
  try:
 
302
  return self.yes_answer, [self.yes_answer]
303
  return self.no_answer, [self.no_answer]
304
 
 
 
 
305
 
306
  class KeyValTemplate(Template):
307
  """Generate field 'source' from fields designated as input, and fields 'target' and 'references' from fields designated as output, of the processed instance.
 
315
  outputs_key_val_seperator: str = ": "
316
  use_keys_for_outputs: bool = False
317
 
 
 
 
 
318
  def process_dict(
319
  self, dic: Dict[str, object], key_val_sep, pairs_sep, use_keys
320
  ) -> str:
 
328
  pairs.append(key_val_sep.join(key_val))
329
  return pairs_sep.join(pairs)
330
 
331
+ def inputs_to_source(self, inputs: Dict[str, object]) -> Tuple[str, str]:
332
+ ret = self.process_dict(
333
  inputs,
334
  key_val_sep=self.key_val_seperator,
335
  pairs_sep=self.pairs_seperator,
336
  use_keys=self.use_keys_for_inputs,
337
  )
338
+ return (ret, ret)
339
 
340
  def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
341
  target = self.process_dict(
 
346
  )
347
  return target, [target]
348
 
 
 
 
349
 
350
  class OutputQuantizingTemplate(InputOutputTemplate):
351
  quantum: float = 0.1