RL Models
RL Models are what you’re training with trlX. Currently, we support PPO and ILQL.
Note that new models must be registered with trlx.model.register_model
.
General
- class trlx.model.BaseRLModel(config: trlx.data.configs.TRLConfig, train_mode=False)[source]
- abstract act(data: trlx.data.RLElement) → trlx.data.RLElement[source]
Given RLElement with state, produce an action and add it to the RLElement. Orchestrator should call this, get reward and push subsequent RLElement to RolloutStore
- abstract get_components() → Dict[str, any][source]
Get pytorch components (mainly for saving/loading)
- intervals(steps: int) → Dict[str, bool][source]
Using config and current step number, returns a dict of whether certain things should be done
- abstract learn(log_fn: Optional[Callable] = None, save_fn: Optional[Callable] = None, eval_fn: Optional[Callable] = None)[source]
Use experiences in RolloutStore to learn
- Parameters
log_fn (Callable[Dict[str, any]]) – Optional function that is called when logging and passed a dict of logging relevant values
save_fn (Callable[Dict[str, any]]) – Optional function to call after saving. Is passed the components.
eval_fn (Callable[BaseRLModel]) – Optional function to call during evaluation. Eval doesn’t do anything without this.
- load(fp: str, title: str = 'OUT')[source]
Try to load all components from specified path under a folder with given title
- abstract sample(prompts: Iterable[str], length: int, n_samples: int) → Iterable[str][source]
Sample from the language. Takes prompts and maximum length to generate.
- Parameters
prompts – List of prompts to tokenize and use as context
length (int) – How many new tokens to genrate for each prompt
n_samples – Default behavior is to take number of prompts as this
- class trlx.model.accelerate_base_model.AccelerateRLModel(config, train_mode=True)[source]
RL Model that uses accelerate for training
- evaluate()[source]
Samples model on eval_prompts, logs stats with reward_fn or metric_fn if provided
- generate(input_ids, attention_mask=None, **kwargs)[source]
Wraps hf’s generate adding some specific method’s defaults
- abstract get_arch(config: trlx.data.configs.TRLConfig)[source]
Returns a specific wrapper of the decoder architecture
- learn()[source]
Samples batches from self.store, updates model and periodically evaluates it on self.eval_dataloader
PPO
- class trlx.model.accelerate_ppo_model.AcceleratePPOModel(config)[source]
- get_arch(config: trlx.data.configs.TRLConfig)[source]
Returns a specific wrapper of the decoder architecture
- class trlx.model.nn.ppo_models.GPTHeadWithValueModel(config: Union[transformers.configuration_utils.PretrainedConfig, str])[source]
The GPTHeadWithValueModel class implements a GPT-type language model with a secondary, scalar head.
- forward(input_ids=None, past_key_values=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, mc_token_ids=None, lm_labels=None, mc_labels=None, return_dict=False, output_attentions=False, output_hidden_states=False)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class trlx.model.nn.ppo_models.ModelBranch(config, transformer_blocks, ln_f, lm_head)[source]
ModelBranch implements the frozen upper trunk of the reference model used when computing the PPO KL-divergence penalty. Expects a list of frozen transformer blocks and an lm_head from the base model.
- forward(hidden_states: torch.Tensor, output_shape: torch.Tensor, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = False) → Union[Tuple, trlx.model.nn.ppo_models.CausalLMOutputWithCrossAttentions][source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class trlx.model.nn.ppo_models.GPTHydraHeadWithValueModel(config: Union[transformers.configuration_utils.PretrainedConfig, str], num_layers_unfrozen: int = - 1)[source]
The GPTHeadWithValueModel class implements a GPT-type language model with a secondary, scalar head.
- forward(input_ids=None, attention_mask=None, past_key_values=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, mc_token_ids=None, lm_labels=None, mc_labels=None, return_dict=False, output_attentions=False, output_hidden_states=False)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
ILQL
- class trlx.model.accelerate_ilql_model.AccelerateILQLModel(config, logit_mask=None, metric_fn=None, train_mode=True)[source]
- class trlx.model.nn.ilql_models.CausalLMWithValueHeads(config: Union[transformers.configuration_utils.PretrainedConfig, str], params, num_layers_unfrozen=- 1)[source]
This is a wrapper around huggingface AutoModelForCausalLM with two additional scalar heads
- forward(input_ids, attention_mask=None, position_ids=None, past_key_values=None, actions_ixs=None, states_ixs=None)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- generate(input_ids, attention_mask=None, position_ids=None, past_key_values=None, beta=1, max_length=32, temperature=1, top_k=20, logit_mask=None, logs=True, pad_token_id=50256, eos_token_id=50256)[source]
Generates samples akin to hf’s .generate but with custom logp prepossessing: changing token probabilities as to how advantageous they would be according to value functions estimations.