Source code for trlx.model.nn.ilql_models

import os
from collections import defaultdict
from copy import deepcopy
from itertools import chain
from typing import Union

import deepspeed
import numpy as np
import torch
import torch.nn.functional as F
import transformers
from torch import nn
from transformers import AutoModelForCausalLM, PretrainedConfig

import wandb

def topk_mask(xs: torch.FloatTensor, k: int):
    if k > xs.shape[-1]:
        return xs
    mintop = torch.topk(xs, k)[0][:, -1].unsqueeze(-1)
    return torch.where(xs < mintop, -np.inf * torch.ones_like(xs, dtype=xs.dtype), xs)

def make_head(n_embd: int, out: int):
    return nn.Sequential(
        nn.Linear(n_embd, n_embd * 2), nn.ReLU(), nn.Linear(n_embd * 2, out)

[docs]class CausalLMWithValueHeads(nn.Module): """This is a wrapper around huggingface AutoModelForCausalLM with two additional scalar heads""" def __init__( self, config: Union[PretrainedConfig, str], params, num_layers_unfrozen=-1 ): super().__init__() # enable zero3 init within from_pretrained if os.environ.get("DEEPSPEED_ZERO_STAGE", "0") == "3": config_path = os.environ.get("DEEPSPEED_CONFIG_FILE", "") if config_path: _hfconfig = transformers.deepspeed.HfDeepSpeedConfig( # noqa: F841 config_path ) if isinstance(config, PretrainedConfig): self.gpt = AutoModelForCausalLM.from_config(config) else: self.gpt = AutoModelForCausalLM.from_pretrained(config) if hasattr(self.gpt, "gpt_neox"): self.gpt.transformer = self.gpt.gpt_neox self.gpt.lm_head = self.gpt_embed_out self.n_embd = self.gpt.config.hidden_size gpt_blocks = self.gpt.gpt_neox.layers else: self.n_embd = self.gpt.config.n_embd gpt_blocks = self.gpt.transformer.h if num_layers_unfrozen == 0: gpt_blocks_to_freeze = list(gpt_blocks) elif num_layers_unfrozen > 0: gpt_blocks_to_freeze = list(gpt_blocks)[:-num_layers_unfrozen] else: gpt_blocks_to_freeze = [] for m in gpt_blocks_to_freeze: m.requires_grad_(False) self.vocab_size = self.gpt.config.vocab_size self.v_head = make_head(self.n_embd, 1) self.q1_head = make_head(self.n_embd, self.vocab_size) self.target_q1_head = deepcopy(self.q1_head) self.target_q1_head.requires_grad_(False) self.tau = params.tau self.alpha = params.alpha self.gamma = params.gamma self.awac_scale = params.awac_scale self.cql_scale = params.cql_scale self.two_qs = params.two_qs if self.two_qs: self.q2_head = make_head(self.n_embd, self.vocab_size) self.target_q2_head = deepcopy(self.q2_head) self.target_q2_head.requires_grad_(False)
[docs] def forward( self, input_ids, attention_mask=None, position_ids=None, past_key_values=None, actions_ixs=None, states_ixs=None, ): out = self.gpt.transformer( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, ) hs = out.last_hidden_state if states_ixs is not None: states_hs = hs.gather( dim=1, index=states_ixs.unsqueeze(-1).repeat(1, 1, hs.shape[-1]) ) actions_hs = hs.gather( dim=1, index=actions_ixs.unsqueeze(-1).repeat(1, 1, hs.shape[-1]) ) else: states_hs = actions_hs = hs if self.two_qs: qs = (self.q1_head(actions_hs), self.q2_head(actions_hs)) target_qs = ( self.target_q1_head(actions_hs), self.target_q2_head(actions_hs), ) else: qs = self.q1_head(actions_hs) target_qs = self.target_q1_head(actions_hs) logits = self.gpt.lm_head(hs) vs = self.v_head(states_hs) return logits, qs, target_qs, vs, out.past_key_values
def _sync_target_q_heads(self, alpha): for target_param, copy_param in zip( self.target_q1_head.parameters(), self.q1_head.parameters() ): (alpha * + (1.0 - alpha) * ) if self.two_qs: for target_param, copy_param in zip( self.target_q2_head.parameters(), self.q2_head.parameters() ): (alpha * + (1.0 - alpha) * ) def sync_target_q_heads(self): if os.environ.get("DEEPSPEED_ZERO_STAGE", "0") == "3": params = chain( self.q1_head.parameters(), self.target_q1_head.parameters(), self.q2_head.parameters() if self.two_qs else [], self.target_q2_head.parameters() if self.two_qs else [], ) with, modifier_rank=0): if deepspeed.comm.get_rank() == 0: self._sync_target_q_heads(self.alpha) else: self._sync_target_q_heads(self.alpha)
[docs] def generate( self, input_ids, attention_mask=None, position_ids=None, past_key_values=None, beta=1, max_length=32, temperature=1, top_k=20, logit_mask=None, logs=True, pad_token_id=50256, eos_token_id=50256, ): """ Generates samples akin to hf's `.generate` but with custom logp prepossessing: changing token probabilities as to how advantageous they would be according to value functions estimations. """ if attention_mask is None: attention_mask = input_ids.not_equal(pad_token_id) if position_ids is None: position_ids = attention_mask.cumsum(-1) - 1 position_ids.masked_fill_(attention_mask.eq(0), 0) samples = input_ids.clone() tensors = defaultdict(list) n_new_tokens = max_length - input_ids.shape[1] finished = torch.zeros( input_ids.shape[0], 1, dtype=torch.long, device=input_ids.device ) for _ in range(n_new_tokens): out = self.forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, ) logits, _, target_qs, vs, past_key_values = out if self.two_qs: qs = torch.minimum(target_qs[0][:, -1, :], target_qs[1][:, -1, :]) else: qs = target_qs[:, -1, :] logits = logits[:, -1, :] vs = vs[:, -1, :] if logit_mask is not None: logits[torch.where(logit_mask[input_ids[:, -1].squeeze()])] = -np.inf adv = qs - vs pi_beta = F.log_softmax(logits, -1) pi_top_k = topk_mask(pi_beta + beta * adv, top_k) pi = F.softmax(pi_top_k / temperature, -1) input_ids = torch.multinomial(pi, num_samples=1) input_ids = (1 - finished) * input_ids + finished * eos_token_id finished = (input_ids == eos_token_id).long() samples = torch.hstack((samples, input_ids)) attention_mask = torch.hstack( (attention_mask, (input_ids != eos_token_id).long()) ) position_ids = (position_ids[:, -1] + 1).view(-1, 1) if logs: tensors["qs"].append(qs) tensors["vs"].append(vs) tensors["adv"].append(adv) tensors["pi"].append(pi) if torch.all(finished): break stats = {} for name, xs in tensors.items(): xs = torch.vstack(xs) xs = torch.where(torch.isfinite(xs), xs, 0) stats.update( { f"tensors/{name}/{beta}": wandb.Histogram( xs.cpu().float().view(-1) ), } ) return samples, stats
@property def dummy_inputs(self): return {"input_ids": torch.ones(1, 1, device=self.gpt.device, dtype=torch.long)} @property def device(self): return self.gpt.device