import gc
import inspect
import re
from copy import deepcopy
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import deepspeed
import numpy as np
import torch
import transformers
from torchtyping import TensorType
from transformers.modeling_outputs import ModelOutput
from transformers.models.bloom import modeling_bloom
from transformers.models.opt import modeling_opt
from trlx.data.method_configs import MethodConfig, register_method
from trlx.models.modeling_base import PreTrainedModelWrapper
from trlx.utils.modeling import (
flatten_dict,
get_tensor_stats,
hf_get_decoder,
hf_get_decoder_blocks,
hf_get_decoder_final_norm,
hf_get_hidden_size,
hf_get_lm_head,
hf_get_num_hidden_layers,
make_head,
whiten,
)
# KL Controllers
class AdaptiveKLController:
"""Adaptive KL Controller as described in Ziegler et al. "Fine-Tuning Language Models from Human Preferences"
Reference: Section 2.2 https://arxiv.org/pdf/1909.08593.pdf#page=2
Source: https://github.com/openai/lm-human-preferences/blob/master/lm_human_preferences/train_policy.py
"""
def __init__(self, init_kl_coef: float, target: float, horizon: int):
self.value = init_kl_coef
self.target = target
self.horizon = horizon
def update(self, current: float, n_steps: int):
"""Returns adaptively updated KL coefficient, βₜ₊₁.
Arguments:
current: The current KL value between the newest policy and the initial policy.
"""
proportional_error = np.clip(current / self.target - 1, -0.2, 0.2) # ϵₜ
mult = 1 + proportional_error * n_steps / self.horizon
self.value *= mult # βₜ₊₁
class FixedKLController:
"""Fixed KL controller."""
def __init__(self, kl_coef):
self.value = kl_coef
def update(self, current: float, n_steps: int):
"""Returns updated KL coefficient, βₜ₊₁.
Arguments:
current: The current KL value between the newest policy and the initial policy.
"""
pass
# PPO Configs
[docs]@dataclass
@register_method
class PPOConfig(MethodConfig):
"""
Config for PPO method
:param ppo_epochs: Number of updates per batch
:type ppo_epochs: int
:param num_rollouts: Number of experiences to observe before learning
:type num_rollouts: int
:param init_kl_coef: Initial value for KL coefficient
:type init_kl_coef: float
:param target: Target value for KL coefficient
:type target: float
:param horizon: Number of steps for KL coefficient to reach target
:type horizon: int
:param gamma: Discount factor
:type gamma: float
:param lam: GAE lambda
:type lam: float
:param cliprange: Clipping range for PPO policy loss (1 - cliprange, 1 + cliprange)
:type cliprange: float
:param cliprange_value: Clipping range for predicted values
(observed values - cliprange_value, observed values + cliprange_value)
:type cliprange_value: float
:param vf_coef: Value loss scale w.r.t policy loss
:type vf_coef: float
:param gen_kwargs: Additional kwargs for the generation
:type gen_kwargs: Dict[str, Any]
:param gen_experience_kwargs: if this is not None, then the experience is generated using this
:type gen_experience_kwargs: Dict[str, Any]
"""
ppo_epochs: int
num_rollouts: int
chunk_size: int
init_kl_coef: float
target: float
horizon: int
gamma: float
lam: float
cliprange: float
cliprange_value: float
vf_coef: float
scale_reward: Optional[str]
ref_mean: Optional[float]
ref_std: Optional[float]
cliprange_reward: float
gen_kwargs: dict
gen_experience_kwargs: Optional[dict] = None
num_value_layers_unfrozen: int = 0
[docs] def get_advantages_and_returns(
self,
values: TensorType["batch_size", "response_size"],
rewards: TensorType["batch_size", "response_size"],
response_length: int,
use_whitening: Optional[bool] = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function that computes advantages and returns from rewards and values.
Calculated as in the original PPO paper: https://arxiv.org/abs/1707.06347
Note that rewards may include a KL divergence loss term.
Advantages looks like this:
Adv1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ...
- V1 + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ...
Returns looks like this:
Ret1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ...
+ γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ...
Args:
values: Tensor of shape (batch_size, response_size)
rewards: Tensor of shape (batch_size, response_size)
response_length: Length of the response sequence
use_whitening: Whether to use whitening (ie. normalize advantages) or not
"""
lastgaelam = 0
advantages_reversed = []
for t in reversed(range(response_length)):
nextvalues = values[:, t + 1] if t < response_length - 1 else 0.0
delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]
lastgaelam = delta + self.gamma * self.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = advantages + values
if use_whitening:
advantages = whiten(advantages)
return advantages.detach(), returns
[docs] def loss(
self,
logprobs: TensorType["batch_size", "response_size"],
values: TensorType["batch_size", "response_size"],
old_logprobs: TensorType["batch_size", "response_size"],
old_values: TensorType["batch_size", "response_size"],
advantages: TensorType["batch_size", "response_size"],
returns: TensorType["batch_size", "response_size"],
mask: TensorType["batch_size", "response_size"],
):
"""PPO objective function.
References:
- https://stable-baselines.readthedocs.io/en/master/modules/ppo2.html
"""
values_clipped = torch.clamp(
values,
old_values - self.cliprange_value,
old_values + self.cliprange_value,
)
n = mask.sum()
vf_loss1 = (values - returns) ** 2
vf_loss2 = (values_clipped - returns) ** 2
vf_loss = 0.5 * torch.sum(torch.max(vf_loss1, vf_loss2) * mask) / n
vf_clipfrac = torch.sum((vf_loss2 > vf_loss1).float() * mask) / n
log_ratio = (logprobs - old_logprobs) * mask
ratio = torch.exp(log_ratio)
# Unbiased KL-div estimates (`k3`). Ref: http://joschu.net/blog/kl-approx.html
with torch.no_grad():
approx_kl = torch.mean((ratio - 1) - log_ratio)
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(
ratio,
1.0 - self.cliprange,
1.0 + self.cliprange,
)
pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / n
pg_clipfrac = torch.sum((pg_loss2 > pg_loss1).float() * mask) / n
loss = pg_loss + self.vf_coef * vf_loss
stats = dict(
losses=dict(
total_loss=loss.item(),
policy_loss=pg_loss.item(),
value_loss=vf_loss.item(),
),
values=dict(
get_tensor_stats(values, mask, n),
values_error=torch.sum(((values - returns) * mask) ** 2) / n,
values_mape_error=torch.sum((abs(values - returns) * mask) / abs(returns * mask + 1e-2)) / n,
clipfrac=vf_clipfrac,
),
old_values=get_tensor_stats(old_values, mask, n),
returns=get_tensor_stats(returns, mask, n),
policy=dict(approx_kl=approx_kl.item(), clipfrac=pg_clipfrac.item()),
ratio=(ratio * mask).sum() / n,
padding_percentage=1 - n / mask.numel(),
)
return loss, flatten_dict(stats)
# CausalLM architectures
@dataclass
class CausalLMOutputWithValue(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
value: Optional[torch.FloatTensor] = None
def make_value_branch(base_model, num_value_layers_unfrozen):
value_head = make_head(hf_get_hidden_size(base_model.config), 1)
if num_value_layers_unfrozen == 0:
return value_head
config = base_model.config
branch_class = hf_get_branch_class(config)
value_branch = branch_class(base_model, num_layers_unfrozen=num_value_layers_unfrozen, frozen=False)
value_branch.lm_head = value_head
return value_branch
class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
"""An `AutoModel` class wrapper for `transformers` causal models that have a
language modeling head and a value head
"""
_auto_model_parent_class = transformers.AutoModelForCausalLM
_supported_modules = ["v_head"]
_supported_args = ["peft_config", "num_value_layers_unfrozen"]
def __init__(
self,
base_model: transformers.PreTrainedModel,
peft_config=None,
num_value_layers_unfrozen=0,
):
super().__init__(base_model, peft_config=peft_config)
self.num_value_layers_unfrozen = num_value_layers_unfrozen
self.v_head = make_value_branch(base_model, num_value_layers_unfrozen)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
position_ids: Optional[List[torch.FloatTensor]] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
ignore_peft_adapter: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithValue]:
forward_kwargs = self.get_compatible_forward_kwargs(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
forward_kwargs["output_hidden_states"] = True
forward_kwargs["return_dict"] = True
if self.peft_type == "PREFIX_TUNING":
# In this case peft redefines past_key_values, remove it to avoid an exception.
forward_kwargs.pop("past_key_values", None)
if self.peft_type and ignore_peft_adapter:
if "LORA" in self.peft_type:
# For LORA, temporarily disable the adapter
lora_model = self.base_model.base_model
lora_model.disable_adapter_layers()
outputs = self.base_model(**forward_kwargs)
lora_model.enable_adapter_layers()
else:
# For prompt or prefix adapters, just use the base model of PeftModel
outputs = self.base_model.base_model(**forward_kwargs)
else:
outputs = self.base_model(**forward_kwargs)
# TODO: Apply PEFT to value branch
if self.num_value_layers_unfrozen > 0:
output_shape = outputs.hidden_states[-1].size()
forward_kwargs.pop("input_ids", None)
forward_kwargs.pop("inputs_embeds", None)
forward_kwargs["return_dict"] = False
value = self.v_head(
outputs.hidden_states[-(self.num_value_layers_unfrozen + 1)],
output_shape=output_shape,
**forward_kwargs,
)[0].squeeze(-1)
else:
value = self.v_head(outputs.hidden_states[-(self.num_value_layers_unfrozen + 1)]).squeeze(-1)
if not return_dict:
outputs = (outputs.logits,) + outputs[1:] + (value,)
return outputs
return CausalLMOutputWithValue(**outputs, value=value)
def generate(self, *args, **kwargs) -> Union[ModelOutput, torch.LongTensor]:
return self.base_model.generate(*args, **kwargs)
def state_dict(self, *args, heads_only=False, **kwargs):
"""
Returns the state dictionary of the model. We add the state dictionary of the value head
to the state dictionary of the wrapped model by prepending the key with `v_head.`.
"""
state_dict = self.v_head.state_dict(*args, **dict(prefix="v_head.", **kwargs))
if not heads_only:
state_dict = {**state_dict, **self.base_model.state_dict(*args, **dict(prefix="base_model.", **kwargs))}
return {
**self.base_model.state_dict(*args, **dict(prefix="base_model.", **kwargs)),
**self.v_head.state_dict(*args, **dict(prefix="v_head.", **kwargs)),
}
return state_dict
def post_init(self, state_dict):
"""
Adds the state dictionary of the value head to the state dictionary of the wrapped model
by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the
keys of the value head state dictionary.
"""
super().post_init()
trlx_checkpoint = any(k.startswith("base_model.") or k.startswith("v_head.") for k in state_dict)
self.load_state_dict(state_dict, strict=trlx_checkpoint)
del state_dict
gc.collect() # noqa: E702
class AutoModelForCausalLMWithHydraValueHead(AutoModelForCausalLMWithValueHead):
_supported_modules = ["v_head", "frozen_head"]
_supported_args = ["num_layers_unfrozen", "peft_config", "num_value_layers_unfrozen"]
def __init__(
self,
base_model: transformers.PreTrainedModel,
*,
num_layers_unfrozen: int = -1,
peft_config=None,
num_value_layers_unfrozen: int = 0,
):
super().__init__(base_model, peft_config=peft_config, num_value_layers_unfrozen=num_value_layers_unfrozen)
self.num_layers_unfrozen = num_layers_unfrozen
if self.num_layers_unfrozen > 0 and not self.peft_type:
config = self.base_model.config
branch_class = hf_get_branch_class(config)
self.frozen_head = branch_class(
self.base_model,
num_layers_unfrozen=self.num_layers_unfrozen,
).eval()
else:
self.frozen_head = None
def forward_hydra(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
position_ids: Optional[List[torch.FloatTensor]] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[torch.FloatTensor, CausalLMOutputWithValue]:
forward_kwargs = self.get_compatible_forward_kwargs(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
return_dict = forward_kwargs.get("return_dict", True)
forward_kwargs["return_dict"] = True
forward_kwargs["output_hidden_states"] = True
if self.peft_type:
hydra_outputs = self.forward(**forward_kwargs, ignore_peft_adapter=True)
else:
outputs = self.forward(**forward_kwargs)
# Select the hidden state before the first branching layer
input_hidden_state = outputs.hidden_states[-(self.num_layers_unfrozen + 1)]
output_shape = outputs.hidden_states[-1].size()
forward_kwargs.pop("input_ids", None) # Ignore `input_ids` for branch head
forward_kwargs.pop("inputs_embeds", None) # Ignore `inputs_embeds` for branch head
hydra_outputs = self.frozen_head(input_hidden_state, output_shape, **forward_kwargs)
if not return_dict:
return hydra_outputs.logits
return hydra_outputs
def state_dict(self, *args, heads_only=False, **kwargs):
"""
Returns the state dictionary of the model. We add the state dictionary of the value head
to the state dictionary of the wrapped model by prepending the key with `v_head.`.
"""
state_dict = self.v_head.state_dict(*args, **dict(prefix="v_head.", **kwargs))
if not heads_only:
state_dict = {
**state_dict,
**self.base_model.state_dict(*args, **dict(prefix="" if self.peft_type else "base_model.", **kwargs)),
}
if self.frozen_head:
state_dict = {
**state_dict,
**self.frozen_head.state_dict(*args, **dict(prefix="frozen_head.", **kwargs)),
}
return state_dict
def post_init(self, state_dict):
"""
Load `state_dict` into the model. If peft was used to train the model,
only the value head would be present in the loaded `state_dict`, so the
loading has to be not strict. Also `frozen_head` will be recreated and
loaded from the checkpoint, to comply with deepspeed checkpoint loading.
"""
strict = not self.peft_type and any(k.startswith("base_model.") or k.startswith("v_head.") for k in state_dict)
if not self.peft_type and self.frozen_head is None:
for k in state_dict:
match = re.search(r"^frozen_head\..+\.(\d+)\.", k)
if match:
self.num_layers_unfrozen = max(self.num_layers_unfrozen, int(match.group(1)) + 1)
config = self.base_model.config
branch_class = hf_get_branch_class(config)
self.frozen_head = branch_class(
self.base_model,
num_layers_unfrozen=self.num_layers_unfrozen,
).eval()
self.load_state_dict(state_dict, strict=strict)
del state_dict
gc.collect() # noqa: E702
class ModelBranch(transformers.PreTrainedModel):
"""Implements the upper trunk of the pretrained reference model used
when computing the PPO KL-divergence penalty.
"""
def __init__(
self,
base_model: transformers.PreTrainedModel,
*,
num_layers_unfrozen: int,
frozen=True,
):
"""
Args:
base_model (transformers.PreTrainedModel): The pretrained model to extract upper trunk from
num_layers_unfrozen (int): The number of trainable layers
"""
super().__init__(base_model.config)
# The branch is defined by the last `num_layers_unfrozen` layers of the pretrained model
decoder_blocks = hf_get_decoder_blocks(base_model)[-num_layers_unfrozen:]
final_norm = hf_get_decoder_final_norm(base_model)
lm_head = hf_get_lm_head(base_model)
with deepspeed.zero.GatheredParameters(
list(decoder_blocks.parameters()) + list(final_norm.parameters()) + list(lm_head.parameters()),
modifier_rank=None,
):
self.decoder_blocks = deepcopy(decoder_blocks)
self.final_norm = deepcopy(final_norm)
self.lm_head = deepcopy(lm_head)
self.hidden_size = hf_get_hidden_size(self.config)
self.model_parallel = False
self.device_map = None
self.last_device = None
self.gradient_checkpointing = False
# Freeze the entire branch
if frozen:
for parameter in self.parameters():
parameter.requires_grad_(False)
class GPTModelBranch(ModelBranch):
def forward( # noqa: max-complexity
self,
hidden_states: torch.Tensor, # Takes as input hidden_states instead of input_ids
output_shape: torch.Tensor, # output_size given by main trunk
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = False,
) -> Union[Tuple, CausalLMOutputWithValue]:
"""Reference:
https://github.com/huggingface/transformers/blob/2411f0e465e761790879e605a4256f3d4afb7f82/src/transformers/models/gpt2/modeling_gpt2.py#L743 # noqa: E501
"""
batch_size, seq_length = hidden_states.shape[:2]
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
device = hidden_states.device
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.decoder_blocks))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length)
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = attention_mask[:, None, None, :]
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
if self.config.add_cross_attention and encoder_hidden_states is not None:
(
encoder_batch_size,
encoder_sequence_length,
_,
) = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_attention_mask = None
head_mask = self.get_head_mask(head_mask, hf_get_num_hidden_layers(self.config))
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.decoder_blocks, past_key_values)):
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
if layer_past is not None:
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
kwargs = dict(
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
# Assumes we are never training the branch
block_params = inspect.getfullargspec(block.forward).args
if "encoder_hidden_states" not in block_params:
kwargs.pop("encoder_hidden_states")
kwargs.pop("encoder_attention_mask")
# Remove position_ids for GPT2Block
if "position_ids" not in block_params:
kwargs.pop("position_ids")
outputs = block(hidden_states, **kwargs)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))
hidden_states = self.final_norm(hidden_states)
hidden_states = hidden_states.view(output_shape)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.model_parallel:
torch.cuda.set_device(self.transformer.first_device)
hidden_states = hidden_states.to(self.lm_head.weight.device)
lm_logits = self.lm_head(hidden_states)
if not return_dict:
outputs = (lm_logits,) + (None,) + (None,)
return outputs
return CausalLMOutputWithValue(
logits=lm_logits,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
class OPTModelBranch(ModelBranch):
def forward( # noqa: max-complexity
self,
hidden_states: torch.Tensor,
output_shape: torch.Tensor,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = False,
) -> Union[Tuple, CausalLMOutputWithValue]:
"""Reference:
https://github.com/huggingface/transformers/blob/bdb84e2bada3658f99c6a81c963ec562f8485151/src/transformers/models/opt/modeling_opt.py#L840 # noqa: E501
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
attention_mask = torch.ones(hidden_states.shape[:2], dtype=torch.bool, device=hidden_states.device)
input_shape = hidden_states.size()[:-1]
combined_attention_mask = None
if input_shape[-1] > 1:
# `modeling_opt._make_causal_mask` @ transformers==4.27.1 doesn't have the `device` argument
if "device" in inspect.getfullargspec(modeling_opt._make_causal_mask).args:
kwargs = dict(device=hidden_states.device)
else:
kwargs = {}
combined_attention_mask = modeling_opt._make_causal_mask(
input_shape,
hidden_states.dtype,
past_key_values_length=past_key_values_length,
**kwargs,
).to(hidden_states.device)
if attention_mask is not None:
expanded_attn_mask = modeling_opt._expand_mask(
attention_mask, hidden_states.dtype, tgt_len=input_shape[-1]
).to(hidden_states.device)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
attention_mask = combined_attention_mask
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
if attn_mask is not None:
if attn_mask.size()[0] != (len(self.decoder_blocks)):
raise ValueError(
f"The `{mask_name}` should be specified for {len(self.decoder_blocks)} layers, but it is for"
f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.decoder_blocks):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
layer_outputs = decoder_layer(
hidden_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if self.final_norm is not None:
hidden_states = self.final_norm(hidden_states)
# TODO: Add output projection support
# https://github.com/huggingface/transformers/blob/699e90437f984d69ad3c9b891dd2e9d0fc2cffe4/src/transformers/models/opt/modeling_opt.py#L499 # noqa: E501
# if self.project_out is not None:
# hidden_states = self.project_out(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
lm_logits = self.lm_head(hidden_states).contiguous()
if not return_dict:
return tuple(
v
for v in [
lm_logits,
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
]
if v is not None
)
return CausalLMOutputWithValue(
logits=lm_logits,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class BloomModelBranch(ModelBranch):
def forward( # noqa: max-complexity
self,
hidden_states: torch.Tensor, # Takes as input hidden_states instead of input_ids
output_shape: torch.Tensor,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = False,
) -> Union[Tuple, CausalLMOutputWithValue]:
"""Reference:
https://github.com/huggingface/transformers/blob/2411f0e465e761790879e605a4256f3d4afb7f82/src/transformers/models/bloom/modeling_bloom.py#L623 # noqa: E501
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size, seq_length = hidden_states.shape[:2]
if past_key_values is None:
past_key_values = tuple([None] * len(self.decoder_blocks))
head_mask = self.get_head_mask(head_mask, hf_get_num_hidden_layers(self.config))
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)
alibi = modeling_bloom.build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype)
combined_attention_mask = None
device = attention_mask.device
input_shape = (batch_size, seq_length)
_, src_length = input_shape
if src_length > 1:
combined_attention_mask = modeling_bloom._make_causal_mask(
input_shape,
device=device,
past_key_values_length=past_key_values_length,
)
expanded_attn_mask = modeling_bloom._expand_mask(attention_mask, tgt_length=src_length)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
)
causal_mask = combined_attention_mask
for i, (block, layer_past) in enumerate(zip(self.decoder_blocks, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
hidden_states = self.final_norm(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
lm_logits = self.lm_head(hidden_states)
if not return_dict:
return tuple(
v
for v in [
lm_logits,
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
]
if v is not None
)
return CausalLMOutputWithValue(
logits=lm_logits,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class LlamaModelBranch(ModelBranch):
def _make_causal_mask(self, input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, hidden_states, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = self._make_causal_mask(
input_shape, hidden_states.dtype, past_key_values_length=past_key_values_length
).to(hidden_states.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = self._expand_mask(attention_mask, hidden_states.dtype, tgt_len=input_shape[-1]).to(
hidden_states.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
hidden_states: torch.Tensor,
output_shape: torch.Tensor,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = False,
) -> Union[Tuple, CausalLMOutputWithValue]:
"""Reference:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L491
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size, seq_length = hidden_states.shape[:2]
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = hidden_states.device if hidden_states is not None else encoder_hidden_states.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.decoder_blocks):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.final_norm(hidden_states)
hidden_states = hidden_states.view(output_shape)
lm_logits = self.lm_head(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
outputs = (lm_logits,) + (None,) + (None,)
return outputs
return CausalLMOutputWithValue(
logits=lm_logits,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class GPTBigCodeModelBranch(ModelBranch):
def __init__(
self,
base_model: transformers.PreTrainedModel,
*,
num_layers_unfrozen: int,
):
"""
Args:
base_model (transformers.PreTrainedModel): The pretrained model to extract upper trunk from
num_layers_unfrozen (int): The number of trainable layers
"""
super().__init__(base_model, num_layers_unfrozen=num_layers_unfrozen)
self.config = base_model.transformer.config
self.bias = base_model.transformer.bias
self.multi_query = base_model.transformer.multi_query
self.get_head_mask = base_model.transformer.get_head_mask
def forward( # noqa: C901
self,
hidden_states: torch.Tensor, # Takes as input hidden_states instead of input_ids
output_shape: torch.Tensor, # output_size given by main trunk
past_key_values: Optional[List[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithValue]:
"""Reference:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L539
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size, seq_length = hidden_states.shape[:2]
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
device = hidden_states.device
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.decoder_blocks))
else:
past_length = past_key_values[0].size(-2)
# Self-attention mask.
query_length = seq_length
key_length = past_length + query_length
self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length].to(device)
if attention_mask is not None:
self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to(
dtype=torch.bool, device=self_attention_mask.device
)
# MQA models: (batch_size, query_length, n_heads, key_length)
# MHA models: (batch_size, n_heads, query_length, key_length)
attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.add_cross_attention and encoder_hidden_states is not None and encoder_attention_mask is not None:
if encoder_attention_mask.dim() == 2:
encoder_attention_mask.unsqueeze(1)
assert encoder_attention_mask.dim() == 3
encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2 if self.multi_query else 1)
else:
encoder_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
presents = [] if use_cache else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.decoder_blocks, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache:
presents.append(outputs[1])
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
hidden_states = self.final_norm(hidden_states)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
lm_logits = self.lm_head(hidden_states)
if not return_dict:
return tuple(
v
for v in [
lm_logits,
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return CausalLMOutputWithValue(
logits=lm_logits,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
# Seq2Seq architectures
@dataclass
class Seq2SeqLMOutputWithValue(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
value: Optional[torch.FloatTensor] = None
class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper):
"""An `AutoModel` class wrapper for `transformers` sequence-to-sequence
models that have a language modeling head and a value head
"""
_auto_model_parent_class = transformers.AutoModelForSeq2SeqLM
_supported_modules = ["v_head"]
_supported_args = ["peft_config", "num_value_layers_unfrozen"]
def __init__(
self,
base_model: transformers.PreTrainedModel,
peft_config=None,
num_value_layers_unfrozen=0,
):
super().__init__(base_model, peft_config=peft_config)
# TODO: Support Seq2Seq value branching
if num_value_layers_unfrozen > 0:
raise NotImplementedError("Value branches unsupported for Seq2Seq architecture")
self.v_head = make_head(hf_get_hidden_size(self.base_model.config), 1)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.FloatTensor] = None,
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = True,
output_hidden_states: Optional[bool] = True,
return_dict: Optional[bool] = None,
ignore_peft_adapter: Optional[bool] = None,
) -> Seq2SeqLMOutputWithValue:
forward_kwargs = self.get_compatible_forward_kwargs(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
forward_kwargs["output_hidden_states"] = True
forward_kwargs["return_dict"] = True
if self.peft_type == "PREFIX_TUNING":
# In this case peft redefines past_key_values, remove it to avoid an exception.
forward_kwargs.pop("past_key_values", None)
if self.peft_type and ignore_peft_adapter:
if "LORA" in self.peft_type:
# For LORA, temporarily disable the adapter
lora_model = self.base_model.base_model
lora_model.disable_adapter_layers()
outputs = self.base_model(**forward_kwargs)
lora_model.enable_adapter_layers()
else:
# For prompt or prefix adapters, just use the base model of PeftModel
outputs = self.base_model.base_model(**forward_kwargs)
else:
outputs = self.base_model(**forward_kwargs)
last_hidden_state = outputs.decoder_hidden_states[-1]
value = self.v_head(last_hidden_state).squeeze(-1)
return Seq2SeqLMOutputWithValue(**outputs, value=value)
def generate(self, *args, **kwargs) -> Union[ModelOutput, torch.LongTensor]:
return self.base_model.generate(*args, **kwargs)
def state_dict(self, *args, heads_only=False, **kwargs):
"""
Returns the state dictionary of the model. We add the state dictionary of the value head
to the state dictionary of the wrapped model by prepending the key with `v_head.`.
"""
state_dict = self.v_head.state_dict(*args, **dict(prefix="v_head.", **kwargs))
if not heads_only:
state_dict = {**state_dict, **self.base_model.state_dict(*args, **dict(prefix="base_model.", **kwargs))}
return state_dict
def post_init(self, state_dict):
"""
Adds the state dictionary of the value head to the state dictionary of the wrapped model
by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the
keys of the value head state dictionary.
"""
super().post_init()
trlx_checkpoint = any(k.startswith("base_model.") or k.startswith("v_head.") for k in state_dict)
self.load_state_dict(state_dict, strict=trlx_checkpoint)
del state_dict
gc.collect() # noqa: E702
class AutoModelForSeq2SeqLMWithHydraValueHead(AutoModelForSeq2SeqLMWithValueHead):
_supported_modules = ["v_head", "frozen_head"]
_supported_args = ["num_layers_unfrozen", "peft_config", "num_value_layers_unfrozen"]
def __init__(
self,
base_model: transformers.PreTrainedModel,
*,
num_layers_unfrozen: int = -1,
peft_config=None,
num_value_layers_unfrozen: int = 0,
):
super().__init__(base_model, peft_config=peft_config, num_value_layers_unfrozen=num_value_layers_unfrozen)
self.num_layers_unfrozen = num_layers_unfrozen
if self.num_layers_unfrozen > 0 and not self.peft_type:
branch_class = T5Branch # TODO: Add support for other model branches
self.frozen_head = branch_class(
self.base_model,
num_layers_unfrozen=self.num_layers_unfrozen,
).eval()
else:
self.frozen_head = None
def forward_hydra(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.FloatTensor] = None,
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Seq2SeqLMOutputWithValue:
forward_kwargs = self.get_compatible_forward_kwargs(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
return_dict = forward_kwargs.get("return_dict", True)
forward_kwargs["output_hidden_states"] = True
forward_kwargs["return_dict"] = True
if self.peft_type:
hydra_outputs = self.forward(**forward_kwargs, ignore_peft_adapter=True)
else:
outputs = self.forward(**forward_kwargs)
# Select the hidden state before the first branching layer
input_hidden_state = outputs.decoder_hidden_states[-(self.num_layers_unfrozen + 1)]
hydra_outputs = self.frozen_head(
hidden_states=input_hidden_state,
attention_mask=decoder_attention_mask,
encoder_hidden_states=outputs.encoder_last_hidden_state,
encoder_attention_mask=attention_mask,
use_cache=False,
output_attentions=False,
output_hidden_states=True,
return_dict=return_dict,
)
if not return_dict:
return hydra_outputs.logits
return hydra_outputs
def state_dict(self, *args, heads_only=False, **kwargs):
"""
Returns the state dictionary of the model. We add the state dictionary of the value head
to the state dictionary of the wrapped model by prepending the key with `v_head.`.
"""
state_dict = self.v_head.state_dict(*args, **dict(prefix="v_head.", **kwargs))
if not heads_only:
state_dict = {
**state_dict,
**self.base_model.state_dict(*args, **dict(prefix="" if self.peft_type else "base_model.", **kwargs)),
}
if self.frozen_head:
state_dict = {
**state_dict,
**self.frozen_head.state_dict(*args, **dict(prefix="frozen_head.", **kwargs)),
}
return state_dict
def post_init(self, state_dict):
"""
Load `state_dict` into the model. If peft was used to train the model,
only the value head would be present in the loaded `state_dict`, so the
loading has to be not strict. Also `frozen_head` will be recreated and
loaded from the checkpoint, to comply with deepspeed checkpoint loading.
"""
strict = not self.peft_type and any(k.startswith("base_model.") or k.startswith("v_head.") for k in state_dict)
if not self.peft_type and self.frozen_head is None:
for k in state_dict:
match = re.search(r"^frozen_head\.decoder_blocks\.(\d+)", k)
if match:
self.num_layers_unfrozen = max(self.num_layers_unfrozen, int(match.group(1)) + 1)
branch_class = T5Branch # TODO: Add support for other model branches
self.frozen_head = branch_class(
self.base_model,
num_layers_unfrozen=self.num_layers_unfrozen,
).eval()
self.load_state_dict(state_dict, strict=strict)
del state_dict
gc.collect() # noqa: E702
class T5Branch(ModelBranch):
"""Decoder only T5 branch"""
def __init__(
self,
base_model: transformers.PreTrainedModel,
*,
num_layers_unfrozen: int,
):
super().__init__(base_model, num_layers_unfrozen=num_layers_unfrozen)
self.dropout = hf_get_decoder(base_model).dropout
self.is_decoder = True
def forward( # noqa: max-complexity
self,
hidden_states: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Seq2SeqLMOutputWithValue]:
"""Reference:
https://github.com/huggingface/transformers/blob/bc21aaca789f1a366c05e8b5e111632944886393/src/transformers/models/t5/modeling_t5.py#L899 # noqa: E501
"""
batch_size, seq_length = hidden_states.shape[:2]
input_shape = (batch_size, seq_length)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if attention_mask is None:
attention_mask = torch.ones(batch_size, seq_length, device=hidden_states.device)
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
encoder_seq_length = encoder_hidden_states.shape[1]
encoder_attention_mask = torch.ones(
batch_size, encoder_seq_length, device=hidden_states.device, dtype=torch.long
)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
if self.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=hidden_states.device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
position_bias = None
encoder_decoder_position_bias = None
for _, layer_module in enumerate(self.decoder_blocks):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
hidden_states,
attention_mask=extended_attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
use_cache=use_cache,
output_attentions=output_attentions,
)
if use_cache is False:
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
hidden_states, present_key_value_state = layer_outputs[:2]
position_bias = layer_outputs[2]
if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[3],)
hidden_states = self.final_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
sequence_output = hidden_states
if self.config.tie_word_embeddings:
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 # noqa: E501
sequence_output = sequence_output * (self.config.d_model**-0.5)
lm_logits = self.lm_head(sequence_output)
if not return_dict:
return (lm_logits,)
return Seq2SeqLMOutputWithValue(
logits=lm_logits,
decoder_hidden_states=all_hidden_states,
decoder_attentions=all_attentions,
)
# Branch class utils
def hf_get_branch_class(
config: transformers.PretrainedConfig,
) -> "ModelBranch":
"""Returns the model branch class for the given config."""
gpt_branch_supported_archs = [
"GPTJForCausalLM",
"GPT2LMHeadModel",
"GPTNeoForCausalLM",
"GPTNeoXForCausalLM",
]
opt_branch_supported_archs = ["OPTForCausalLM"]
bloom_branch_supported_archs = ["BloomModel", "BloomForCausalLM"]
llama_branch_supported_archs = ["LlamaModel", "LlamaForCausalLM"]
bigcode_branch_supported_archs = ["GPTBigCodeModel", "GPTBigCodeForCausalLM"]
arch = config.architectures[0]
if arch in gpt_branch_supported_archs:
return GPTModelBranch
elif arch in opt_branch_supported_archs:
return OPTModelBranch
elif arch in bloom_branch_supported_archs:
return BloomModelBranch
elif arch in llama_branch_supported_archs:
return LlamaModelBranch
elif arch in bigcode_branch_supported_archs:
return GPTBigCodeModelBranch
else:
all_supported_archs = sum(
[
gpt_branch_supported_archs,
opt_branch_supported_archs,
bloom_branch_supported_archs,
llama_branch_supported_archs,
bigcode_branch_supported_archs,
],
[],
)
raise ValueError(
f"Unsupported architecture: `{arch}`. The following architectures are "
f"available for model branching:\n{all_supported_archs}"
)