Source code for trlx.model.accelerate_ppo_model

import numpy as np
import torch

from trlx.data.configs import TRLConfig
from trlx.model import register_model
from trlx.model.accelerate_base_model import AccelerateRLModel
from trlx.model.nn.ppo_models import GPTHydraHeadWithValueModel
from trlx.pipeline.ppo_pipeline import PPORolloutStorage
from trlx.utils.modeling import clip_by_value, logprobs_from_logits, whiten


class AdaptiveKLController:
    def __init__(self, init_kl_coef, target, horizon):
        self.value = init_kl_coef
        self.target = target
        self.horizon = horizon

    def update(self, current, n_steps):
        target = self.target
        proportional_error = np.clip(current / target - 1, -0.2, 0.2)
        mult = 1 + proportional_error * n_steps / self.horizon
        self.value *= mult


class FixedKLController:
    """Fixed KL controller."""

    def __init__(self, kl_coef):
        self.value = kl_coef

    def update(self, current, n_steps):
        pass


[docs]@register_model class AcceleratePPOModel(AccelerateRLModel): def __init__(self, config): super().__init__(config) self.store = PPORolloutStorage(self.tokenizer.pad_token_id) rollout_loader = self.store.create_loader( self.config.train.batch_size, shuffle=True ) self.model, self.opt, self.scheduler, rollout_loader = self.accelerator.prepare( self.model, self.opt, self.scheduler, rollout_loader ) self.store.clear_history() if config.method.target is not None: self.kl_ctl = AdaptiveKLController( config.method.init_kl_coef, config.method.target, config.method.horizon ) else: self.kl_ctl = FixedKLController(config.method.init_kl_coef) if config.method.target is not None: self.kl_ctl = AdaptiveKLController( config.method.init_kl_coef, config.method.target, config.method.horizon ) else: self.kl_ctl = FixedKLController(config.method.init_kl_coef) self.generate_kwargs = dict( config.method.gen_kwargs, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.eos_token_id, )
[docs] def get_arch(self, config: TRLConfig): return GPTHydraHeadWithValueModel( self.config.model.model_path, self.config.model.num_layers_unfrozen )
[docs] def loss(self, batch): query_tensors = batch.query_tensors.to(self.accelerator.device) response_tensors = batch.response_tensors.to(self.accelerator.device) all_logprobs = batch.logprobs.to(self.accelerator.device) all_values = batch.values.to(self.accelerator.device) all_rewards = batch.rewards.to(self.accelerator.device) lastgaelam = 0 advantages_reversed = [] gen_len = response_tensors.shape[1] for t in reversed(range(gen_len)): nextvalues = all_values[:, t + 1] if t < gen_len - 1 else 0.0 delta = ( all_rewards[:, t] + self.config.method.gamma * nextvalues - all_values[:, t] ) lastgaelam = ( delta + self.config.method.gamma * self.config.method.lam * lastgaelam ) advantages_reversed.append(lastgaelam) advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1) returns = advantages + all_values advantages = whiten(advantages) advantages = advantages.detach() all_tokens = torch.cat((query_tensors, response_tensors), dim=1) attention_mask = ( all_tokens.not_equal(self.tokenizer.pad_token_id) .long() .to(all_tokens.device) ) # for a proper positional encoding in case of left padding position_ids = attention_mask.cumsum(-1) - 1 position_ids.masked_fill_(attention_mask.eq(0), 0) logits, _, vpred = self.model( all_tokens, attention_mask, position_ids=position_ids ) logprob = logprobs_from_logits(logits[:, :-1, :], all_tokens[:, 1:]) # only the generation part of the values/logprobs is needed logprob, vpred = logprob[:, -gen_len:], vpred[:, -gen_len:] vpredclipped = clip_by_value( vpred, all_values - self.config.method.cliprange_value, all_values + self.config.method.cliprange_value, ) mask = attention_mask[:, -gen_len:] vf_losses1 = (vpred - returns) ** 2 vf_losses2 = (vpredclipped - returns) ** 2 vf_loss = 0.5 * torch.sum(torch.max(vf_losses1, vf_losses2) * mask) / mask.sum() kl = logprob - all_logprobs # Record mean_kl for kl coef adjustment self.mean_kl = torch.mean(torch.sum(kl, dim=-1)).item() ratio = torch.exp(kl) pg_losses = -advantages * ratio pg_losses2 = -advantages * torch.clamp( ratio, 1.0 - self.config.method.cliprange, 1.0 + self.config.method.cliprange, ) pg_loss = torch.sum(torch.max(pg_losses, pg_losses2) * mask) / mask.sum() loss = pg_loss + self.config.method.vf_coef * vf_loss stats = { "loss": loss, "pg_loss": pg_loss, "vf_loss": vf_loss, } return loss, stats
[docs] def post_epoch_callback(self): self.store.clear_history() self.orch.make_experience( self.config.method.num_rollouts, self.iter_count ) # Collect more rollouts for training
[docs] def post_backward_callback(self): # Update kl_coefficient self.kl_ctl.update(self.mean_kl, self.config.train.batch_size)
def prepare_learning(self): eval_dataloader = self.eval_pipeline.create_loader(self.config.train.batch_size) train_dataloader = self.store.create_loader( self.config.train.batch_size, shuffle=True ) self.train_dataloader, self.eval_dataloader = self.accelerator.prepare( train_dataloader, eval_dataloader ) self.n_updates_per_batch = self.config.method.ppo_epochs self.total_steps = ( self.config.train.epochs * self.n_updates_per_batch * len(self.train_dataloader) ) self.total_steps = min(self.total_steps, self.config.train.total_steps)