sachin commited on
Commit
5e1c8df
1 Parent(s): 35352c6

Working download script for COCO dataset

Browse files
Files changed (2) hide show
  1. src/config.py +11 -5
  2. src/download.py +54 -0
src/config.py CHANGED
@@ -1,10 +1,16 @@
 
 
1
  import pydantic
2
 
 
 
 
 
3
 
4
  class DataConfig(pydantic.BaseModel):
5
  buffer_size: int = 1000
6
- data_len: int = 100000
7
- train_len: int = 90000
8
  small_dataset: str = "laion/220k-gpt4vision-captions-from-livis"
9
  large_dataset: str = "laion/laion400m"
10
  dataset: str = small_dataset
@@ -16,7 +22,7 @@ class ModelConfig(pydantic.BaseModel):
16
  projection_layers: int = 3
17
  embed_dim: int = 256
18
  transformer_embed_dim: int = 768
19
- max_len: int = 77 # maximum length of text in CLIP
20
  cls_type: bool = True
21
  freeze_vision_base: bool = False
22
  freeze_text_base: bool = False
@@ -36,5 +42,5 @@ class TrainerConfig(pydantic.BaseModel):
36
 
37
  run_openai_clip: bool = False
38
 
39
- model_config: ModelConfig = ModelConfig()
40
- data_config: DataConfig = DataConfig()
 
1
+ import pathlib
2
+
3
  import pydantic
4
 
5
+ MAX_DOWNLOAD_TIME = 0.2
6
+
7
+ IMAGE_DOWNLOAD_PATH = pathlib.Path("/tmp/images")
8
+
9
 
10
  class DataConfig(pydantic.BaseModel):
11
  buffer_size: int = 1000
12
+ data_len: int = 100
13
+ train_len: int = 90
14
  small_dataset: str = "laion/220k-gpt4vision-captions-from-livis"
15
  large_dataset: str = "laion/laion400m"
16
  dataset: str = small_dataset
 
22
  projection_layers: int = 3
23
  embed_dim: int = 256
24
  transformer_embed_dim: int = 768
25
+ max_len: int = 128 # 77
26
  cls_type: bool = True
27
  freeze_vision_base: bool = False
28
  freeze_text_base: bool = False
 
42
 
43
  run_openai_clip: bool = False
44
 
45
+ _model_config: ModelConfig = ModelConfig()
46
+ _data_config: DataConfig = DataConfig()
src/download.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ import pathlib
3
+ from functools import partial
4
+ from typing import Any
5
+
6
+ import datasets
7
+ from PIL import Image
8
+ from loguru import logger
9
+ import requests
10
+ from tqdm.auto import tqdm
11
+
12
+ from src import config
13
+
14
+
15
+ def _save_resized_image(example: dict[str, Any], size: tuple[int, int], path: pathlib.Path):
16
+ # Download the image
17
+ image_url = example["url"]
18
+ image_path = path / image_url.rsplit("/", 1)[-1]
19
+ if image_path.exists():
20
+ return
21
+
22
+ response = requests.get(image_url)
23
+ image = Image.open(BytesIO(response.content))
24
+ # Resize the image
25
+ image_resized = image.resize(size)
26
+ image_resized.save(image_path)
27
+
28
+
29
+ def _get_images(dataset: datasets.Dataset, path: pathlib.Path):
30
+ save_resized_image = partial(_save_resized_image, path=path, size=(256, 256))
31
+ dataset.map(save_resized_image, num_proc=128)
32
+
33
+
34
+ def _check_corrupt_images(image_file: pathlib.Path):
35
+ try:
36
+ with Image.open(image_file) as img:
37
+ img.verify() # Verify the integrity of the image
38
+ except (IOError, SyntaxError) as e:
39
+ logger.error(f"Corrupt image: {image_file}")
40
+
41
+
42
+ if __name__ == "__main__":
43
+ hyper_parameters = config.TrainerConfig()
44
+
45
+ dataset = datasets.load_dataset(
46
+ hyper_parameters._data_config.dataset,
47
+ split="train",
48
+ )
49
+
50
+ config.IMAGE_DOWNLOAD_PATH.mkdir(parents=True, exist_ok=True)
51
+ _get_images(dataset, config.IMAGE_DOWNLOAD_PATH) # type: ignore
52
+
53
+ for image in tqdm(config.IMAGE_DOWNLOAD_PATH.iterdir()):
54
+ _check_corrupt_images(image)