Generalized Advantage Estimation

Generalized Advantage Estimation#

Note

Generalized Advantage Estimation (GAE) described an advantage estimator, with two separate parameters \(\gamma\) and \(\lambda\), both of which contribute to the bias-variance tradeoff when using an approximate value function.

Preliminaries#

There are several different related expressions for the policy gradient, which have the form

\[g = \mathbb{E}\left[\sum_{t=0}^{\infty}\Psi_{t}\nabla_{\theta}\log\pi_{\theta}(a_t|s_t)\right]\]

where \(\Psi_{t}\) may be one of the following:

  1. \(\sum_{t=0}^{\infty}r_t\): total reward of the trajectory.

  2. \(\sum_{t'=t}^{\infty}r_{t'}\): reward following action \(a_t\).

  3. \(\sum_{t'=t}^{\infty}r_{t'} - b(s_t)\): baselined version of previous formula.

  4. \(Q_{\pi}(s_t, a_t)\): state-action value function.

  5. \(A_{\pi}(s_t, a_t)\): advantage function.

  6. \(r_{t} + V_{\pi}(s_{t+1}) - V_{\pi}(s_t)\): TD residual.

The choice \(\Psi_{t} = A_{\pi}(s_t, a_t)\) yields almost the lowest possible variance, though in practice, the advantage function is not known and must be estimated.

We will introduce a parameter \(\gamma\) that allows us to reduce variance by downweighting rewards corresponding to delayed effects, at the cost of introducing bias. This parameter corresponds to the discount factor used in discounted formulations of MDPs, , but we treat it as a variance reduction parameter in an undiscounted problem.

\[\begin{split} \begin{aligned} V_{\pi,\gamma}(s_t) :=& \mathbb{E}_{s_{t+1}:\infty, a_{t}:\infty}\left[\sum_{l=0}^{\infty}\gamma^{l}r_{t+l}\right]\\ Q_{\pi,\gamma}(s_t, a_t) :=& \mathbb{E}_{s_{t+1}:\infty, a_{t+1}:\infty}\left[\sum_{l=0}^{\infty}\gamma^{l}r_{t+l}\right]\\ A_{\pi,\gamma}(s_t, a_t) :=& Q_{\pi,\gamma}(s_t, a_t) - V_{\pi,\gamma}(s_t) \end{aligned} \end{split}\]

The colon notation \(a: b\) refers to the inclusive range \((a, a+1,\dots, b)\). The discounted approximation to the policy gradient is defined as follows:

\[g^{\gamma} := \mathbb{E}_{s_0:\infty,a_0:\infty}\left[\sum_{t=0}^{\infty}A_{\pi,\gamma}(a_t, s_t)\nabla_{\theta}\log\pi_{\theta}(a_t|s_t)\right]\]

Before proceeding, we will introduce the notion of a \(\gamma-\)just estimator of the advantage function, which is an estimator that does not introduce bias when we use it in place of \(A^{\pi,\gamma}\).

Definition 1. The estimator \(\hat{A}_{t}\) is \(\gamma-\)just if

\[\mathbb{E}_{s_0:\infty,a_0:\infty}\left[\hat{A}_{t}\nabla_{\theta}\log\pi_{\theta}(a_t|s_t)\right]=\mathbb{E}_{s_0:\infty,a_0:\infty}\left[A_{\pi,\gamma}(a_t, s_t)\nabla_{\theta}\log\pi_{\theta}(a_t|s_t)\right] = g^{\gamma}\]

We can verify that the following expressions are \(\gamma\)-just advantage estimators for \(\hat{A}_t\):

  • \(\sum_{l=0}^{\infty}\gamma^{l}r_{t+l}\)

  • \(Q_{\pi,\gamma}(s_t, a_t)\)

  • \(A_{\pi,\gamma}(s_t, a_t)\)

  • \(r_t + \gamma V_{\pi,\gamma}(s_{t+1}) - V_{\pi,\gamma}(s_{t})\)

Advantage function estimation#

This section will be concerned with producing an accurate estimate \(\hat{A}_{t}\) which will then be used to construct a policy gradient estimator of the following form:

\[\hat{g} = \frac{1}{N}\sum_{n=1}^{N}\sum_{t=0}^{T}\hat{A}_{t}^{n}\nabla_{\theta}\log\pi_{\theta}(a_{t}^{n}|s_{t}^{n})\]

where \(n\) indexes over a batch of episodes.

Let \(V\) be an approximate value function. Define

\[\delta_{t}^{V} = r_t + \gamma V(s_{t+1}) - V(s_t)\]

i.e., the TD residual. \(\delta_{t}^{V}\) can be considered as an estimate of the advantage of the action \(a_t\). In fact, if we have the correct value function \(V=V_{\pi,\gamma}\), then it is a \(\gamma\)-just advantage estimator. However, this estimator is only \(\gamma\)-just for \(V=V_{\pi,\gamma}\), otherwise it will yield biased policy gradient estimates.

Next, let:

\[\begin{split} \begin{aligned} \hat{A}_{t}^{(1)} &:= \delta_{t}^{V} &&=-V(s_t) + r_t + \gamma V(s_{t+1})\\ \hat{A}_{t}^{(2)} &:= \delta_{t}^{V} + \gamma\delta_{t+1}^{V} &&=-V(s_t) + r_t + \gamma r_{t+1} + \gamma^{2} V(s_{t+2})\\ \hat{A}_{t}^{(k)} &:= \sum_{l=0}^{k-1}\gamma^{l}\delta_{t+l}^{V} &&=-V(s_t) + r_t + \gamma r_{t+1} + \dots + \gamma^{k-1}r_{t+k-1} + \gamma^{k} V(s_{t+k})\\ \end{aligned} \end{split}\]

We can consider \(\hat{A}_{t}^{(k)}\) to be an estimator of the advantage function, which is only \(\gamma\)-just when \(V=V_{\pi,\gamma}\). However, note that the bias generally becomes smaller as \(k\to\infty\), since the term \(\gamma^{k}V(s_{t+k})\) becomes more heavily discounted, and the term \(-V(s_t)\) does not affect the bias.

The generalized advantage estimator \(\text{GAE}(\gamma, \lambda)\) is defined as the exponentially-weighted average of these \(k\)-step estimators:

\[\begin{split} \begin{aligned} \hat{A}_{t}^{\text{GAE}(\gamma, \lambda)} &:= (1 - \lambda)\left(\hat{A}_{t}^{(1)} + \lambda\hat{A}_{t}^{(2)} + \lambda^{2}\hat{A}_{t}^{(3)} + \dots\right) \\ &= \sum_{l=0}^{\infty}(\gamma\lambda)^{l}\delta_{t+l}^{V} \end{aligned} \end{split}\]

There are two notable special cases of this formula, obtained by setting \(\lambda=0\) and \(\lambda=1\).

\[\begin{split} \begin{aligned} &\text{GAE}(\gamma, 0):\hat{A}_{t} = \delta_{t} &&= r_t + \gamma V(s_{t+1}) - V(s_t) \\ &\text{GAE}(\gamma, 1):\hat{A}_{t} = \sum_{l=0}\gamma^{l}\delta_{t+l} &&= \sum_{l=0}^{\infty}\gamma^{l}r_{t+l} - V(s_t) \end{aligned} \end{split}\]

Tip

\(\text{GAE}(\gamma, 1)\) is \(\gamma-\)just regardless of the accuracy of \(V\), but it has high variance due to the sum of terms. \(\text{GAE}(\gamma, 0)\) is \(\gamma-\)just for \(V=V_{\pi,\gamma}\) and otherwise induces bias, but it typically has much lower variance. The generalized advantage estimator for \(0<\lambda<1\) makes a compromise between bias and variance, controlled by parameter \(\lambda\).
Taking \(\gamma<1\) introduces bias into the policy gradient estimate, regardless of the value function’s accuracy.