SRPose / eval.py
FrickYinn's picture
Upload 53 files
e170a8e verified
import argparse
from torch.utils.data import DataLoader
import lightning as L
from datasets import dataset_dict
from model import PL_RelPose, keypoint_dict
from configs.default import get_cfg_defaults
def main(args):
config = get_cfg_defaults()
config.merge_from_file(args.config)
task = config.DATASET.TASK
dataset = config.DATASET.DATA_SOURCE
batch_size = config.TRAINER.BATCH_SIZE
num_workers = config.TRAINER.NUM_WORKERS
pin_memory = config.TRAINER.PIN_MEMORY
test_num_keypoints = config.MODEL.TEST_NUM_KEYPOINTS
build_fn = dataset_dict[task][dataset]
testset = build_fn('test', config)
testloader = DataLoader(testset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)
pl_relpose = PL_RelPose.load_from_checkpoint(args.ckpt_path)
pl_relpose.extractor = keypoint_dict[pl_relpose.hparams['features']](max_num_keypoints=test_num_keypoints, detection_threshold=0.0).eval()
trainer = L.Trainer(
devices=[0],
)
trainer.test(pl_relpose, dataloaders=testloader)
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument('config', type=str, help='.yaml configure file path')
parser.add_argument('ckpt_path', type=str)
return parser
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
main(args)