import timm from timm import data import torch.nn as nn from torchvision import transforms from src.config import TinyCLIPVisionConfig def get_vision_base( config: TinyCLIPVisionConfig, ) -> tuple[nn.Module, int]: base = timm.create_model(config.vision_model, num_classes=0, pretrained=True) num_features = base.num_features return base, num_features def get_vision_transform(config: TinyCLIPVisionConfig) -> transforms.Compose: timm_config = data.resolve_data_config({}, model=config.vision_model) transform = data.transforms_factory.create_transform(**timm_config) return transform # type: ignore