File size: 633 Bytes
3b13f40
 
24d96ab
 
3b13f40
24d96ab
3b13f40
 
24d96ab
 
 
 
 
 
 
 
 
 
3b13f40
24d96ab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import timm
from timm import data
import torch.nn as nn
from torchvision import transforms

from src.config import TinyCLIPVisionConfig


def get_vision_base(
    config: TinyCLIPVisionConfig,
) -> tuple[nn.Module, int]:
    base = timm.create_model(config.vision_model, num_classes=0, pretrained=True)
    num_features = base.num_features
    return base, num_features


def get_vision_transform(config: TinyCLIPVisionConfig) -> transforms.Compose:
    timm_config = data.resolve_data_config({}, model=config.vision_model)
    transform = data.transforms_factory.create_transform(**timm_config)
    return transform  # type: ignore