m3hrdadfi commited on
Commit
4744005
1 Parent(s): bb3b0e4

Update runner

Browse files
Files changed (1) hide show
  1. src/run_wav2vec2_pretrain_flax.py +14 -3
src/run_wav2vec2_pretrain_flax.py CHANGED
@@ -200,11 +200,23 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
200
  )
201
  mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
202
 
 
 
 
 
 
 
 
 
 
 
 
203
  # sample randomly masked indices
204
  batch["mask_time_indices"] = _compute_mask_indices(
205
- (batch["input_values"].shape[0], mask_indices_seq_length),
206
  self.model.config.mask_time_prob,
207
  self.model.config.mask_time_length,
 
208
  min_masks=2,
209
  )
210
 
@@ -216,7 +228,6 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
216
 
217
  return batch
218
 
219
-
220
  def configure_logger(model_args: ModelArguments, training_args: TrainingArguments):
221
  logging.basicConfig(
222
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -348,7 +359,7 @@ def main():
348
  do_normalize=True
349
  )
350
 
351
- target_sampling_rate = 16_000
352
  def prepare_dataset(batch):
353
  # check that all files have the correct sampling rate
354
  # batch["speech"], _ = librosa.load(batch[data_args.speech_file_column], sr=feature_extractor.sampling_rate)
 
200
  )
201
  mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
202
 
203
+ batch_size = batch["input_values"].shape[0]
204
+
205
+ if batch["attention_mask"] is not None:
206
+ output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1))
207
+ attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8)
208
+
209
+ # these two operations makes sure that all values
210
+ # before the output lengths indices are attended to
211
+ attention_mask[(np.arange(attention_mask.shape[0]), output_lengths - 1)] = 1
212
+ attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
213
+
214
  # sample randomly masked indices
215
  batch["mask_time_indices"] = _compute_mask_indices(
216
+ (batch_size, mask_indices_seq_length),
217
  self.model.config.mask_time_prob,
218
  self.model.config.mask_time_length,
219
+ attention_mask=attention_mask,
220
  min_masks=2,
221
  )
222
 
 
228
 
229
  return batch
230
 
 
231
  def configure_logger(model_args: ModelArguments, training_args: TrainingArguments):
232
  logging.basicConfig(
233
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
 
359
  do_normalize=True
360
  )
361
 
362
+ target_sampling_rate = feature_extractor.sampling_rate
363
  def prepare_dataset(batch):
364
  # check that all files have the correct sampling rate
365
  # batch["speech"], _ = librosa.load(batch[data_args.speech_file_column], sr=feature_extractor.sampling_rate)