Source code for trlx.model.nn.ppo_models

import inspect
from copy import deepcopy
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
from torch import nn
from transformers.modeling_outputs import ModelOutput

from transformers import (  # isort:skip
    AutoConfig,
    AutoModelForCausalLM,
    PretrainedConfig,
    PreTrainedModel,
)


@dataclass
class CausalLMOutputWithCrossAttentions(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
    value: Optional[torch.FloatTensor] = None


def make_head(n_embd: int, out: int):
    return nn.Sequential(
        nn.Linear(n_embd, n_embd * 2), nn.ReLU(), nn.Linear(n_embd * 2, out)
    )


[docs]class GPTHeadWithValueModel(nn.Module): """ The GPTHeadWithValueModel class implements a GPT-type language model with a secondary, scalar head. """ def __init__(self, config: Union[PretrainedConfig, str]): super().__init__() if isinstance(config, PretrainedConfig): self.gpt = AutoModelForCausalLM.from_config(config) else: self.gpt = AutoModelForCausalLM.from_pretrained(config) if hasattr(self.gpt.config, "hidden_size"): self.n_embd = self.gpt.config.hidden_size else: self.n_embd = self.gpt.config.n_embd self.v_head = make_head(self.n_embd, 1) def generate(self, input_ids, **x): return self.gpt.generate(input_ids, **x)
[docs] def forward( self, input_ids=None, past_key_values=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, mc_token_ids=None, lm_labels=None, mc_labels=None, return_dict=False, output_attentions=False, output_hidden_states=False, ): loss = None transformer_outputs = self.gpt.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, ) hidden_states = transformer_outputs[0] lm_logits = self.gpt.lm_head(hidden_states) value = self.v_head(hidden_states).squeeze(-1) if not return_dict: outputs = (lm_logits,) + transformer_outputs[1:] + (value,) return outputs return CausalLMOutputWithCrossAttentions( loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, cross_attentions=transformer_outputs.cross_attentions, value=value, )
[docs]class ModelBranch(PreTrainedModel): """ ModelBranch implements the frozen upper trunk of the reference model used when computing the PPO KL-divergence penalty. Expects a list of frozen transformer blocks and an lm_head from the base model. """ def __init__(self, config, transformer_blocks, ln_f, lm_head): super().__init__(config) # Defined by the main trunk self.n_embd = config.n_embd self.h = deepcopy(nn.ModuleList(transformer_blocks)) self.ln_f = deepcopy(ln_f) self.lm_head = deepcopy(lm_head) # Model parallel self.model_parallel = False self.device_map = None self.gradient_checkpointing = False # Turning off grad saves memory for block in self.h: for parameter in block.parameters(): parameter.requires_grad = False for parameter in lm_head.parameters(): parameter.requires_grad = False
[docs] def forward( self, hidden_states: torch.Tensor, # Takes as input hidden_states instead of input_ids output_shape: torch.Tensor, # output_size given by main trunk past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = False, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: batch_size = hidden_states.size()[0] output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) device = hidden_states.device if past_key_values is None: past_key_values = tuple([None] * len(self.h)) # GPT2Attention mask. if attention_mask is not None: if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") attention_mask = attention_mask.view(batch_size, -1) # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. attention_mask = attention_mask[:, None, None, :] # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and the dtype's smallest value for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.add_cross_attention and encoder_hidden_states is not None: ( encoder_batch_size, encoder_sequence_length, _, ) = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_attention_mask = None # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # head_mask has shape n_layer x batch x n_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_cross_attentions = ( () if output_attentions and self.config.add_cross_attention else None ) all_hidden_states = () if output_hidden_states else None for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): # Model parallel if self.model_parallel: torch.cuda.set_device(hidden_states.device) # Ensure layer_past is on same device as hidden_states (might not be correct) if layer_past is not None: layer_past = tuple( past_state.to(hidden_states.device) for past_state in layer_past ) # Ensure that attention_mask is always on the same device as hidden_states if attention_mask is not None: attention_mask = attention_mask.to(hidden_states.device) if isinstance(head_mask, torch.Tensor): head_mask = head_mask.to(hidden_states.device) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # Assumes we are never training the branch block_params = inspect.getfullargspec(block.forward).args if "encoder_hidden_states" in block_params: outputs = block( hidden_states, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i], encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, ) else: outputs = block( hidden_states, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, ) hidden_states = outputs[0] if use_cache is True: presents = presents + (outputs[1],) if output_attentions: all_self_attentions = all_self_attentions + ( outputs[2 if use_cache else 1], ) if self.config.add_cross_attention: all_cross_attentions = all_cross_attentions + ( outputs[3 if use_cache else 2], ) # Model Parallel: If it's the last layer for that device, put things on the next device if self.model_parallel: for k, v in self.device_map.items(): if i == v[-1] and "cuda:" + str(k) != self.last_device: hidden_states = hidden_states.to("cuda:" + str(k + 1)) hidden_states = self.ln_f(hidden_states) hidden_states = hidden_states.view(output_shape) # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # last_hidden_state = hidden_states # past_key_values = presents # hidden_states = all_hidden_states # attentions = all_self_attentions # cross_attentions = all_cross_attentions # START OF CAUSAL HEAD # # hidden_states = hidden_states.to(torch.float32) Present for gptj if self.model_parallel: torch.cuda.set_device(self.transformer.first_device) hidden_states = hidden_states.to(self.lm_head.weight.device) lm_logits = self.lm_head(hidden_states) if not return_dict: outputs = (lm_logits,) + (None,) + (None,) return outputs return CausalLMOutputWithCrossAttentions( loss=None, logits=lm_logits, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, value=None, )
[docs]class GPTHydraHeadWithValueModel(nn.Module): """The GPTHeadWithValueModel class implements a GPT-type language model with a secondary, scalar head.""" def __init__( self, config: Union[PretrainedConfig, str], num_layers_unfrozen: int = -1 ): super().__init__() if isinstance(config, PretrainedConfig): self.gpt = AutoModelForCausalLM.from_config(config) else: self.gpt = AutoModelForCausalLM.from_pretrained(config) if hasattr(self.gpt.config, "hidden_size"): self.n_embd = self.gpt.config.hidden_size self.gpt.config.n_embd = self.n_embd else: self.n_embd = self.gpt.config.n_embd self.v_head = make_head(self.n_embd, 1) self.num_layers_unfrozen = num_layers_unfrozen if num_layers_unfrozen > 0: transformer_blocks = list(self.gpt.transformer.h)[-num_layers_unfrozen:] # Retrive hf_config to init hf_config = AutoConfig.from_pretrained(config) hf_config.n_embd = self.n_embd self.frozen_head = ModelBranch( hf_config, transformer_blocks, self.gpt.transformer.ln_f, self.gpt.lm_head, ) def generate(self, input_ids, **x): return self.gpt.generate(input_ids, **x) def forward_hydra(self, input_ids, **x): if x.get("return_dict") is not None: return_dict = x["return_dict"] else: return_dict = True x["return_dict"] = True x["output_hidden_states"] = True output = self.forward(input_ids, **x) all_hidden_states = output.hidden_states # Get output of last frozen hidden layer # Select hidden state before first layer of branch. input_hidden_state = all_hidden_states[-(self.num_layers_unfrozen + 1)] # Get size of last hidden state output_shape = all_hidden_states[-1].size() outputs = self.frozen_head(input_hidden_state, output_shape, **x) if not return_dict: return outputs.logits return outputs
[docs] def forward( self, input_ids=None, attention_mask=None, past_key_values=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, mc_token_ids=None, lm_labels=None, mc_labels=None, return_dict=False, output_attentions=False, output_hidden_states=False, ): loss = None transformer_outputs = self.gpt.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states, ) hidden_states = transformer_outputs[0] lm_logits = self.gpt.lm_head(hidden_states) value = self.v_head(hidden_states).squeeze(-1) if not return_dict: outputs = (lm_logits,) + transformer_outputs[1:] + (value,) return outputs return CausalLMOutputWithCrossAttentions( loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, cross_attentions=None, value=value, )