|
import sys |
|
import os |
|
import librosa |
|
import numpy as np |
|
import torch |
|
import audio_to_text.captioning.models |
|
import audio_to_text.captioning.models.encoder |
|
import audio_to_text.captioning.models.decoder |
|
import audio_to_text.captioning.utils.train_util as train_util |
|
|
|
|
|
def load_model(config, checkpoint): |
|
ckpt = torch.load(checkpoint, "cpu") |
|
encoder_cfg = config["model"]["encoder"] |
|
encoder = train_util.init_obj( |
|
audio_to_text.captioning.models.encoder, |
|
encoder_cfg |
|
) |
|
if "pretrained" in encoder_cfg: |
|
pretrained = encoder_cfg["pretrained"] |
|
train_util.load_pretrained_model(encoder, |
|
pretrained, |
|
sys.stdout.write) |
|
decoder_cfg = config["model"]["decoder"] |
|
if "vocab_size" not in decoder_cfg["args"]: |
|
decoder_cfg["args"]["vocab_size"] = len(ckpt["vocabulary"]) |
|
decoder = train_util.init_obj( |
|
audio_to_text.captioning.models.decoder, |
|
decoder_cfg |
|
) |
|
if "word_embedding" in decoder_cfg: |
|
decoder.load_word_embedding(**decoder_cfg["word_embedding"]) |
|
if "pretrained" in decoder_cfg: |
|
pretrained = decoder_cfg["pretrained"] |
|
train_util.load_pretrained_model(decoder, |
|
pretrained, |
|
sys.stdout.write) |
|
model = train_util.init_obj(audio_to_text.captioning.models, config["model"], |
|
encoder=encoder, decoder=decoder) |
|
train_util.load_pretrained_model(model, ckpt) |
|
model.eval() |
|
return { |
|
"model": model, |
|
"vocabulary": ckpt["vocabulary"] |
|
} |
|
|
|
|
|
def decode_caption(word_ids, vocabulary): |
|
candidate = [] |
|
for word_id in word_ids: |
|
word = vocabulary[word_id] |
|
if word == "<end>": |
|
break |
|
elif word == "<start>": |
|
continue |
|
candidate.append(word) |
|
candidate = " ".join(candidate) |
|
return candidate |
|
|
|
|
|
class AudioCapModel(object): |
|
def __init__(self,weight_dir,device='cpu'): |
|
config = os.path.join(weight_dir,'config.yaml') |
|
self.config = train_util.parse_config_or_kwargs(config) |
|
checkpoint = os.path.join(weight_dir,'swa.pth') |
|
resumed = load_model(self.config, checkpoint) |
|
model = resumed["model"] |
|
self.vocabulary = resumed["vocabulary"] |
|
self.model = model.to(device) |
|
self.device = device |
|
|
|
def caption(self,audio_list): |
|
if isinstance(audio_list,np.ndarray): |
|
audio_list = [audio_list] |
|
elif isinstance(audio_list,str): |
|
audio_list = [librosa.load(audio_list,sr=32000)[0]] |
|
|
|
captions = [] |
|
for wav in audio_list: |
|
inputwav = torch.as_tensor(wav).float().unsqueeze(0).to(self.device) |
|
wav_len = torch.LongTensor([len(wav)]) |
|
input_dict = { |
|
"mode": "inference", |
|
"wav": inputwav, |
|
"wav_len": wav_len, |
|
"specaug": False, |
|
"sample_method": "beam", |
|
} |
|
print(input_dict) |
|
out_dict = self.model(input_dict) |
|
caption_batch = [decode_caption(seq, self.vocabulary) for seq in \ |
|
out_dict["seq"].cpu().numpy()] |
|
captions.extend(caption_batch) |
|
return captions |
|
|
|
|
|
|
|
def __call__(self, audio_list): |
|
return self.caption(audio_list) |
|
|
|
|
|
|
|
|