Source code for trlx.trainer.accelerate_sft_trainer

from dataclasses import dataclass

from transformers import AutoModelForCausalLM, PretrainedConfig

from trlx.data.configs import TRLConfig
from trlx.data.method_configs import MethodConfig, register_method
from trlx.pipeline.offline_pipeline import (
    DialogStore,
    PromptPipeline,
    tokenize_dialogue,
)
from trlx.trainer import register_trainer
from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer


@dataclass
@register_method
class SFTConfig(MethodConfig):
    """
    Config for SFT training

    :param gen_kwargs: kwargs for generation
    :type gen_kwargs: Dict[str, Any]
    """

    gen_kwargs: dict


[docs]@register_trainer class AccelerateSFTTrainer(AccelerateRLTrainer): def __init__(self, config: TRLConfig, **kwargs): super().__init__(config, **kwargs) self.generate_kwargs = dict( config.method.gen_kwargs, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id, )
[docs] def get_arch(self, config): from_fn = AutoModelForCausalLM.from_pretrained if issubclass(type(config.model.model_path), PretrainedConfig): from_fn = AutoModelForCausalLM.from_config model = from_fn(config.model.model_path, **config.model.model_extra_configs) if config.model.peft_config is not None: # Initialize the peft adapter import peft peft_config = config.model.peft_config if not isinstance(peft_config, peft.PeftConfig): if isinstance(peft_config, dict): peft_config = peft.get_peft_config(peft_config) else: raise ValueError("`peft_config` should be an instance of `peft.PeftConfig` or a dict.") model = peft.get_peft_model(model, peft_config) if self.accelerator.is_main_process: model.print_trainable_parameters() return model
[docs] def loss(self, batch): if "labels" in batch: labels = batch.labels.clone() else: labels = batch.input_ids.clone() labels[~batch.attention_mask.bool()] = -100 loss = self.model(input_ids=batch.input_ids, attention_mask=batch.attention_mask, labels=labels).loss stats = {"loss": loss.item()} return loss, stats
[docs] def create_train_dataloader(self): return self.accelerator.prepare(self.store.create_loader(self.config.train.batch_size))
[docs] def prepare_learning(self): self.train_dataloader = self.create_train_dataloader() eval_dataloader = self.eval_pipeline.create_loader(self.config.train.batch_size) ( self.model, self.opt, self.eval_dataloader, ) = self.accelerator.prepare(self.model, self.opt, eval_dataloader) self.n_inner_epochs = 1 self.total_steps = self.config.train.epochs * len(self.train_dataloader) self.total_steps = min(self.total_steps, self.config.train.total_steps)
def make_experience(self, samples, seq_length): if isinstance(samples[0], str): self.store = PromptPipeline(samples, seq_length, self.tokenizer) else: dialogs = [tokenize_dialogue(d, self.tokenizer, seq_length) for d in samples] self.store = DialogStore(dialogs, self.tokenizer)