Source code for trlx.data.configs

from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set

import yaml

from trlx.data.method_configs import MethodConfig, get_method


def merge(base: Dict, update: Dict, updated: Set) -> Dict:
    "Recursively updates a nested dictionary with new values"
    for k, v in base.items():
        if k in update and isinstance(v, dict):
            base[k] = merge(v, update[k], updated)
            updated.add(k)
        elif k in update:
            base[k] = update[k]
            updated.add(k)

    return base


def _merge_dicts(base: Dict, update: Dict) -> Dict:
    "Merge two dictionaries recursively, returning a new dictionary."

    base = deepcopy(base)

    for k, v in update.items():
        if isinstance(v, dict):
            base[k] = _merge_dicts(base.get(k, {}), v)
        else:
            base[k] = v

    return base


[docs]@dataclass class ModelConfig: """ Config for a model. :param model_path: Path or name of the model (local or on huggingface hub) :type model_path: str :param model_arch_type: Type of model architecture. Either "causal" or "seq2seq" :type model_arch_type: str :param num_layers_unfrozen: Number of layers to unfreeze for fine-tuning. -1 means all layers are unfrozen. :type num_layers_unfrozen: int :param peft_config: configuration for peft (Parameter Efficient Fine-Tuning library). Peft is designed to reduce the number of parameters to train and the memory footprint, without significant performance loss. It supports multiple techniques such as LORA or prefix tuning (cf. https://github.com/huggingface/peft). Here is an example of LORA configuration: {"peft_type": "LORA", "r": 8, "lora_alpha": 32, "lora_dropout": 0.1} (parameter-efficient fine-tuning was previously done in trlx with OpenDelta, but it is no longer supported) :type peft_config: Union[peft.PeftConfig, Dict[str, Any]] """ model_path: str model_arch_type: str = "causal" num_layers_unfrozen: int = -1 peft_config: Any = None model_extra_configs: Dict[str, Any] = field(default_factory=dict) @classmethod def from_dict(cls, config: Dict[str, Any]): return cls(**config)
[docs]@dataclass class TokenizerConfig: """ Config for a model. :param tokenizer_path: Path or name of the tokenizer (local or on huggingface hub) :type tokenizer_path: str :param padding_side: Padding side :type padding_path: str :param truncation_side: Truncation side :type truncation_side: str """ tokenizer_path: str padding_side: str = "left" truncation_side: str = "right" tokenizer_extra_configs: Dict[str, Any] = field(default_factory=dict) @classmethod def from_dict(cls, config: Dict[str, Any]): return cls(**config)
[docs]@dataclass class OptimizerConfig: """ Config for an optimizer. :param name: Name of the optimizer :type name: str :param kwargs: Keyword arguments for the optimizer (e.g. lr, betas, eps, weight_decay) :type kwargs: Dict[str, Any] """ name: str kwargs: Dict[str, Any] = field(default_factory=dict) @classmethod def from_dict(cls, config: Dict[str, Any]): return cls(**config)
[docs]@dataclass class SchedulerConfig: """ Config for a learning rate scheduler. :param name: Name of the scheduler :type name: str :param kwargs: Keyword arguments for the scheduler instance (e.g. warmup_steps, T_max) :type kwargs: Dict[str, Any] """ name: str kwargs: Dict[str, Any] = field(default_factory=dict) @classmethod def from_dict(cls, config: Dict[str, Any]): return cls(**config)
[docs]@dataclass class TrainConfig: """ Config for train job on model. :param total_steps: Total number of training steps :type total_steps: int :param seq_length: Number of tokens to use as context (max length for tokenizer) :type seq_length: int :param epochs: Total number of passes through data :type epochs: int :param batch_size: Batch size for training :type batch_size: int :param tracker: Tracker to use for logging. Default: "wandb" :type tracker: str :param checkpoint_interval: Save model every checkpoint_interval steps. Each checkpoint is stored in a sub-directory of the `TrainConfig.checkpoint_dir` directory in the format `checkpoint_dir/checkpoint_{step}`. :type checkpoint_interval: int :param eval_interval: Evaluate model every eval_interval steps :type eval_interval: int :param pipeline: Pipeline to use for training. One of the registered pipelines present in trlx.pipeline :type pipeline: str :param trainer: Trainer to use for training. One of the registered trainers present in trlx.trainer :type trainer: str :param trainer_kwargs: Extra keyword arguments for the trainer :type trainer: Dict[str, Any] :param project_name: Project name for wandb :type project_name: str :param entity_name: Entity name for wandb :type entity_name: str :param group_name: Group name for wandb (used for grouping runs) :type group_name: str :param checkpoint_dir: Directory to save checkpoints :type checkpoint_dir: str :param rollout_logging_dir: Directory to store generated rollouts for use in Algorithm Distillation. Only used by AcceleratePPOTrainer. :type rollout_logging_dir: Optional[str] :param save_best: Save best model based on mean reward :type save_best: bool :param seed: Random seed :type seed: int :param minibatch_size: Size of model input during one forward pass. Must divide batch size :type minibatch_size: int """ total_steps: int seq_length: int epochs: int batch_size: int checkpoint_interval: int eval_interval: int pipeline: str # One of the pipelines in framework.pipeline trainer: str # One of the trainers trainer_kwargs: Dict[str, Any] = field(default_factory=dict) # Extra keyword arguments for the trainer project_name: str = "trlx" run_name: Optional[str] = None entity_name: Optional[str] = None group_name: Optional[str] = None checkpoint_dir: str = "ckpts" rollout_logging_dir: Optional[str] = None save_best: bool = True save_optimizer: bool = True resume_from_checkpoint: Optional[str] = None tracker: Optional[str] = "wandb" logging_dir: Optional[str] = None tags: Optional[List[str]] = field(default_factory=list) seed: int = 1000 minibatch_size: Optional[int] = None @classmethod def from_dict(cls, config: Dict[str, Any]): return cls(**config)
[docs]@dataclass class TRLConfig: """ Top level config for trlX. Loads configs and can be converted to dictionary. """ method: MethodConfig model: ModelConfig optimizer: OptimizerConfig scheduler: SchedulerConfig tokenizer: TokenizerConfig train: TrainConfig
[docs] @classmethod def load_yaml(cls, yml_fp: str): """ Load yaml file as TRLConfig. :param yml_fp: Path to yaml file :type yml_fp: str """ with open(yml_fp, mode="r") as file: config = yaml.safe_load(file) return cls.from_dict(config)
[docs] def to_dict(self): """ Convert TRLConfig to dictionary. """ data = { "method": self.method.__dict__, "model": self.model.__dict__, "optimizer": self.optimizer.__dict__, "scheduler": self.scheduler.__dict__, "tokenizer": self.tokenizer.__dict__, "train": self.train.__dict__, } return data
[docs] def evolve(self, **kwargs) -> "TRLConfig": """ Evolve TRLConfig with new parameters. Can update nested parameters. >>> config = trlx.data.default_configs.default_ilql_config() >>> config = config.evolve(method=dict(gamma=0.99, gen_kwargs=dict(max_new_tokens=100)) >>> config.method.gamma 0.99 """ return TRLConfig.from_dict(_merge_dicts(self.to_dict(), kwargs))
[docs] @classmethod def from_dict(cls, config: Dict): """ Convert dictionary to TRLConfig. """ return cls( method=get_method(config["method"]["name"]).from_dict(config["method"]), model=ModelConfig.from_dict(config["model"]), tokenizer=TokenizerConfig.from_dict(config["tokenizer"]), optimizer=OptimizerConfig.from_dict(config["optimizer"]), scheduler=SchedulerConfig.from_dict(config["scheduler"]), train=TrainConfig.from_dict(config["train"]), )
@classmethod def update(cls, baseconfig: Dict, config: Dict): update = {} # unflatten a string variable name into a nested dictionary # key1.key2.key3: value -> {key1: {key2: {key3: value}}} for name, value in config.items(): if isinstance(value, dict): update[name] = value else: *layers, var = name.split(".") if layers: d = update.setdefault(layers[0], {}) for layer in layers[1:]: d = d.setdefault(layer, {}) d[var] = value if not isinstance(baseconfig, Dict): baseconfig = baseconfig.to_dict() updates = set() merged = merge(baseconfig, update, updates) for param in update: if param not in updates: raise ValueError(f"parameter {param} is not present in the config (typo or a wrong config)") return cls.from_dict(merged) def __str__(self): """Returns a human-readable string representation of the config.""" import json return json.dumps(self.to_dict(), indent=4)