File size: 4,121 Bytes
49235ad
f90d455
 
49235ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f90d455
49235ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f90d455
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from pathlib import Path

import hydra
import torch
from omegaconf import DictConfig
from slider import Beatmap

from osudiffusion import DiT_models
from osuT5.inference import Preprocessor, Pipeline, Postprocessor, DiffisionPipeline
from osuT5.tokenizer import Tokenizer
from osuT5.utils import get_model


def get_args_from_beatmap(args: DictConfig):
    if args.beatmap_path is None or args.beatmap_path == "":
        return

    beatmap_path = Path(args.beatmap_path)

    if not beatmap_path.is_file():
        raise FileNotFoundError(f"Beatmap file {beatmap_path} not found.")

    beatmap = Beatmap.from_path(beatmap_path)
    args.audio_path = beatmap_path.parent / beatmap.audio_filename
    args.output_path = beatmap_path.parent
    args.bpm = beatmap.bpm_max()
    args.offset = min(tp.offset.total_seconds() * 1000 for tp in beatmap.timing_points)
    args.slider_multiplier = beatmap.slider_multiplier
    args.title = beatmap.title
    args.artist = beatmap.artist
    args.beatmap_id = beatmap.beatmap_id if args.beatmap_id == -1 else args.beatmap_id
    args.diffusion.style_id = beatmap.beatmap_id if args.diffusion.style_id == -1 else args.diffusion.style_id
    args.difficulty = float(beatmap.stars()) if args.difficulty == -1 else args.difficulty


def find_model(ckpt_path, args: DictConfig, device):
    assert Path(ckpt_path).exists(), f"Could not find DiT checkpoint at {ckpt_path}"
    checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
    if "ema" in checkpoint:  # supports checkpoints from train.py
        checkpoint = checkpoint["ema"]

    model = DiT_models[args.diffusion.model](
        num_classes=args.diffusion.num_classes,
        context_size=19 - 3 + 128,
    ).to(device)
    model.load_state_dict(checkpoint)
    model.eval()  # important!
    return model


@hydra.main(config_path="configs", config_name="inference", version_base="1.1")
def main(args: DictConfig):
    get_args_from_beatmap(args)

    torch.set_grad_enabled(False)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ckpt_path = Path(args.model_path)
    model_state = torch.load(ckpt_path / "pytorch_model.bin", map_location=device)
    tokenizer_state = torch.load(ckpt_path / "custom_checkpoint_0.pkl")

    tokenizer = Tokenizer()
    tokenizer.load_state_dict(tokenizer_state)

    model = get_model(args, tokenizer)
    model.load_state_dict(model_state)
    model.eval()
    model.to(device)
    
    preprocessor = Preprocessor(args)
    audio = preprocessor.load(args.audio_path)
    sequences = preprocessor.segment(audio)
    total_duration_ms = len(audio) / 16000 * 1000
    args.total_duration_ms = total_duration_ms




    
    generated_maps = []
    generated_positions = []
    diffs = []

    
    if args.full_set:
        for i in range(args.set_difficulties):
            diffs.append(3 + i * (7 - 3) / (args.set_difficulties - 1))
            
        print(diffs)
        for diff in diffs:
            print(f"Generating difficulty {diff}")
            args.difficulty = diff
            pipeline = Pipeline(args, tokenizer)
            events = pipeline.generate(model, sequences)
            generated_maps.append(events)
    else:
        pipeline = Pipeline(args, tokenizer)
        events = pipeline.generate(model, sequences)
        generated_maps.append(events)
    
    
        
    if args.generate_positions:
        model = find_model(args.diff_ckpt, args, device)
        refine_model = find_model(args.diff_refine_ckpt, args, device) if len(args.diff_refine_ckpt) > 0 else None
        diffusion_pipeline = DiffisionPipeline(args.diffusion)
        for events in generated_maps:
            events = diffusion_pipeline.generate(model, events, refine_model)
            generated_positions.append(events)
    else:
        generated_positions = generated_maps
        
    postprocessor = Postprocessor(args)
    postprocessor.generate(generated_positions)


if __name__ == "__main__":
    main()