metric / fusion.py
Elron's picture
Upload fusion.py with huggingface_hub
f747a71 verified
raw
history blame
3.61 kB
import copy
from abc import abstractmethod
from typing import Generator, List, Optional
from .dataclass import NonPositionalField
from .operator import SourceOperator
from .random_utils import new_random_generator
from .stream import MultiStream, Stream
class BaseFusion(SourceOperator):
"""BaseFusion operator that combines multiple streams into one.
Args:
include_splits: List of splits to include. If None, all splits are included.
"""
origins: List[SourceOperator]
include_splits: Optional[List[str]] = NonPositionalField(default=None)
@abstractmethod
def fusion_generator(self, split) -> Generator:
pass
def splits(self) -> Generator:
splits = []
for origin in self.origins:
for s in origin().keys():
if s not in splits:
if self.include_splits is None or s in self.include_splits:
splits.append(s)
return splits
def process(
self,
) -> MultiStream:
result = {}
for split in self.splits():
result[split] = Stream(self.fusion_generator, gen_kwargs={"split": split})
return MultiStream(result)
class FixedFusion(BaseFusion):
"""FixedFusion operator that combines multiple streams into one based on a fixed number of examples per task.
Args:
origins: List of SourceOperator objects.
examples_per_task: Number of examples per task. If None, all examples are returned.
splits: List of splits to include. If None, all splits are included.
"""
max_instances_per_origin: Optional[int] = None
def fusion_generator(self, split) -> Generator:
for origin in self.origins:
iterator = iter(origin()[split])
if self.max_instances_per_origin is not None:
for _ in range(self.max_instances_per_origin):
try:
yield next(iterator)
except StopIteration:
break
else:
yield from iterator
class WeightedFusion(BaseFusion):
"""Fusion operator that combines multiple streams based.
Args:
origins: List of SourceOperator objects.
weights: List of weights for each origin.
max_total_examples: Total number of examples to return. If None, all examples are returned.
"""
origins: List[SourceOperator] = None
weights: List[float] = None
max_total_examples: int = None
def verify(self):
super().verify()
assert self.origins is not None, "origins must be specified"
assert self.weights is not None, "weights must be specified"
assert len(self.origins) == len(
self.weights
), "origins and weights must have the same length"
def fusion_generator(self, split) -> Generator:
weights = copy.deepcopy(self.weights)
iterators = [iter(origin()[split]) for origin in self.origins]
total_examples = 0
random_generator = new_random_generator(sub_seed="weighted_fusion_" + split)
while (
self.max_total_examples is None or total_examples <= self.max_total_examples
) and len(iterators) > 0:
iterator = random_generator.choices(population=iterators, weights=weights)[
0
]
try:
yield next(iterator)
total_examples += 1
except StopIteration:
index = iterators.index(iterator)
iterators.pop(index)
weights.pop(index)