Text Generation
English
Eval Results
gpt2 / README.md
d-matrix
Update README.md
e394a6e verified
|
raw
history blame
2.17 kB
---
license: apache-2.0
datasets:
- wikitext
- ptb_text_only
language:
- en
metrics:
- perplexity
pipeline_tag: text-generation
model-index:
- name: distilgpt2
results:
- task:
type: text-generation
dataset:
name: penn_treebank
type: ptb_text_only
metrics:
- name: perlexity@BASELINE
type: dmx-perlexity
value: 63.45857238769531
- name: perlexity@FALLBACK
type: dmx-perlexity
value: 64.36720275878906
- task:
type: text-generation
dataset:
name: wikitext2
type: wikitext-2-raw-v1
metrics:
- name: perlexity@BASELINE
type: dmx-perlexity
value: 46.05925369262695
- name: perlexity@FALLBACK
type: dmx-perlexity
value: 46.570838928222656
---
This is a quantized version of [DistilGPT2](https://huggingface.co/distilbert/distilgpt2). We provide the following two quantization configurations:
BASELINE: Everything in original format, equivalent to original model.
FALLBACK: Quantized Linear and Conv1D layers to BFP16. Added approximation functions for Layer Norm, GELU and Softmax.
### Usage Example
Prerequisites:
- Install dmx-mltools: "pip install dmx-mltools"
- clone this repo. "cd" to the cloned repo.
```python
>>> import os
>>> from mltools import dmx
>>> from transformers import pipeline
>>> import evaluate
>>> from datasets import load_dataset
>>> my_hf_token = os.environ.get("HUGGING_FACE_HUB_TOKEN")
>>> pipe = pipeline(
>>> "text-generation",
>>> model="d-matrix/distilgpt2",
>>> use_auth_token=my_hf_token,
>>> trust_remote_code=True,
>>> # device_map="auto", # use this line for enabling pipeline parallel
>>> )
>>> pipe.model = dmx.Model(
>>> pipe.model, monkey_patched=False, hf=True, input_names=["input_ids", "labels"]
>>> )
>>> pipe.model.transform("FALLBACK.yaml")
>>> perplexity = evaluate.load("d-matrix/dmx_perplexity", module_type="metric")
>>> input_texts = load_dataset("ptb_text_only", "penn_treebank", split="test")["sentence"]
>>> results = perplexity.compute(model=pipe.model.body, references=input_texts)
>>> print(results)
{'loss': 4.164604187011719, 'perplexity': 64.36720275878906}
```