Fix train script for NPSC
Browse files
run_speech_recognition_ctc.py
CHANGED
@@ -391,6 +391,23 @@ def main():
|
|
391 |
# Set seed before initializing model.
|
392 |
set_seed(training_args.seed)
|
393 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
394 |
# 1. First, let's load the dataset
|
395 |
raw_datasets = DatasetDict()
|
396 |
|
@@ -401,6 +418,8 @@ def main():
|
|
401 |
split=data_args.train_split_name,
|
402 |
use_auth_token=data_args.use_auth_token,
|
403 |
)
|
|
|
|
|
404 |
|
405 |
if data_args.audio_column_name not in raw_datasets["train"].column_names:
|
406 |
raise ValueError(
|
@@ -426,6 +445,8 @@ def main():
|
|
426 |
split=data_args.eval_split_name,
|
427 |
use_auth_token=data_args.use_auth_token,
|
428 |
)
|
|
|
|
|
429 |
|
430 |
if data_args.max_eval_samples is not None:
|
431 |
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
|
|
|
391 |
# Set seed before initializing model.
|
392 |
set_seed(training_args.seed)
|
393 |
|
394 |
+
# Pre-processing dataset
|
395 |
+
def preprocess_dataset(entry):
|
396 |
+
return (
|
397 |
+
"<INAUDIBLE>" not in entry["text"]
|
398 |
+
and entry["sentence_language_code"].lower() == "nb-no"
|
399 |
+
)
|
400 |
+
|
401 |
+
def map_dataset(entry):
|
402 |
+
return {"text": (entry["text"]
|
403 |
+
.lower()
|
404 |
+
.replace("<ee>", "eee")
|
405 |
+
.replace("<mm>", "mmm")
|
406 |
+
.replace("<qq>", "qqq")
|
407 |
+
.replace("ó", "o")
|
408 |
+
.replace("é", "e")
|
409 |
+
)}
|
410 |
+
|
411 |
# 1. First, let's load the dataset
|
412 |
raw_datasets = DatasetDict()
|
413 |
|
|
|
418 |
split=data_args.train_split_name,
|
419 |
use_auth_token=data_args.use_auth_token,
|
420 |
)
|
421 |
+
raw_datasets["train"] = raw_datasets["train"].filter(preprocess_dataset)
|
422 |
+
raw_datasets["train"] = raw_datasets["train"].map(map_dataset)
|
423 |
|
424 |
if data_args.audio_column_name not in raw_datasets["train"].column_names:
|
425 |
raise ValueError(
|
|
|
445 |
split=data_args.eval_split_name,
|
446 |
use_auth_token=data_args.use_auth_token,
|
447 |
)
|
448 |
+
raw_datasets["eval"] = raw_datasets["eval"].filter(preprocess_dataset)
|
449 |
+
raw_datasets["eval"] = raw_datasets["eval"].map(map_dataset)
|
450 |
|
451 |
if data_args.max_eval_samples is not None:
|
452 |
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
|