RL Models

RL Models are what you’re training with trlX. Currently, we support PPO and ILQL. Note that new models must be registered with trlx.model.register_model.

General

class trlx.model.BaseRLModel(config: trlx.data.configs.TRLConfig, train_mode=False)[source]
abstract act(data: trlx.data.RLElement)trlx.data.RLElement[source]

Given RLElement with state, produce an action and add it to the RLElement. Orchestrator should call this, get reward and push subsequent RLElement to RolloutStore

abstract get_components()Dict[str, any][source]

Get pytorch components (mainly for saving/loading)

intervals(steps: int)Dict[str, bool][source]

Using config and current step number, returns a dict of whether certain things should be done

abstract learn(log_fn: Optional[Callable] = None, save_fn: Optional[Callable] = None, eval_fn: Optional[Callable] = None)[source]

Use experiences in RolloutStore to learn

Parameters
  • log_fn (Callable[Dict[str, any]]) – Optional function that is called when logging and passed a dict of logging relevant values

  • save_fn (Callable[Dict[str, any]]) – Optional function to call after saving. Is passed the components.

  • eval_fn (Callable[BaseRLModel]) – Optional function to call during evaluation. Eval doesn’t do anything without this.

load(fp: str, title: str = 'OUT')[source]

Try to load all components from specified path under a folder with given title

abstract sample(prompts: Iterable[str], length: int, n_samples: int)Iterable[str][source]

Sample from the language. Takes prompts and maximum length to generate.

Parameters
  • prompts – List of prompts to tokenize and use as context

  • length (int) – How many new tokens to genrate for each prompt

  • n_samples – Default behavior is to take number of prompts as this

save(fp: str, title: str = 'OUT')[source]

Try to save all components to specified path under a folder with given title

class trlx.model.accelerate_base_model.AccelerateRLModel(config, train_mode=True)[source]

RL Model that uses accelerate for training

add_eval_pipeline(eval_pipeline)[source]

Adds pipeline from with validation prompts

evaluate()[source]

Samples model on eval_prompts, logs stats with reward_fn or metric_fn if provided

generate(input_ids, attention_mask=None, **kwargs)[source]

Wraps hf’s generate adding some specific method’s defaults

abstract get_arch(config: trlx.data.configs.TRLConfig)[source]

Returns a specific wrapper of the decoder architecture

get_components()Dict[str, any][source]

Get pytorch components (mainly for saving/loading)

learn()[source]

Samples batches from self.store, updates model and periodically evaluates it on self.eval_dataloader

abstract loss(batch)Tuple[float, Dict][source]

Compute loss on a batch from store and return some statistics

abstract post_backward_callback()[source]

Do something after model update

abstract post_epoch_callback()[source]

Do something after exhausting/single pass over self.store

save(directory=None)[source]

Creates checkpoint of optimizer, scheduler and a model

tokenize(text: Iterable[str])[source]

Tokenize a batch of text after adding bos token to each of the samples

PPO

class trlx.model.accelerate_ppo_model.AcceleratePPOModel(config)[source]
get_arch(config: trlx.data.configs.TRLConfig)[source]

Returns a specific wrapper of the decoder architecture

loss(batch)[source]

Compute loss on a batch from store and return some statistics

post_backward_callback()[source]

Do something after model update

post_epoch_callback()[source]

Do something after exhausting/single pass over self.store

ILQL

class trlx.model.accelerate_ilql_model.AccelerateILQLModel(config, logit_mask=None, metric_fn=None, train_mode=True)[source]
get_arch(config)[source]

Returns a specific wrapper of the decoder architecture

loss(batch)[source]

Compute loss on a batch from store and return some statistics

post_backward_callback()[source]

Do something after model update

tokenize(texts: Union[Iterable[str], Iterable[torch.LongTensor]])[source]

Tokenize a batch of text after adding bos token to each of the samples

class trlx.model.nn.ilql_models.CausalLMWithValueHeads(config: Union[transformers.configuration_utils.PretrainedConfig, str], params, num_layers_unfrozen=- 1)[source]

This is a wrapper around huggingface AutoModelForCausalLM with two additional scalar heads

forward(input_ids, attention_mask=None, position_ids=None, past_key_values=None, actions_ixs=None, states_ixs=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

generate(input_ids, attention_mask=None, position_ids=None, past_key_values=None, beta=1, max_length=32, temperature=1, top_k=20, logit_mask=None, logs=True, pad_token_id=50256, eos_token_id=50256)[source]

Generates samples akin to hf’s .generate but with custom logp prepossessing: changing token probabilities as to how advantageous they would be according to value functions estimations.