Source code for

from dataclasses import dataclass
from typing import Any, Dict, Tuple

import yaml

from import MethodConfig, get_method

[docs]@dataclass class ModelConfig: """ Config for a model. :param model_path: Path to the model (local or on huggingface hub) :type model_path: str :param tokenizer_path: Path to the tokenizer (local or on huggingface hub) :type tokenizer_path: str :param model_type: One of the registered RL models present in trlx.model :type model_type: str """ model_path: str tokenizer_path: str model_type: str # One of the architectures present in framework.model num_layers_unfrozen: int = -1 @classmethod def from_dict(cls, config: Dict[str, Any]): return cls(**config)
[docs]@dataclass class TrainConfig: """ Config for train job on model. :param total_steps: Total number of training steps :type total_steps: int :param seq_length: Number of tokens to use as context (max length for tokenizer) :type seq_length: int :param epochs: Total number of passes through data :type epochs: int :param batch_size: Batch size for training :type batch_size: int :param lr_ramp_steps: Number of steps before learning rate reaches learning_rate_init :type lr_ramp_steps: int :param lr_decay_steps: Number of after ramp up steps before learning rate decays to learning_rate_target :type lr_decay_steps: int :param weight_decay: Weight decay for optimizer :type weight_decay: float :param learning_rate_init: Initial learning rate after ramp up :type learning_rate_init: float :param learning_rate_target: Target learning rate after decay :type learning_rate_target: float :param checkpoint_interval: Save model every checkpoint_interval steps :type checkpoint_interval: int :param eval_interval: Evaluate model every eval_interval steps :type eval_interval: int :param pipeline: Pipeline to use for training. One of the registered pipelines present in trlx.pipeline :type pipeline: str :param orchestrator: Orchestrator to use for training. One of the registered orchestrators present in trlx.orchestrator :type orchestrator: str :param project_name: Project name for wandb :type project_name: str """ total_steps: int seq_length: int epochs: int batch_size: int lr_ramp_steps: int lr_decay_steps: int weight_decay: float learning_rate_init: float learning_rate_target: float opt_betas: Tuple[float] checkpoint_interval: int eval_interval: int pipeline: str # One of the pipelines in framework.pipeline orchestrator: str # One of the orchestrators checkpoint_dir: str = "ckpts" project_name: str = "trlx" seed: int = 1000 @classmethod def from_dict(cls, config: Dict[str, Any]): return cls(**config)
[docs]@dataclass class TRLConfig: """ Top level config for trlX. Loads configs and can be converted to dictionary. """ model: ModelConfig train: TrainConfig method: MethodConfig
[docs] @classmethod def load_yaml(cls, yml_fp: str): """ Load yaml file as TRLConfig. :param yml_fp: Path to yaml file :type yml_fp: str """ with open(yml_fp, mode="r") as file: config = yaml.safe_load(file) return cls( ModelConfig.from_dict(config["model"]), TrainConfig.from_dict(config["train"]), get_method(config["method"]["name"]).from_dict(config["method"]), )
[docs] def to_dict(self): """ Convert TRLConfig to dictionary. """ data = self.model.__dict__.copy() data.update(self.train.__dict__) data.update(self.method.__dict__) return data