Configs

Training a model in TRL will require you to set several configs: ModelConfig, which contains general info on the model being trained. TrainConfig, which contains things like training hyperparameters. And finally, MethodConfig, which contains hyperparameters or settings for the specific method being used (i.e. ILQL or PPO)

General

class trlx.data.configs.TRLConfig(model: trlx.data.configs.ModelConfig, train: trlx.data.configs.TrainConfig, method: trlx.data.method_configs.MethodConfig)[source]

Top level config for trlX. Loads configs and can be converted to dictionary.

classmethod load_yaml(yml_fp: str)[source]

Load yaml file as TRLConfig.

Parameters

yml_fp (str) – Path to yaml file

to_dict()[source]

Convert TRLConfig to dictionary.

class trlx.data.configs.ModelConfig(model_path: str, tokenizer_path: str, model_type: str, num_layers_unfrozen: int = - 1)[source]

Config for a model.

Parameters
  • model_path (str) – Path to the model (local or on huggingface hub)

  • tokenizer_path (str) – Path to the tokenizer (local or on huggingface hub)

  • model_type (str) – One of the registered RL models present in trlx.model

class trlx.data.configs.TrainConfig(total_steps: int, seq_length: int, epochs: int, batch_size: int, lr_ramp_steps: int, lr_decay_steps: int, weight_decay: float, learning_rate_init: float, learning_rate_target: float, opt_betas: Tuple[float], checkpoint_interval: int, eval_interval: int, pipeline: str, orchestrator: str, checkpoint_dir: str = 'ckpts', project_name: str = 'trlx', seed: int = 1000)[source]

Config for train job on model.

Parameters
  • total_steps (int) – Total number of training steps

  • seq_length (int) – Number of tokens to use as context (max length for tokenizer)

  • epochs (int) – Total number of passes through data

  • batch_size (int) – Batch size for training

  • lr_ramp_steps (int) – Number of steps before learning rate reaches learning_rate_init

  • lr_decay_steps (int) – Number of after ramp up steps before learning rate decays to learning_rate_target

  • weight_decay (float) – Weight decay for optimizer

  • learning_rate_init (float) – Initial learning rate after ramp up

  • learning_rate_target (float) – Target learning rate after decay

  • checkpoint_interval (int) – Save model every checkpoint_interval steps

  • eval_interval (int) – Evaluate model every eval_interval steps

  • pipeline (str) – Pipeline to use for training. One of the registered pipelines present in trlx.pipeline

  • orchestrator (str) – Orchestrator to use for training. One of the registered orchestrators present in trlx.orchestrator

  • project_name (str) – Project name for wandb

class trlx.data.method_configs.MethodConfig(name: str)[source]

Config for a certain RL method.

Parameters

name (str) – Name of the method

PPO

class trlx.data.method_configs.PPOConfig(name: str, ppo_epochs: int, num_rollouts: int, chunk_size: int, init_kl_coef: float, target: float, horizon: int, gamma: float, lam: float, cliprange: float, cliprange_value: float, vf_coef: float, gen_kwargs: dict)[source]
chunk_size: int
cliprange: float
cliprange_value: float
gamma: float
gen_kwargs: dict
horizon: int
init_kl_coef: float
lam: float
num_rollouts: int
ppo_epochs: int
target: float
vf_coef: float

ILQL

class trlx.data.method_configs.ILQLConfig(name: str, tau: float, gamma: float, cql_scale: float, awac_scale: float, alpha: float, steps_for_target_q_sync: int, betas: List[float], two_qs: bool)[source]
alpha: float
awac_scale: float
betas: List[float]
cql_scale: float
gamma: float
steps_for_target_q_sync: int
tau: float
two_qs: bool