File size: 4,761 Bytes
6faeba1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import numpy as np
import pandas as pd
import json
import argparse
from tqdm import tqdm
import os
from Utility.storage_config import MODELS_DIR

def approximate_and_inject_language_embeddings(model_path, df, iso_lookup, min_n_langs=5, max_n_langs=25, threshold_percentile=50):
    # load pretrained language_embeddings
    model = torch.load(model_path, map_location="cpu")
    lang_embs = model["model"]["encoder.language_embedding.weight"]

    features_per_closest_lang = 2
    # for combined, df has up to 5 features (if containing individual distances) per closest lang + 1 target lang column
    if "combined_dist_0" in df.columns: 
        if "map_dist_0" in df.columns:
            features_per_closest_lang += 1
        if "asp_dist_0" in df.columns:
            features_per_closest_lang += 1
        if "tree_dist_0" in df.columns:
            features_per_closest_lang += 1
        n_closest = len(df.columns) // features_per_closest_lang
        distance_type = "combined"
    # else, df has 2 features per closest lang + 1 target lang column
    else:
        n_closest = len(df.columns) // features_per_closest_lang
        if "map_dist_0" in df.columns:
            distance_type = "map"
        elif "tree_dist_0" in df.columns:
            distance_type = "tree"
        elif "asp_dist_0" in df.columns:
            distance_type = "asp"
        elif "learned_dist_0" in df.columns:
            distance_type = "learned"
        else:
            distance_type = "random"

    # get relevant columns
    closest_lang_columns = [f"closest_lang_{i}" for i in range(n_closest)]
    closest_dist_columns = [f"{distance_type}_dist_{i}" for i in range(n_closest)]
    closest_lang_columns = closest_lang_columns[:max_n_langs]
    closest_dist_columns = closest_dist_columns[:max_n_langs]
    assert df[closest_dist_columns[-1]].isna().sum().sum() == 0

    # get threshold based on distance of a certain percentile of the furthest language across all samples
    threshold = np.percentile(df[closest_dist_columns[-1]], threshold_percentile)
    print(f"threshold: {threshold:.4f}")
    for row in tqdm(df.itertuples(), total=df.shape[0], desc="Approximating language embeddings"):
        avg_emb = torch.zeros([16])
        dists = [getattr(row, d) for i, d in enumerate(closest_dist_columns) if i < min_n_langs or getattr(row, d) < threshold]
        langs = [getattr(row, l) for l in closest_lang_columns[:len(dists)]]

        for lang in langs:
            lang_emb = lang_embs[iso_lookup[-1][str(lang)]]
            avg_emb += lang_emb
        avg_emb /= len(langs) # normalize
        lang_embs[iso_lookup[-1][str(row.target_lang)]] = avg_emb

    # inject language embeddings into Toucan model and save
    model["model"]["encoder.language_embedding.weight"] = lang_embs
    modified_model_path = model_path.split(".")[0] + "_zeroshot_lang_embs.pt"
    torch.save(model, modified_model_path)
    print(f"Replaced unsupervised language embeddings with zero-shot approximations.\nSaved modified model to {modified_model_path}")


if __name__ == "__main__":
    default_model_path = os.path.join(MODELS_DIR, "ToucanTTS_Meta", "best.pt") # MODELS_DIR must be absolute path, the relative path will fail at this location
    default_csv_path = "distance_datasets/dataset_learned_top30.csv"
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default=default_model_path, help="path of the model for which the language embeddings should be modified")
    parser.add_argument("--dataset_path", type=str, default=default_csv_path, help="path to distance dataset CSV")
    parser.add_argument("--min_n_langs", type=int, default=5, help="minimum amount of languages used for averaging")
    parser.add_argument("--max_n_langs", type=int, default=25, help="maximum amount of languages used for averaging")
    parser.add_argument("--threshold_percentile", type=int, default=50, help="percentile of the furthest used languages \
                        used as cutoff threshold (no langs >= the threshold are used for averaging)")
    args = parser.parse_args() 
    ISO_LOOKUP_PATH = "iso_lookup.json"
    with open(ISO_LOOKUP_PATH, "r") as f:
        iso_lookup = json.load(f) # iso_lookup[-1] = iso2id mapping
    # load language distance dataset
    distance_df = pd.read_csv(args.dataset_path, sep="|")
    approximate_and_inject_language_embeddings(model_path=args.model_path,
                                  df=distance_df,
                                  iso_lookup=iso_lookup,
                                  min_n_langs=args.min_n_langs,
                                  max_n_langs=args.max_n_langs,
                                  threshold_percentile=args.threshold_percentile)