Trainers

Abstract Trainers

class trlx.trainer.BaseRLTrainer(config: trlx.data.configs.TRLConfig, reward_fn=None, metric_fn=None, logit_mask=None, stop_sequences=None, train_mode=False)[source]
abstract learn()[source]

Use data in the the rollout store to update the model

push_to_store(data)[source]

Append new data to the rollout store

class trlx.trainer.accelerate_base_trainer.AccelerateRLTrainer(config, **kwargs)[source]

Asbtract Trainer that uses accelerate backend

add_eval_pipeline(eval_pipeline)[source]

Adds a evalution pipeline with validation prompts

abstract create_train_dataloader()[source]

Returns a new dataloader for training.

decode(prompts: List[torch.LongTensor], samples: List[torch.LongTensor], prompt_sizes: Optional[torch.LongTensor] = None, append_eos_token: bool = False)Tuple[List[str], List[str], List[str]][source]

Decodes tensor generations into lists of strings (samples: List[str], prompts: List[str], outputs: List[str])

evaluate()[source]

Samples model using eval_prompts, computes statistics with reward_fn and metric_fn

generate(input_ids, attention_mask=None, **kwargs)[source]

Generate samples for the experience buffer using method’s specific self.generate_experience_kwargs

generate_eval(input_ids, attention_mask=None, **kwargs)[source]

Generate samples for evaluation using self.generate_kwargs

abstract get_arch(config: trlx.data.configs.TRLConfig)[source]

Returns a specific wrapper given a model’s architecture

learn()[source]

Samples batches from self.store, updates model and periodically evaluates it on self.eval_dataloader

load(directory: Optional[str] = None, **kwargs)[source]

Loads the checkpoint of the optimizer, scheduler and the model

abstract loss(batch)Tuple[float, Dict][source]

Computes loss on a batch of data and returns statistics

abstract post_backward_callback()[source]

Do something after model update

abstract post_epoch_callback()[source]

Do something after a single pass over data from self.store

abstract prepare_learning()[source]

Do something before the start of training

save(directory: Optional[str] = None, **kwargs)[source]

Creates a checkpoint for the optimizer, scheduler and the model

save_pretrained(directory: Optional[str] = None, **kwargs)[source]

Save the underlying model, tokenizer, and configuration files to a directory

Args:
directory (str, optional): The directory to save the trainer files to.

NOTE: If not specified, the model will be saved to a directory named hf_model in the checkpoint directory as specified by the Trainer’s config.

**kwargs: Additional keyword arguments passed to the underlying Hugging Face model’s

save_pretrained method.

setup_model()[source]

Returns a model derived from an instance’s TRLConfig

setup_optimizer()[source]

Returns an optimizer derived from an instance’s TRLConfig

setup_scheduler()[source]

Returns a learning rate scheduler derived from an instance’s TRLConfig

Accelerate Trainers

class trlx.trainer.accelerate_ppo_trainer.AcceleratePPOTrainer(config: trlx.data.configs.TRLConfig, **kwargs)[source]

PPO Accelerate Trainer

add_prompt_pipeline(pipeline: trlx.pipeline.offline_pipeline.PromptPipeline)[source]

Add a prompt pipeline dataloader to a trainer instance for the make_experience stage

create_train_dataloader()[source]

Returns a new dataloader for training.

get_arch(config: trlx.data.configs.TRLConfig)[source]

Returns a specific wrapper given a model’s architecture

loss(batch: trlx.data.ppo_types.PPORLBatch)Tuple[float, Dict[str, Any]][source]

Computes loss on a batch of data and returns statistics

Args:

batch: PPORLBatch Previous batch of episodes

Returns:

loss: Float Loss value stats: Dict[str, Any] PPO Statistics values

make_experience(num_rollouts: int = 1024, iter_count: int = 0)[source]

Takes chunk_size number of prompts from prompt_iterator, samples from the model and then computes the KL against a reference model. Finally it then appends PPOElements to trainer’s store.

Args:

num_rollouts: Number of rollouts to generate iter_count: Total number of updates for all batches & epochs

post_backward_callback()[source]

Do something after model update

post_epoch_callback()[source]

Clears the rollout store and creates num_rollouts new samples

prepare_learning()[source]

Do something before the start of training

save_pretrained(directory: Optional[str] = None, **kwargs)[source]
Args:
directory (str, optional): The directory to save the trainer files to.

NOTE: If not specified, the model will be saved to a directory named hf_model in the checkpoint directory as specified by the Trainer’s config.

**kwargs: Additional keyword arguments passed to the underlying Hugging Face model’s

save_pretrained method.

setup_rollout_logging(config)[source]

Make rollout logging directory to log rollouts to

class trlx.trainer.accelerate_ilql_trainer.AccelerateILQLTrainer(config: trlx.data.configs.TRLConfig, **kwargs)[source]
create_train_dataloader()[source]

Returns a new dataloader for training.

get_arch(config)[source]

Returns a specific wrapper given a model’s architecture

loss(batch: Union[trlx.data.ilql_types.ILQLBatch, trlx.data.ilql_types.ILQLSeq2SeqBatch])[source]

Computes loss on a batch of data and returns statistics

make_experience(samples, rewards, max_length=2048)[source]

Tokenizes samples and shapes rewards into proper tensors and then inserts the resulting dataset into the trainer

make_experience_seq2seq(samples, rewards, max_length=2048)[source]

Tokenizes samples and shapes rewards into proper tensors and then inserts the resulting dataset into the trainer

post_backward_callback()[source]

Do something after model update

prepare_learning()[source]

Do something before the start of training

class trlx.trainer.accelerate_sft_trainer.AccelerateSFTTrainer(config: trlx.data.configs.TRLConfig, **kwargs)[source]
create_train_dataloader()[source]

Returns a new dataloader for training.

get_arch(config)[source]

Returns a specific wrapper given a model’s architecture

loss(batch)[source]

Computes loss on a batch of data and returns statistics

prepare_learning()[source]

Do something before the start of training

NeMo Trainers