tiny_clip / src /vision_model.py
sachin's picture
Refactoring to distangle modules
24d96ab
raw
history blame contribute delete
633 Bytes
import timm
from timm import data
import torch.nn as nn
from torchvision import transforms
from src.config import TinyCLIPVisionConfig
def get_vision_base(
config: TinyCLIPVisionConfig,
) -> tuple[nn.Module, int]:
base = timm.create_model(config.vision_model, num_classes=0, pretrained=True)
num_features = base.num_features
return base, num_features
def get_vision_transform(config: TinyCLIPVisionConfig) -> transforms.Compose:
timm_config = data.resolve_data_config({}, model=config.vision_model)
transform = data.transforms_factory.create_transform(**timm_config)
return transform # type: ignore