tiny_clip / src /download.py
sachin's picture
Working download script for COCO dataset
5e1c8df
raw
history blame
1.54 kB
from io import BytesIO
import pathlib
from functools import partial
from typing import Any
import datasets
from PIL import Image
from loguru import logger
import requests
from tqdm.auto import tqdm
from src import config
def _save_resized_image(example: dict[str, Any], size: tuple[int, int], path: pathlib.Path):
# Download the image
image_url = example["url"]
image_path = path / image_url.rsplit("/", 1)[-1]
if image_path.exists():
return
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))
# Resize the image
image_resized = image.resize(size)
image_resized.save(image_path)
def _get_images(dataset: datasets.Dataset, path: pathlib.Path):
save_resized_image = partial(_save_resized_image, path=path, size=(256, 256))
dataset.map(save_resized_image, num_proc=128)
def _check_corrupt_images(image_file: pathlib.Path):
try:
with Image.open(image_file) as img:
img.verify() # Verify the integrity of the image
except (IOError, SyntaxError) as e:
logger.error(f"Corrupt image: {image_file}")
if __name__ == "__main__":
hyper_parameters = config.TrainerConfig()
dataset = datasets.load_dataset(
hyper_parameters._data_config.dataset,
split="train",
)
config.IMAGE_DOWNLOAD_PATH.mkdir(parents=True, exist_ok=True)
_get_images(dataset, config.IMAGE_DOWNLOAD_PATH) # type: ignore
for image in tqdm(config.IMAGE_DOWNLOAD_PATH.iterdir()):
_check_corrupt_images(image)