import json import time import random from typing import Literal import requests import zstandard as zstd from torch.utils.data import IterableDataset, get_worker_info Subset = Literal["train", "val", "test"] URLs = { "val": [ "https://the-eye.eu/public/AI/pile/val.jsonl.zst", ], "test": [ "https://the-eye.eu/public/AI/pile/test.jsonl.zst", ], "train": [ "https://the-eye.eu/public/AI/pile/train/00.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/01.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/02.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/03.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/04.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/05.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/06.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/07.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/08.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/09.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/10.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/11.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/12.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/13.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/14.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/15.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/16.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/17.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/18.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/19.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/20.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/21.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/22.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/23.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/24.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/25.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/26.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/27.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/28.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/29.jsonl.zst", ], } def _read_line_from_stream(reader, initial_line="", buffer_size=4096): line = initial_line while True: c = reader.read(buffer_size) if not c: raise StopIteration line += c.decode("utf-8") if "\n" in line: break return line.split("\n", 1) def _line_streamer(reader, buffer_size=4096): rest = "" while True: try: line, rest = _read_line_from_stream( reader, rest, buffer_size, ) yield line except StopIteration: break class ThePile(IterableDataset): TEXT_BUFFER_SIZE = 4096 def __init__(self, subset: Subset): self.subset = subset def __iter__(self): urls = URLs[self.subset].copy() while True: wi = get_worker_info() seed = wi.id if wi is not None else None rnd = random.Random(seed) rnd.shuffle(urls) for url in urls: r = requests.get(url, stream=True) with zstd.ZstdDecompressor().stream_reader(r.raw) as reader: for line in _line_streamer(reader, self.TEXT_BUFFER_SIZE): data = json.loads(line) yield data if __name__ == "__main__": from tqdm import tqdm dataset = ThePile("train") for data in tqdm(dataset, smoothing=0.01): pass # Average: ~2000 samples/sec/worker