Source code for trlx.models.modeling_ilql

import gc
import os
from copy import deepcopy
from dataclasses import dataclass
from functools import reduce
from typing import Any, Optional, Tuple

import deepspeed  # type: ignore
import numpy as np
import torch
import torch.nn.functional as F
import transformers
from torch import nn
from torchtyping import TensorType
from transformers.modeling_outputs import ModelOutput

from trlx.data.ilql_types import ILQLBatch
from trlx.data.method_configs import MethodConfig, register_method
from trlx.models.modeling_base import PreTrainedModelWrapper
from trlx.utils.modeling import (
    flatten_dict,
    get_tensor_stats,
    hf_get_hidden_size,
    hf_get_lm_head,
    make_head,
)


def topk_mask(xs: torch.FloatTensor, k: int):
    if k > xs.shape[-1]:
        return xs
    mintop = torch.topk(xs, k)[0][:, -1].unsqueeze(-1)
    return torch.where(xs < mintop, -np.inf * torch.ones_like(xs, dtype=xs.dtype), xs)


def batched_index_select(
    x: TensorType["batch", "seq_len", "hidden"],
    idxs: TensorType["batch", "index_len"],
    dim: int,
) -> TensorType["batch", "index_len", "hidden"]:
    """
    Gather vectors at idxs along dim from x
    """
    idxs = idxs.unsqueeze(-1).expand(idxs.shape[0], idxs.shape[1], x.shape[-1])
    return x.gather(dim=dim, index=idxs)


[docs]@dataclass @register_method class ILQLConfig(MethodConfig): """ Configuration for ILQL method. :param tau: Parameter for expectile regression for the value function to q estimates, \in (0, 1), where tau=0.5 is equivalent to the mean square error and tau=1 is equivalent to taking a maximum over q estimates :type tau: float :param gamma: Discount factor :type gamma: float :param cql_scale: Scale for the CQL loss (conservative q-learning loss) :type cql_scale: float :param awac_scale: Scale for the AWAC loss (weighted cross-entropy loss) :type awac_scale: float :param alpha: Parameter for Polyak averaging of the target Q-head sync, \in (0, 1) :type alpha: float :param beta: Parameter for magnitude of weighting effect in the AWAC loss, \in (0, 1) :type beta: float :param steps_for_target_q_sync: Number of steps between target Q-head syncs :type steps_for_target_q_sync: int :param two_qs: Whether to use two Q-heads and taking minimum of separate estimates or using only one :type two_qs: bool :param gen_kwargs: Keyword arguments for the generation method :type gen_kwargs: dict """ tau: float gamma: float cql_scale: float awac_scale: float alpha: float beta: float steps_for_target_q_sync: int two_qs: bool gen_kwargs: dict def loss(self, outputs, labels): logits, (qs, target_qs, vs) = outputs terminal_mask = labels.dones[:, :-1] n_nonterminal = max(1, terminal_mask.sum()) # check type of labels if isinstance(labels, ILQLBatch): actions = labels.input_ids[:, 1:].gather(dim=1, index=labels.actions_ixs).unsqueeze(-1) else: actions = labels.decoder_input_ids[:, 1:].unsqueeze(-1) nactions = actions.shape[1] bsize, _, dsize = logits.shape Q = [q.gather(-1, actions).squeeze(-1) for q in qs] targetQs = [q.gather(-1, actions).squeeze(-1).detach() for q in target_qs] targetQ = reduce(torch.minimum, targetQs) # The loss_q assumes len(states) == len(rewards) + 1 # values of current states V = vs[:, :-1, 0] # values of next states Vnext = vs[:, 1:, 0] * labels.dones[:, 1:].to(vs.dtype) # target to fit Q Q_ = labels.rewards + self.gamma * Vnext.detach() loss_qs = [((Qi - Q_) * terminal_mask).pow(2).sum() / n_nonterminal for Qi in Q] loss_q = sum(loss_qs) targetQ = targetQ.detach() loss_v = ( ( (targetQ >= V).int() * self.tau * (targetQ - V).pow(2) + (targetQ < V).int() * (1 - self.tau) * (targetQ - V).pow(2) ) * terminal_mask ).sum() / n_nonterminal def cql_loss(q): loss = F.cross_entropy(q.reshape(-1, dsize), actions.reshape(-1), reduction="none") loss = loss.reshape(bsize, nactions) * terminal_mask loss = loss.sum() / n_nonterminal return loss loss_cql = sum(cql_loss(q) for q in qs) # select logits from continuations action_logits = batched_index_select(logits, labels.actions_ixs, dim=1) cross_entropy = F.cross_entropy( action_logits.reshape(-1, dsize), actions.reshape(-1), reduction="none", ).reshape(bsize, nactions) with torch.no_grad(): awac_weight = torch.exp(self.beta * (targetQ - V)) loss_awac = torch.sum(cross_entropy * awac_weight * terminal_mask) / n_nonterminal loss = loss_q + loss_v + self.cql_scale * loss_cql + self.awac_scale * loss_awac stats = dict( losses=dict( loss=loss.item(), loss_q=loss_q.item(), loss_v=loss_v.item(), loss_cql=loss_cql.item(), loss_awac=loss_awac.item(), ), values=get_tensor_stats(V, terminal_mask, n_nonterminal), qvalues={str(ix): get_tensor_stats(Q[ix], terminal_mask, n_nonterminal) for ix in range(len(Q))}, awac_weight=get_tensor_stats(awac_weight, terminal_mask, n_nonterminal), ) return loss, flatten_dict(stats)
class ILQLHeads(nn.Module): def __init__( self, hidden_size: int, vocab_size: int, two_qs: bool, alpha: float, dtype: type, ): super().__init__() self.hidden_size = hidden_size self.vocab_size = vocab_size self.two_qs = two_qs self.alpha = alpha self.v_head = make_head(self.hidden_size, 1, dtype) n_qs = 2 if self.two_qs else 1 self.q_heads = nn.ModuleList(make_head(self.hidden_size, self.vocab_size, dtype) for _ in range(n_qs)) self.target_q_heads = nn.ModuleList(deepcopy(q_head) for q_head in self.q_heads) for target_q_head in self.target_q_heads: target_q_head.requires_grad_(False) def forward( self, hs: TensorType["batch", "seq_len", "hidden"], states_ixs: Optional[TensorType["batch", "states_seq_len"]] = None, actions_ixs: Optional[TensorType["batch", "actions_seq_len"]] = None, **kwargs, ) -> Tuple[ Tuple[TensorType["batch", "actions_seq_len", "hidden"]], Tuple[TensorType["batch", "actions_seq_len", "hidden"]], TensorType["batch", "states_seq_len", "hidden"], ]: if states_ixs is not None: states_hs = batched_index_select(hs, states_ixs, 1) actions_hs = batched_index_select(hs, actions_ixs, 1) else: states_hs = actions_hs = hs qs = tuple(q_head(actions_hs) for q_head in self.q_heads) target_qs = tuple(q_head(actions_hs) for q_head in self.target_q_heads) vs = self.v_head(states_hs) return qs, target_qs, vs def _sync_target_q_heads(self, alpha): for target_q_head, q_head in zip(self.target_q_heads, self.q_heads): for target_param, copy_param in zip(target_q_head.parameters(), q_head.parameters()): target_param.data.copy_((alpha * copy_param.data) + (1.0 - alpha) * target_param.data) def sync_target_q_heads(self): if os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE", "0") == "3": with deepspeed.zero.GatheredParameters(list(self.parameters()), modifier_rank=0): if deepspeed.comm.get_rank() == 0: self._sync_target_q_heads(self.alpha) else: self._sync_target_q_heads(self.alpha)
[docs]@dataclass class CausalILQLOutput(ModelOutput): """ Output of the causal model with ILQL heads. :param logits: Logits of the causal model. :type logits: torch.FloatTensor :param past_key_values: Tuple of past key values of the causal model. :type past_key_values: Tuple[Tuple[torch.FloatTensor]] :param hidden_states: Last hidden state of the causal model. :type hidden_states: Tuple[torch.FloatTensor] :param value: Value function estimation for each token in the input sequence. :type value: torch.FloatTensor :param qs: Q-function estimations for each token in the input sequence. :type qs: Tuple[torch.FloatTensor] :param target_qs: Q-function estimations from the target Q-head for each token in the input sequence. :type target_qs: Tuple[torch.FloatTensor] """ logits: Optional[torch.FloatTensor] = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None value: Optional[torch.FloatTensor] = None qs: Optional[Tuple[torch.FloatTensor]] = None target_qs: Optional[Tuple[torch.FloatTensor]] = None
class AutoModelForCausalLMWithILQLHeads(PreTrainedModelWrapper): """An `AutoModel` class wrapper for `transformers` causal models with a language modeling head and ILQL heads. References: [1] Snell et al., "Offline RL for Natural Language Generation with Implicit Language Q Learning", https://arxiv.org/abs/2206.11871, 2022 """ _auto_model_parent_class = transformers.AutoModelForCausalLM _supported_modules = ["ilql_heads"] _supported_args = ["two_qs", "alpha", "peft_config"] def __init__( self, base_model: transformers.PreTrainedModel, *, two_qs: bool = True, alpha: float = 0.99, peft_config=None, ): super().__init__(base_model, peft_config=peft_config) hidden_size = hf_get_hidden_size(self.base_model.config) vocab_size = self.base_model.config.vocab_size dtype = next(hf_get_lm_head(self.base_model).parameters()).dtype self.two_qs = two_qs self.alpha = alpha self.ilql_heads = ILQLHeads(hidden_size, vocab_size, self.two_qs, self.alpha, dtype=dtype) def forward( self, input_ids, attention_mask=None, position_ids=None, past_key_values=None, actions_ixs=None, states_ixs=None, return_dict=False, bypass_peft_prompt_adapter=False, ): forward_kwargs = self.get_compatible_forward_kwargs( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, ) forward_kwargs["output_hidden_states"] = True if self.peft_type == "PREFIX_TUNING" and not bypass_peft_prompt_adapter: # Peft redefines past_key_values, remove it to avoid an exception. forward_kwargs.pop("past_key_values", None) if bypass_peft_prompt_adapter: outputs = self.base_model.base_model(**forward_kwargs) else: outputs = self.base_model(**forward_kwargs) qs, target_qs, vs = self.ilql_heads(outputs.hidden_states[-1], states_ixs=states_ixs, actions_ixs=actions_ixs) if return_dict: return CausalILQLOutput(outputs.logits, outputs.past_key_values, outputs.hidden_states, vs, qs, target_qs) return outputs.logits, qs, target_qs, vs, outputs.past_key_values def generate( self, input_ids, attention_mask=None, position_ids=None, past_key_values=None, beta=1, max_new_tokens=32, max_length=1024, temperature=1, top_k=20, logit_mask=None, pad_token_id=None, eos_token_id=None, ): """ Generates samples akin to hf's `.generate` but with custom logp preprocessing: changing token probabilities as to how advantageous they would be according to value functions estimations. """ pad_token_id = pad_token_id if pad_token_id is not None else self.base_model.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.base_model.config.eos_token_id if attention_mask is None: attention_mask = input_ids.not_equal(pad_token_id) if position_ids is None: position_ids = attention_mask.cumsum(-1) - 1 position_ids.masked_fill_(attention_mask.eq(0), 0) samples = input_ids.clone() max_new_tokens = min(max_new_tokens, max_length - input_ids.shape[1]) finished = torch.zeros(input_ids.shape[0], 1, dtype=torch.long, device=input_ids.device) bypass_peft = False for token in range(max_new_tokens): out = self.forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, bypass_peft_prompt_adapter=bypass_peft, ) logits, _, target_qs, vs, past_key_values = out if self.two_qs: qs = torch.minimum(target_qs[0][:, -1, :], target_qs[1][:, -1, :]) else: qs = target_qs[0][:, -1, :] logits = logits[:, -1, :] vs = vs[:, -1, :] if logit_mask is not None: mask = logit_mask[input_ids[:, -1].squeeze().to(logit_mask.device)] logits[torch.where(mask)] = -np.inf adv = qs - vs pi_beta = F.log_softmax(logits, -1) pi_top_k = topk_mask(pi_beta + beta * adv, top_k) if temperature == 0.0: input_ids = pi_top_k.argmax(dim=-1, keepdim=True) else: pi = F.softmax(pi_top_k / temperature, -1) input_ids = torch.multinomial(pi, num_samples=1) input_ids = (1 - finished) * input_ids + finished * eos_token_id finished = (input_ids == eos_token_id).long() samples = torch.hstack((samples, input_ids)) attention_mask = torch.hstack((attention_mask, (input_ids != eos_token_id).long())) position_ids = (position_ids[:, -1] + 1).view(-1, 1) # Some peft models add a prefix to the prompt at each forward pass. # We need to bypass it so that it doesn't add multiple times the prefix. if self.peft_type and token == 0 and "LORA" not in self.peft_type: bypass_peft = True prefix_attention_mask = torch.ones(input_ids.shape[0], self.peft_config.num_virtual_tokens).to( input_ids.device ) attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) if os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE", "0") != "3" and torch.all(finished): break return samples def sync_target_q_heads(self): self.ilql_heads.sync_target_q_heads() def state_dict(self, *args, heads_only=False, **kwargs): """ Returns the state dictionary of the model. We add the state dictionary of the ilql heads to the state dictionary of the wrapped model by prepending the key with `ilql_heads.`. """ state_dict = self.ilql_heads.state_dict(*args, **dict(prefix="ilql_heads.", **kwargs)) if not heads_only: state_dict = { **state_dict, **self.base_model.state_dict(*args, **dict(prefix="" if self.peft_type else "base_model.", **kwargs)), } return state_dict def post_init(self, state_dict): """ We add the state dictionary of the ilql heads to the state dictionary of the wrapped model by prepending the key with `ilql_heads.`. """ super().post_init() strict = not self.peft_type and any( k.startswith("base_model.") or k.startswith("ilql_heads.") for k in state_dict ) self.load_state_dict(state_dict, strict=strict) del state_dict gc.collect()
[docs]@dataclass class Seq2SeqILQLOutput(ModelOutput): """ Output of the seq2seq model with ILQL heads. :param logits: Logits of the seq2seq model. :type logits: torch.FloatTensor :param past_key_values: Tuple of past key values of the seq2seq model. :type past_key_values: Tuple[Tuple[torch.FloatTensor]] :param hidden_states: Last hidden state of the seq2seq model. :type hidden_states: Tuple[torch.FloatTensor] :param value: Value function estimation for each token in the input sequence. :type value: torch.FloatTensor :param qs: Q-function estimations for each token in the input sequence. :type qs: Tuple[torch.FloatTensor] :param target_qs: Q-function estimations from the target Q-head for each token in the input sequence. :type target_qs: Tuple[torch.FloatTensor] :param encoder_outputs: Tuple of encoder outputs of the seq2seq model. :type encoder_outputs: Tuple[Any] """ logits: Optional[torch.FloatTensor] = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None value: Optional[torch.FloatTensor] = None qs: Optional[Tuple[torch.FloatTensor]] = None target_qs: Optional[Tuple[torch.FloatTensor]] = None encoder_outputs: Optional[Tuple[Any]] = None
class AutoModelForSeq2SeqLMWithILQLHeads(PreTrainedModelWrapper): """This is a wrapper around huggingface AutoModelForSeq2Seq with two additional scalar heads""" _auto_model_parent_class = transformers.AutoModelForSeq2SeqLM _supported_modules = ["ilql_heads"] _supported_args = ["two_qs", "alpha", "peft_config"] def __init__( self, base_model: transformers.PreTrainedModel, *, two_qs: bool = True, alpha: float = 0.99, peft_config=None, ): super().__init__(base_model, peft_config=peft_config) hidden_size = hf_get_hidden_size(self.base_model.config) vocab_size = self.base_model.config.vocab_size dtype = next(hf_get_lm_head(self.base_model).parameters()).dtype self.two_qs = two_qs self.alpha = alpha self.ilql_heads = ILQLHeads(hidden_size, vocab_size, self.two_qs, self.alpha, dtype=dtype) def sync_target_q_heads(self): self.ilql_heads.sync_target_q_heads() def state_dict(self, *args, heads_only=False, **kwargs): """ Returns the state dictionary of the model. We add the state dictionary of the ilql heads to the state dictionary of the wrapped model by prepending the key with `ilql_heads.`. """ state_dict = self.ilql_heads.state_dict(*args, **dict(prefix="ilql_heads.", **kwargs)) if not heads_only: state_dict = { **state_dict, **self.base_model.state_dict(*args, **dict(prefix="" if self.peft_type else "base_model.", **kwargs)), } return state_dict def post_init(self, state_dict): """ We add the state dictionary of the ilql heads to the state dictionary of the wrapped model by prepending the key with `ilql_heads.`. """ super().post_init() strict = not self.peft_type and any( k.startswith("base_model.") or k.startswith("ilql_heads.") for k in state_dict ) self.load_state_dict(state_dict, strict=strict) del state_dict gc.collect() def forward( self, input_ids, attention_mask=None, decoder_attention_mask=None, decoder_input_ids=None, past_key_values=None, decoder_inputs_embeds=None, encoder_outputs=None, actions_ixs=None, states_ixs=None, output_attentions=True, output_hidden_states=True, return_dict=False, bypass_peft_prompt_adapter=False, ): forward_kwargs = self.get_compatible_forward_kwargs( input_ids=input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, decoder_input_ids=decoder_input_ids, past_key_values=past_key_values, decoder_inputs_embeds=decoder_inputs_embeds, encoder_outputs=encoder_outputs, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) if self.peft_type == "PREFIX_TUNING" and not bypass_peft_prompt_adapter: # Peft redefines past_key_values, remove it to avoid an exception. forward_kwargs.pop("past_key_values", None) if bypass_peft_prompt_adapter: out = self.base_model.base_model(**forward_kwargs) else: out = self.base_model(**forward_kwargs) hs = out.decoder_hidden_states[-1] logits = self.base_model.lm_head(hs) qs, target_qs, vs = self.ilql_heads(hs, states_ixs=states_ixs, actions_ixs=actions_ixs) encoder_outputs = (out.encoder_last_hidden_state, out.encoder_hidden_states, out.encoder_attentions) if return_dict: return Seq2SeqILQLOutput( logits, out.past_key_values, out.decoder_hidden_states, vs, qs, target_qs, encoder_outputs ) return logits, qs, target_qs, vs, out.past_key_values, encoder_outputs def generate( self, input_ids, attention_mask=None, decoder_attention_mask=None, decoder_input_ids=None, past_key_values=None, encoder_outputs=None, beta=1, max_new_tokens=32, max_length=1024, temperature=1, top_k=20, logit_mask=None, pad_token_id=None, eos_token_id=None, ): """ Generates samples akin to hf's `.generate` but with custom logp preprocessing: changing token probabilities as to how advantageous they would be according to value functions estimations. """ if eos_token_id is None or pad_token_id is None: raise ValueError("eos_token_id and pad_token_id must be provided") if attention_mask is None: if decoder_attention_mask is not None: attention_mask = decoder_attention_mask else: attention_mask = input_ids.not_equal(pad_token_id) samples = input_ids.clone() max_new_tokens = min(max_new_tokens, max_length - input_ids.shape[1]) if decoder_input_ids is None: decoder_input_ids = input_ids.new_zeros(input_ids.shape[0], 1) finished = torch.zeros(input_ids.shape[0], 1, dtype=torch.long, device=input_ids.device) bypass_peft = False for token in range(max_new_tokens): out = self.forward( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids[:, -1].unsqueeze(-1), past_key_values=past_key_values, encoder_outputs=encoder_outputs, bypass_peft_prompt_adapter=bypass_peft, ) logits, _, target_qs, vs, past_key_values, encoder_outputs = out if self.two_qs: qs = torch.minimum(target_qs[0][:, -1, :], target_qs[1][:, -1, :]) else: qs = target_qs[0][:, -1, :] logits = logits[:, -1, :] vs = vs[:, -1, :] adv = qs - vs pi_beta = F.log_softmax(logits, -1) pi_top_k = topk_mask(pi_beta + beta * adv, top_k) if temperature == 0.0: next_tokens = pi_top_k.argmax(dim=-1, keepdim=True) else: pi = F.softmax(pi_top_k / temperature, -1) next_tokens = torch.multinomial(pi, num_samples=1) next_tokens = (1 - finished) * next_tokens + finished * eos_token_id finished = (next_tokens == eos_token_id).long() | (next_tokens == pad_token_id).long() decoder_input_ids = torch.cat([decoder_input_ids, next_tokens], dim=-1) samples = decoder_input_ids # Some peft models add a prefix to the prompt at each forward pass. # We need to bypass it so that it doesn't add multiple times the prefix. if self.peft_type and token == 0 and "LORA" not in self.peft_type: bypass_peft = True prefix_attention_mask = torch.ones(input_ids.shape[0], self.peft_config.num_virtual_tokens).to( input_ids.device ) attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) if os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE", "0") != "3" and torch.all(finished): break return samples