REINFORCE++#
Note
We present REINFORCE++, an enhanced variant of the classical REINFORCE algorithm that incorporates key optimization techniques from PPO while eliminating the need for a critic network.
Background: The REINFORCE Algorithm#
The algorithm operates as follows:
Trajectory Sampling.
Return Calculation: The discounted cumulative rewards for each trajectory are computed as:
where \(\gamma\) is the discount factor.
Policy Gradient Estimation: The gradient of the expected return with respect to the policy parameters is estimated using:
Tip
Don’t Let the Past Distract You: using reward to go instead of sum of all rewards.
Policy Update: The policy parameters are updated via gradient ascent:
where \(\alpha\) is the learning rate.
Despite its simplicity, REINFORCE suffers from high variance in gradient estimates, which can hinder its scalability to complex tasks such as aligning LLMs.
REINFORCE++ Enhancements#
Token-Level KL Penalty#
We implement a token-level KL divergence penalty between the RL model and the SFT model distributions. This penalty is incorporated into the reward function as follows:
where:
\(x\) represents the input prompt
\(y\) denotes the generated response
\(\mathbf{I}(s_t=[EOS])\) indicates whether \(t\) is the final token
\(\beta\) is the KL penalty coefficient
This approach facilitates better credit assignment.
PPO-Clip Integration#
We incorporate PPO’s clipping mechanism to constrain policy updates:
where:
\(r_{t}(\theta) = \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}\) is the probability ratio of taking action \(a_t\) in state \(s_t\) under the new policy versus the old policy.
\(\hat{A}_{t}\) is the estimated advantage for token \(t\).
\(\mbox{clip}(r_{t}(\theta), 1-\epsilon, 1 + \epsilon)\) restricts the probability ratio to be within the range of \([1-\epsilon,1+\epsilon]\), where \(\epsilon\) is a small hyperparameter.
Advantage Normalization#
The advantage function in REINFORCE++ is defined as:
We normalize these advantages using z-score normalization:
where \(\mu_A\) and \(\sigma_{A}\) represent the batch mean and standard deviation respectively. Normalization ensures stable gradients and prevents divergence during training.
Reward Normalization and Clipping#
We implement comprehensive reward processing to stabilize training:
Normalization: Standardizes rewards using z-score normalization to mitigate outliers.
Clipping: Constrains reward values within predefined bounds to avoid instability.
Scaling: Applies appropriate scaling factors for numerical stability during updates.
Caution
Detail of scailing?
Mini-Batch Updates#
To enhance training efficiency, we implement mini-batch updates.