File size: 1,427 Bytes
e170a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import argparse
from torch.utils.data import DataLoader
import lightning as L

from datasets import dataset_dict
from model import PL_RelPose, keypoint_dict
from configs.default import get_cfg_defaults


def main(args):
    config = get_cfg_defaults()
    config.merge_from_file(args.config)

    task = config.DATASET.TASK
    dataset = config.DATASET.DATA_SOURCE

    batch_size = config.TRAINER.BATCH_SIZE
    num_workers = config.TRAINER.NUM_WORKERS
    pin_memory = config.TRAINER.PIN_MEMORY
    
    test_num_keypoints = config.MODEL.TEST_NUM_KEYPOINTS
    
    build_fn = dataset_dict[task][dataset]
    testset = build_fn('test', config)
    testloader = DataLoader(testset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)

    pl_relpose = PL_RelPose.load_from_checkpoint(args.ckpt_path)
    pl_relpose.extractor = keypoint_dict[pl_relpose.hparams['features']](max_num_keypoints=test_num_keypoints, detection_threshold=0.0).eval()

    trainer = L.Trainer(
        devices=[0], 
    )
    
    trainer.test(pl_relpose, dataloaders=testloader)


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('config', type=str, help='.yaml configure file path')
    parser.add_argument('ckpt_path', type=str)

    return parser


if __name__ == "__main__":
    parser = get_parser()
    args = parser.parse_args()
    main(args)