from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set
import yaml
from trlx.data.method_configs import MethodConfig, get_method
def merge(base: Dict, update: Dict, updated: Set) -> Dict:
"Recursively updates a nested dictionary with new values"
for k, v in base.items():
if k in update and isinstance(v, dict):
base[k] = merge(v, update[k], updated)
updated.add(k)
elif k in update:
base[k] = update[k]
updated.add(k)
return base
def _merge_dicts(base: Dict, update: Dict) -> Dict:
"Merge two dictionaries recursively, returning a new dictionary."
base = deepcopy(base)
for k, v in update.items():
if isinstance(v, dict):
base[k] = _merge_dicts(base.get(k, {}), v)
else:
base[k] = v
return base
[docs]@dataclass
class ModelConfig:
"""
Config for a model.
:param model_path: Path or name of the model (local or on huggingface hub)
:type model_path: str
:param model_arch_type: Type of model architecture. Either "causal" or "seq2seq"
:type model_arch_type: str
:param num_layers_unfrozen: Number of layers to unfreeze for fine-tuning.
-1 means all layers are unfrozen.
:type num_layers_unfrozen: int
:param peft_config: configuration for peft (Parameter Efficient Fine-Tuning library).
Peft is designed to reduce the number of parameters to train and the memory footprint,
without significant performance loss. It supports multiple techniques such as LORA
or prefix tuning (cf. https://github.com/huggingface/peft).
Here is an example of LORA configuration:
{"peft_type": "LORA", "r": 8, "lora_alpha": 32, "lora_dropout": 0.1}
(parameter-efficient fine-tuning was previously done in trlx with OpenDelta, but it is no longer supported)
:type peft_config: Union[peft.PeftConfig, Dict[str, Any]]
"""
model_path: str
model_arch_type: str = "causal"
num_layers_unfrozen: int = -1
peft_config: Any = None
model_extra_configs: Dict[str, Any] = field(default_factory=dict)
@classmethod
def from_dict(cls, config: Dict[str, Any]):
return cls(**config)
[docs]@dataclass
class TokenizerConfig:
"""
Config for a model.
:param tokenizer_path: Path or name of the tokenizer (local or on huggingface hub)
:type tokenizer_path: str
:param padding_side: Padding side
:type padding_path: str
:param truncation_side: Truncation side
:type truncation_side: str
"""
tokenizer_path: str
padding_side: str = "left"
truncation_side: str = "right"
tokenizer_extra_configs: Dict[str, Any] = field(default_factory=dict)
@classmethod
def from_dict(cls, config: Dict[str, Any]):
return cls(**config)
[docs]@dataclass
class OptimizerConfig:
"""
Config for an optimizer.
:param name: Name of the optimizer
:type name: str
:param kwargs: Keyword arguments for the optimizer (e.g. lr, betas, eps, weight_decay)
:type kwargs: Dict[str, Any]
"""
name: str
kwargs: Dict[str, Any] = field(default_factory=dict)
@classmethod
def from_dict(cls, config: Dict[str, Any]):
return cls(**config)
[docs]@dataclass
class SchedulerConfig:
"""
Config for a learning rate scheduler.
:param name: Name of the scheduler
:type name: str
:param kwargs: Keyword arguments for the scheduler instance (e.g. warmup_steps, T_max)
:type kwargs: Dict[str, Any]
"""
name: str
kwargs: Dict[str, Any] = field(default_factory=dict)
@classmethod
def from_dict(cls, config: Dict[str, Any]):
return cls(**config)
[docs]@dataclass
class TrainConfig:
"""
Config for train job on model.
:param total_steps: Total number of training steps
:type total_steps: int
:param seq_length: Number of tokens to use as context (max length for tokenizer)
:type seq_length: int
:param epochs: Total number of passes through data
:type epochs: int
:param batch_size: Batch size for training
:type batch_size: int
:param tracker: Tracker to use for logging. Default: "wandb"
:type tracker: str
:param checkpoint_interval: Save model every checkpoint_interval steps.
Each checkpoint is stored in a sub-directory of the `TrainConfig.checkpoint_dir`
directory in the format `checkpoint_dir/checkpoint_{step}`.
:type checkpoint_interval: int
:param eval_interval: Evaluate model every eval_interval steps
:type eval_interval: int
:param pipeline: Pipeline to use for training. One of the registered pipelines present in trlx.pipeline
:type pipeline: str
:param trainer: Trainer to use for training. One of the registered trainers present in trlx.trainer
:type trainer: str
:param trainer_kwargs: Extra keyword arguments for the trainer
:type trainer: Dict[str, Any]
:param project_name: Project name for wandb
:type project_name: str
:param entity_name: Entity name for wandb
:type entity_name: str
:param group_name: Group name for wandb (used for grouping runs)
:type group_name: str
:param checkpoint_dir: Directory to save checkpoints
:type checkpoint_dir: str
:param rollout_logging_dir: Directory to store generated rollouts for use in Algorithm Distillation.
Only used by AcceleratePPOTrainer.
:type rollout_logging_dir: Optional[str]
:param save_best: Save best model based on mean reward
:type save_best: bool
:param seed: Random seed
:type seed: int
:param minibatch_size: Size of model input during one forward pass. Must divide batch size
:type minibatch_size: int
"""
total_steps: int
seq_length: int
epochs: int
batch_size: int
checkpoint_interval: int
eval_interval: int
pipeline: str # One of the pipelines in framework.pipeline
trainer: str # One of the trainers
trainer_kwargs: Dict[str, Any] = field(default_factory=dict) # Extra keyword arguments for the trainer
project_name: str = "trlx"
run_name: Optional[str] = None
entity_name: Optional[str] = None
group_name: Optional[str] = None
checkpoint_dir: str = "ckpts"
rollout_logging_dir: Optional[str] = None
save_best: bool = True
save_optimizer: bool = True
resume_from_checkpoint: Optional[str] = None
tracker: Optional[str] = "wandb"
logging_dir: Optional[str] = None
tags: Optional[List[str]] = field(default_factory=list)
seed: int = 1000
minibatch_size: Optional[int] = None
@classmethod
def from_dict(cls, config: Dict[str, Any]):
return cls(**config)
[docs]@dataclass
class TRLConfig:
"""
Top level config for trlX. Loads configs and can be converted to dictionary.
"""
method: MethodConfig
model: ModelConfig
optimizer: OptimizerConfig
scheduler: SchedulerConfig
tokenizer: TokenizerConfig
train: TrainConfig
[docs] @classmethod
def load_yaml(cls, yml_fp: str):
"""
Load yaml file as TRLConfig.
:param yml_fp: Path to yaml file
:type yml_fp: str
"""
with open(yml_fp, mode="r") as file:
config = yaml.safe_load(file)
return cls.from_dict(config)
[docs] def to_dict(self):
"""
Convert TRLConfig to dictionary.
"""
data = {
"method": self.method.__dict__,
"model": self.model.__dict__,
"optimizer": self.optimizer.__dict__,
"scheduler": self.scheduler.__dict__,
"tokenizer": self.tokenizer.__dict__,
"train": self.train.__dict__,
}
return data
[docs] def evolve(self, **kwargs) -> "TRLConfig":
"""
Evolve TRLConfig with new parameters. Can update nested parameters.
>>> config = trlx.data.default_configs.default_ilql_config()
>>> config = config.evolve(method=dict(gamma=0.99, gen_kwargs=dict(max_new_tokens=100))
>>> config.method.gamma
0.99
"""
return TRLConfig.from_dict(_merge_dicts(self.to_dict(), kwargs))
[docs] @classmethod
def from_dict(cls, config: Dict):
"""
Convert dictionary to TRLConfig.
"""
return cls(
method=get_method(config["method"]["name"]).from_dict(config["method"]),
model=ModelConfig.from_dict(config["model"]),
tokenizer=TokenizerConfig.from_dict(config["tokenizer"]),
optimizer=OptimizerConfig.from_dict(config["optimizer"]),
scheduler=SchedulerConfig.from_dict(config["scheduler"]),
train=TrainConfig.from_dict(config["train"]),
)
@classmethod
def update(cls, baseconfig: Dict, config: Dict):
update = {}
# unflatten a string variable name into a nested dictionary
# key1.key2.key3: value -> {key1: {key2: {key3: value}}}
for name, value in config.items():
if isinstance(value, dict):
update[name] = value
else:
*layers, var = name.split(".")
if layers:
d = update.setdefault(layers[0], {})
for layer in layers[1:]:
d = d.setdefault(layer, {})
d[var] = value
if not isinstance(baseconfig, Dict):
baseconfig = baseconfig.to_dict()
updates = set()
merged = merge(baseconfig, update, updates)
for param in update:
if param not in updates:
raise ValueError(f"parameter {param} is not present in the config (typo or a wrong config)")
return cls.from_dict(merged)
def __str__(self):
"""Returns a human-readable string representation of the config."""
import json
return json.dumps(self.to_dict(), indent=4)