Trainers
Abstract Trainers
- class trlx.trainer.BaseRLTrainer(config: trlx.data.configs.TRLConfig, reward_fn=None, metric_fn=None, logit_mask=None, stop_sequences=None, train_mode=False)[source]
- class trlx.trainer.accelerate_base_trainer.AccelerateRLTrainer(config, **kwargs)[source]
Asbtract Trainer that uses accelerate backend
- decode(prompts: List[torch.LongTensor], samples: List[torch.LongTensor], prompt_sizes: Optional[torch.LongTensor] = None, append_eos_token: bool = False) → Tuple[List[str], List[str], List[str]][source]
Decodes tensor generations into lists of strings (samples: List[str], prompts: List[str], outputs: List[str])
- evaluate()[source]
Samples model using eval_prompts, computes statistics with reward_fn and metric_fn
- generate(input_ids, attention_mask=None, **kwargs)[source]
Generate samples for the experience buffer using method’s specific self.generate_experience_kwargs
- generate_eval(input_ids, attention_mask=None, **kwargs)[source]
Generate samples for evaluation using self.generate_kwargs
- abstract get_arch(config: trlx.data.configs.TRLConfig)[source]
Returns a specific wrapper given a model’s architecture
- learn()[source]
Samples batches from self.store, updates model and periodically evaluates it on self.eval_dataloader
- load(directory: Optional[str] = None, **kwargs)[source]
Loads the checkpoint of the optimizer, scheduler and the model
- abstract loss(batch) → Tuple[float, Dict][source]
Computes loss on a batch of data and returns statistics
- save(directory: Optional[str] = None, **kwargs)[source]
Creates a checkpoint for the optimizer, scheduler and the model
- save_pretrained(directory: Optional[str] = None, **kwargs)[source]
Save the underlying model, tokenizer, and configuration files to a directory
- Args:
- directory (str, optional): The directory to save the trainer files to.
NOTE: If not specified, the model will be saved to a directory named hf_model in the checkpoint directory as specified by the Trainer’s config.
- **kwargs: Additional keyword arguments passed to the underlying Hugging Face model’s
save_pretrained method.
Accelerate Trainers
- class trlx.trainer.accelerate_ppo_trainer.AcceleratePPOTrainer(config: trlx.data.configs.TRLConfig, **kwargs)[source]
PPO Accelerate Trainer
- add_prompt_pipeline(pipeline: trlx.pipeline.offline_pipeline.PromptPipeline)[source]
Add a prompt pipeline dataloader to a trainer instance for the make_experience stage
- get_arch(config: trlx.data.configs.TRLConfig)[source]
Returns a specific wrapper given a model’s architecture
- loss(batch: trlx.data.ppo_types.PPORLBatch) → Tuple[float, Dict[str, Any]][source]
Computes loss on a batch of data and returns statistics
- Args:
batch: PPORLBatch Previous batch of episodes
- Returns:
loss: Float Loss value stats: Dict[str, Any] PPO Statistics values
- make_experience(num_rollouts: int = 1024, iter_count: int = 0)[source]
Takes chunk_size number of prompts from prompt_iterator, samples from the model and then computes the KL against a reference model. Finally it then appends PPOElements to trainer’s store.
- Args:
num_rollouts: Number of rollouts to generate iter_count: Total number of updates for all batches & epochs
- save_pretrained(directory: Optional[str] = None, **kwargs)[source]
- Args:
- directory (str, optional): The directory to save the trainer files to.
NOTE: If not specified, the model will be saved to a directory named hf_model in the checkpoint directory as specified by the Trainer’s config.
- **kwargs: Additional keyword arguments passed to the underlying Hugging Face model’s
save_pretrained method.
- class trlx.trainer.accelerate_ilql_trainer.AccelerateILQLTrainer(config: trlx.data.configs.TRLConfig, **kwargs)[source]
-
- loss(batch: Union[trlx.data.ilql_types.ILQLBatch, trlx.data.ilql_types.ILQLSeq2SeqBatch])[source]
Computes loss on a batch of data and returns statistics
- make_experience(samples, rewards, max_length=2048)[source]
Tokenizes samples and shapes rewards into proper tensors and then inserts the resulting dataset into the trainer
- class trlx.trainer.accelerate_sft_trainer.AccelerateSFTTrainer(config: trlx.data.configs.TRLConfig, **kwargs)[source]