Source code for trlx.model.accelerate_base_model

import importlib
import os
from abc import abstractmethod
from time import time
from typing import Dict, Iterable, Tuple

import torch
import torch.nn.functional as F
from accelerate import Accelerator
from transformers import AutoTokenizer

import wandb
from trlx.data.configs import TRLConfig
from trlx.model import BaseRLModel, register_model

if importlib.util.find_spec("rich") is not None:
    from tqdm.rich import tqdm
else:
    from tqdm import tqdm


[docs]@register_model class AccelerateRLModel(BaseRLModel): """ RL Model that uses accelerate for training """ def __init__(self, config, train_mode=True): super().__init__(config, train_mode) self.accelerator = Accelerator(log_with="wandb") if int(os.environ.get("WORLD_SIZE", 1)) > 1: torch.distributed.barrier(device_ids=[int(os.environ.get("LOCAL_RANK", 0))]) else: torch.random.manual_seed(config.train.seed) # Retrieves model equipped for ppo, ilql, etc self.model = self.get_arch(self.config) self.max_length = config.train.seq_length if config.model.tokenizer_path: self.tokenizer = AutoTokenizer.from_pretrained(config.model.tokenizer_path) self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.padding_side = "left" else: self.tokenizer = None if hasattr(self.model.gpt, "gpt_neox"): gpt_blocks = self.model.gpt.gpt_neox.layers else: gpt_blocks = self.model.gpt.transformer.h # freeze transformer's bottom layers if num_layers_unfrozen >= 0 num_layers_unfrozen = self.config.model.num_layers_unfrozen if num_layers_unfrozen == 0: gpt_blocks_to_freeze = list(gpt_blocks) elif num_layers_unfrozen > 0: gpt_blocks_to_freeze = list(gpt_blocks)[:-num_layers_unfrozen] else: gpt_blocks_to_freeze = [] for m in gpt_blocks_to_freeze: m.requires_grad_(False) if self.accelerator.is_main_process: self.accelerator.init_trackers( project_name=self.config.train.project_name, config=self.config.to_dict(), init_kwargs={ "wandb": { "name": f"{config.model.model_path}", "mode": "disabled" if os.environ.get("debug", False) else "online", } }, ) self.opt = torch.optim.AdamW( self.model.parameters(), lr=float(self.config.train.learning_rate_init), betas=self.config.train.opt_betas, ) self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.opt, self.config.train.total_steps, eta_min=float(self.config.train.learning_rate_target), )
[docs] def tokenize(self, text: Iterable[str]): """ Tokenize a batch of text after adding bos token to each of the samples """ text = [self.tokenizer.bos_token + txt for txt in text] return self.tokenizer( text, truncation=True, max_length=self.config.seq_length, return_tensors="pt", )
[docs] def generate(self, input_ids, attention_mask=None, **kwargs): """Wraps hf's `generate` adding some specific method's defaults""" input_ids = input_ids.to(self.accelerator.device) if attention_mask is not None: attention_mask = attention_mask.to(self.accelerator.device) kwargs = dict(self.generate_kwargs, **kwargs) with torch.no_grad(): return self.accelerator.unwrap_model(self.model).generate( input_ids=input_ids, attention_mask=attention_mask, **kwargs )
[docs] def get_components(self) -> Dict[str, any]: components = ( {"model": self.model, "opt": self.opt, "scheduler": self.scheduler} if self.train_mode else {"model": self.model} ) return components
[docs] def save(self, directory=None): """Creates checkpoint of optimizer, scheduler and a model""" self.accelerator.save_state(directory or self.config.train.checkpoint_dir)
[docs] def add_eval_pipeline(self, eval_pipeline): """Adds pipeline from with validation prompts""" self.eval_pipeline = eval_pipeline
[docs] def evaluate(self): """Samples model on `eval_prompts`, logs stats with `reward_fn` or `metric_fn` if provided""" stats = {} all_samples = [] generate_time = time() for prompts in self.eval_dataloader: if isinstance(prompts, torch.Tensor): samples = self.generate(prompts) else: samples = self.generate(**prompts) if isinstance(samples, tuple): samples, *_ = samples pad_token = self.tokenizer.eos_token_id if self.tokenizer else 0 all_samples.append( F.pad( samples, (0, self.max_length - samples.shape[1]), value=pad_token, ) ) stats["generate_time"] = time() - generate_time samples = self.accelerator.gather(torch.vstack(all_samples)) if self.accelerator.is_main_process: if self.tokenizer: samples = self.tokenizer.batch_decode(samples, skip_special_tokens=True) if isinstance(samples[0], str): columns_data = [samples] else: columns_data = [samples.tolist()] columns = ["samples"] # in online setting, compute the reward for validation if self.reward_fn: rewards = torch.as_tensor(self.reward_fn(samples), dtype=torch.float) mean_reward = rewards.mean() columns.append("reward") columns_data.append(rewards) stats["mean_reward"] = mean_reward print(f"{mean_reward=}") # additionally log any other metrics if self.metric_fn: metric_time = time() metrics = self.metric_fn(samples) stats["metric_time"] = time() - metric_time mean_metrics = { f"metrics/{k}": torch.as_tensor(xs).mean(-1) for k, xs in metrics.items() } stats.update(mean_metrics) for metric, values in metrics.items(): columns.append(metric) columns_data.append(values) rows = list(zip(*columns_data)) stats["samples"] = wandb.Table(columns=columns, rows=rows) print(rows[0]) return stats
[docs] def learn(self): """ Samples batches from `self.store`, updates model and periodically evaluates it on `self.eval_dataloader` """ self.prepare_learning() tbar = tqdm( total=self.total_steps, disable=not self.accelerator.is_local_main_process ) self.iter_count = 0 for _ in range(self.config.train.epochs): for batch in self.train_dataloader: for _ in range(self.n_updates_per_batch): forward_time = time() loss, stats = self.loss(batch) forward_time = time() - forward_time backward_time = time() self.accelerator.backward(loss) backward_time = time() - backward_time self.opt.step() self.opt.zero_grad() self.scheduler.step() self.iter_count += 1 if self.iter_count % self.config.train.checkpoint_interval == 0: self.save() if self.iter_count % self.config.train.eval_interval == 0: results = self.evaluate() results.update(stats) results.update( { "forward_time": forward_time, "backward_time": backward_time, } ) self.accelerator.log(results) desc = ", ".join(f"{k}: {v:.2f}" for k, v in stats.items()) tbar.set_description(desc) tbar.update() if self.iter_count >= self.total_steps: self.save() return self.evaluate() self.post_backward_callback() self.post_epoch_callback()
[docs] @abstractmethod def get_arch(self, config: TRLConfig): """Returns a specific wrapper of the decoder architecture""" pass
[docs] @abstractmethod def loss(self, batch) -> Tuple[float, Dict]: """Compute loss on a batch from `store` and return some statistics""" pass
[docs] @abstractmethod def post_backward_callback(self): """Do something after model update""" pass
[docs] @abstractmethod def post_epoch_callback(self): """Do something after exhausting/single pass over `self.store`""" pass