v0.53.11 — GRPO variants live, PRM live, LongLoRA live, weighted preference
The capstone of the v0.50 / v0.49 / v0.40 deferred-stub debt. Five deep TRL trainer-subclassing items lifted.
Six GRPO objective variants live
apply_variant_loss ships real math kernels for every non-standard variant:
gspo— group-stabilised importance ratiodapo— decoupled asymmetric clipdr_grpo— no length normalizationbnpo— length-normalised PPOtwo_sided— symmetricgrpo_deltacliprft— rejection-sampling fine-tuning
Subclassing is done via make_grpo_trainer_variant (an lru_cache factory over _GRPOTrainerVariant) that overrides compute_loss to route through the kernel, with a defensive fallback to the original loss if TRL renames input attributes.
`task: prm` Process Reward Model — live
PRMTrainerWrapper loads AutoModelForCausalLM + nn.Linear(hidden, 1) as the reward head. make_prm_trainer_class(HF Trainer) factory subclasses Trainer and overrides compute_loss to:
1. Gather hidden states at step-boundary tokens
2. Project to scalars through the linear head
3. MSE-loss against per-step labels
_prepare_prm_dataset tokenizes {prompt, completions, labels} with reserved-token truncation; _build_collator pads with -1 sentinel for missing step positions.
task: prm
base: deepseek-ai/deepseek-math-7b-base
data:
train: prm_steps.jsonl
format: prm # {prompt, completions, labels}GRPOStabilityCallback live
- EMA reference-model update fires post-step:
(1-α) * ref + α * policyper parameter (load_state_dict(strict=False)so LoRA state stays valid) - Bounded-deque replay buffer
- TIS (truncated importance sampling) alert counter
ema_alphasurfaced viastate.log_historysosoup whycan flag instability
LongLoRA S² forward override live
The v0.49.0 schema is now backed by a runtime. shift_heads_for_s2 rolls the second half of attention heads by group_size // 2 along the sequence axis (LongLoRA paper §3.2). LongLoRAForwardOverride is a context manager that monkey-patches every Llama / Mistral / Qwen / Phi attention module's forward, restoring on exit. Idempotent __del__ cleanup and best-effort exception swallow ensure training never crashes from a shape mismatch.
True weighted-sum preference combine
attach_weighted_preference_combine now reads the four logprob tensors from TRL inputs (policy_chosen_logps / policy_rejected_logps / reference_chosen_logps / reference_rejected_logps) and computes each requested loss via the matching kernel:
compute_dpo_termcompute_ipo_termcompute_simpo_termcompute_orpo_term
Then combines via combine_losses(terms, weights). BCO mixed with paired losses is still rejected at validation time. Defence-in-depth fallback to the v0.40.1 primary-loss scaling when TRL renames the logps attributes.
task: preference
training:
preference_loss_weights:
dpo: 0.7
simpo: 0.3Tests
- 8,330 → 8,400 (+75) in
test_v05311.py(54 initial + 21 review-fix coverage gaps from python / code / security / tdd review agents). - Math kernels are real and unit-tested. Full GPU smoke runs documented in the v0.53.11.1 follow-up plan.
See also
- [GRPO Plus reference](/docs/grpo-plus)
- [Long Context](/docs/long-context)
- [Preference Variety](/docs/preference-variety)