Orchestrators

Orchestrators manage reading data from a pipeline and creating RL data elements (i.e. trlx.data.RLElement) to push to a models rollout storage. Use the trlx.orchestrator.register_orchestrator decorator when creating new orchestrators.

General

class trlx.orchestrator.Orchestrator(pipeline: trlx.pipeline.BasePipeline, rl_model: trlx.model.BaseRLModel)[source]
abstract make_experience()[source]

Draw from pipeline, get action, generate reward Push to models RolloutStorage

PPO

class trlx.orchestrator.ppo_orchestrator.PPOOrchestrator(model: trlx.model.BaseRLModel, pipeline: trlx.pipeline.BasePipeline, reward_fn: Callable, metric_fn: Optional[Callable] = None, chunk_size: int = 512)[source]

Orchestrator that prepares data for PPO training: transforms samples from pipeline into PPOBatch and pushes them into model’s store

make_experience(num_rollouts: int = 1024, iter_count: int = 0)[source]

Takes num_rollouts prompts from pipeline, samples model, computes KL againts a reference model appends PPOElements to model’s store

score(samples)[source]

Batched scoring function taking text and generating scalar

ILQL

class trlx.orchestrator.offline_orchestrator.OfflineOrchestrator(model, split_token=None)[source]

Orchestrator that creates a static dataset for offline training

make_experience(samples, rewards)[source]

Tokenizes samples and shapes rewards into proper tensors and then inserts the resulting dataset into the model