TRL documentation

Iterative Trainer

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Iterative Trainer

Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code.

Usage

To get started quickly, instantiate an instance a model, and a tokenizer.


model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

trainer = IterativeSFTTrainer(
    model,
    tokenizer
)

You have the choice to either provide a list of strings or a list of tensors to the step function.

Using a list of tensors as input:


inputs = {
    "input_ids": input_ids,
    "attention_mask": attention_mask
}

trainer.step(**inputs)

Using a list of strings as input:


inputs = {
    "texts": texts
}

trainer.step(**inputs)

For causal language models, labels will automatically be created from input_ids or from texts. When using sequence to sequence models you will have to provide your own labels or text_labels.

IterativeTrainer

class trl.IterativeSFTTrainer

< >

( model: Optional = None args: Optional = None tokenizer: Optional = None optimizers: Tuple = (None, None) data_collator: Optional = None eval_dataset: Union = None max_length: Optional = None truncation_mode: Optional = 'keep_end' preprocess_logits_for_metrics: Optional = None compute_metrics: Optional = None optimize_device_cache: Optional = False )

Parameters

  • model (PreTrainedModel) — Model to be optimized, either an ‘AutoModelForCausalLM’ or an ‘AutoModelForSeq2SeqLM’. Check the documentation of PreTrainedModel for more details.
  • args (transformers.TrainingArguments) — The arguments to use for training.
  • tokenizer (PreTrainedTokenizerBase) — Tokenizer to be used for encoding the data. Check the documentation of transformers.PreTrainedTokenizer and transformers.PreTrainedTokenizerFast for more details.
  • optimizers (Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]) — The optimizer and scheduler to use for training.
  • data_collator (Union[DataCollatorForLanguageModeling, DataCollatorForSeq2Seq], optional) — Data collator to be used for training and passed along the dataloader.
  • eval_dataset (datasets.Dataset) — The dataset to use for evaluation.
  • max_length (int, defaults to None) — The maximum length of the input.
  • truncation_mode (str, defaults to keep_end) — The truncation mode to use, either keep_end or keep_start.
  • preprocess_logits_for_metrics (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) — The function to use to preprocess the logits before computing the metrics.
  • compute_metrics (Callable[[EvalPrediction], Dict], optional) — The function to use to compute the metrics. Must take a EvalPrediction and return a dictionary string to metric values.
  • optimize_device_cache (bool, optional, defaults to False) — Optimize CUDA cache for slightly more memory-efficient training.

The IterativeSFTTrainer can be used to finetune models with methods that requires some steps between optimization.

step

< >

( input_ids: Optional = None attention_mask: Optional = None labels: Optional = None texts: Optional = None texts_labels: Optional = None ) dict[str, Any]

Parameters

  • input_ids (Listtorch.LongTensor) — List of tensors containing the input_ids (if not provided, text will be used)
  • attention_mask (Listtorch.LongTensor, , optional) — List of tensors containing the attention_mask
  • labels (Listtorch.FloatTensor, optional) — List of tensors containing the labels (if set to None, will default to input_ids)
  • texts (Liststr, optional) — List of strings containing the text input (if not provided, input_ids will directly be used)
  • texts_labels (Liststr, optional) — List of strings containing the text labels (if set to None, will default to text)

Returns

dict[str, Any]

A summary of the training statistics

Run an optimisation step given a list of input_ids, attention_mask, and labels or a list of text and text_labels.

< > Update on GitHub