File size: 6,789 Bytes
cc5f321
7aa5a5e
058c80a
3c36ff5
 
6502654
cc5f321
3c36ff5
d08fbc6
a4795aa
d08fbc6
058c80a
3c36ff5
 
 
 
cc5f321
3c36ff5
a4795aa
 
3c36ff5
 
 
 
 
d08fbc6
3c36ff5
d08fbc6
 
 
 
 
3c36ff5
 
d08fbc6
058c80a
 
 
 
 
 
d08fbc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
058c80a
 
cc5f321
 
 
058c80a
 
 
 
 
 
 
d08fbc6
058c80a
 
cc5f321
058c80a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d08fbc6
058c80a
cc5f321
 
 
d08fbc6
058c80a
 
3c36ff5
 
7aa5a5e
 
058c80a
d08fbc6
058c80a
 
7aa5a5e
d08fbc6
 
7aa5a5e
 
d08fbc6
7aa5a5e
 
 
d08fbc6
7aa5a5e
 
 
058c80a
 
d08fbc6
 
cc5f321
d08fbc6
cc5f321
 
 
d08fbc6
 
 
058c80a
cc5f321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d08fbc6
 
cc5f321
 
d08fbc6
cc5f321
 
 
d08fbc6
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import json
from functools import lru_cache
from typing import Any, Dict, List, Optional, Union

from .artifact import fetch_artifact
from .dataset_utils import get_dataset_artifact
from .inference import InferenceEngine, LogProbInferenceEngine
from .logging_utils import get_logger
from .metric_utils import _compute, _inference_post_process
from .operator import SourceOperator
from .schema import UNITXT_DATASET_SCHEMA
from .standard import StandardRecipe

logger = get_logger()


def load(source: Union[SourceOperator, str]):
    assert isinstance(
        source, (SourceOperator, str)
    ), "source must be a SourceOperator or a string"
    if isinstance(source, str):
        source, _ = fetch_artifact(source)
    return source().to_dataset()


def _get_recipe_from_query(dataset_query: str) -> StandardRecipe:
    dataset_query = dataset_query.replace("sys_prompt", "instruction")
    try:
        dataset_stream, _ = fetch_artifact(dataset_query)
    except:
        dataset_stream = get_dataset_artifact(dataset_query)
    return dataset_stream


def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> StandardRecipe:
    recipe_attributes = list(StandardRecipe.__dict__["__fields__"].keys())
    for param in dataset_params.keys():
        assert param in recipe_attributes, (
            f"The parameter '{param}' is not an attribute of the 'StandardRecipe' class. "
            f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
        )
    return StandardRecipe(**dataset_params)


def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
    if dataset_query and dataset_args:
        raise ValueError(
            "Cannot provide 'dataset_query' and key-worded arguments at the same time. "
            "If you want to load dataset from a card in local catalog, use query only. "
            "Otherwise, use key-worded arguments only to specify properties of dataset."
        )

    if dataset_query:
        if not isinstance(dataset_query, str):
            raise ValueError(
                f"If specified, 'dataset_query' must be a string, however, "
                f"'{dataset_query}' was provided instead, which is of type "
                f"'{type(dataset_query)}'."
            )

    if not dataset_query and not dataset_args:
        raise ValueError(
            "Either 'dataset_query' or key-worded arguments must be provided."
        )


def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> StandardRecipe:
    if isinstance(dataset_query, StandardRecipe):
        return dataset_query

    _verify_dataset_args(dataset_query, kwargs)

    if dataset_query:
        recipe = _get_recipe_from_query(dataset_query)

    if kwargs:
        recipe = _get_recipe_from_dict(kwargs)

    return recipe


def load_dataset(
    dataset_query: Optional[str] = None, streaming: bool = False, **kwargs
):
    """Loads dataset.

    If the 'dataset_query' argument is provided, then dataset is loaded from a card in local
    catalog based on parameters specified in the query.
    Alternatively, dataset is loaded from a provided card based on explicitly given parameters.

    Args:
        dataset_query (str, optional): A string query which specifies a dataset to load from local catalog or name of specific recipe or benchmark in the catalog.
            For example:
            "card=cards.wnli,template=templates.classification.multi_class.relation.default".
        streaming (bool, False): When True yields the data as Unitxt streams dictionary
        **kwargs: Arguments used to load dataset from provided card, which is not present in local catalog.

    Returns:
        DatasetDict

    Examples:
        dataset = load_dataset(
            dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
        )  # card must be present in local catalog

        card = TaskCard(...)
        template = Template(...)
        loader_limit = 10
        dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
    """
    recipe = load_recipe(dataset_query, **kwargs)

    if streaming:
        return recipe()

    return recipe().to_dataset(features=UNITXT_DATASET_SCHEMA)


def evaluate(predictions, data) -> List[Dict[str, Any]]:
    return _compute(predictions=predictions, references=data)


def post_process(predictions, data) -> List[Dict[str, Any]]:
    return _inference_post_process(predictions=predictions, references=data)


@lru_cache
def _get_produce_with_cache(dataset_query: Optional[str] = None, **kwargs):
    return load_recipe(dataset_query, **kwargs).produce


def produce(instance_or_instances, dataset_query: Optional[str] = None, **kwargs):
    is_list = isinstance(instance_or_instances, list)
    if not is_list:
        instance_or_instances = [instance_or_instances]
    result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
    if not is_list:
        result = result[0]
    return result


def infer(
    instance_or_instances,
    engine: InferenceEngine,
    dataset_query: Optional[str] = None,
    return_data: bool = False,
    return_log_probs: bool = False,
    return_meta_data: bool = False,
    **kwargs,
):
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
    engine, _ = fetch_artifact(engine)
    if return_log_probs:
        if not isinstance(engine, LogProbInferenceEngine):
            raise NotImplementedError(
                f"Error in infer: return_log_probs set to True but supplied engine "
                f"{engine.__class__.__name__} does not support logprobs."
            )
        infer_outputs = engine.infer_log_probs(dataset, return_meta_data)
        raw_predictions = (
            [output.prediction for output in infer_outputs]
            if return_meta_data
            else infer_outputs
        )
        raw_predictions = [
            json.dumps(raw_prediction) for raw_prediction in raw_predictions
        ]
    else:
        infer_outputs = engine.infer(dataset, return_meta_data)
        raw_predictions = (
            [output.prediction for output in infer_outputs]
            if return_meta_data
            else infer_outputs
        )
    predictions = post_process(raw_predictions, dataset)
    if return_data:
        for prediction, raw_prediction, instance, infer_output in zip(
            predictions, raw_predictions, dataset, infer_outputs
        ):
            if return_meta_data:
                instance["infer_meta_data"] = infer_output.__dict__
                del instance["infer_meta_data"]["prediction"]
            instance["prediction"] = prediction
            instance["raw_prediction"] = raw_prediction
        return dataset
    return predictions