This is a short primer on Off-Dynamics Reinforcement Learning. I also talk a little bit about density ratio estimation.
What is this Off-Dynamics Reinforcement Learning
you speak of?
The Off-Policy1 setting is cool, but let us consider operating in an environment where it is really hard to collect data. Suppose that collecting data is really expensive, or takes too long. It becomes rather hard to solve this setting using traditional Off-Policy methods.
In such settings, your friendly neighborhood engineer is going to suggest training inexpensively on a simulator.
That’s a great idea! And like all great ideas, it has a few drawbacks.
For one, I doubt we’re ever going to have perfect simulators. The dream of running behaviour policies in simulation to obtain good enough target policies is likely to remain just that.
Unless we solve this problem with a little something called Off-Dynamics Reinforcement Learning.
Let us assume that we will never be able to perfectly model the real world dynamics for tasks that we care about.
This forces us to assume two Markov Decision Processes (MDPs)2 in the treatment of this setting:
\(\mathcal{M}_{\text{source}}\) represents the source domain, which is the simulator, practice domain or the approximate model of the target domain. \(\mathcal{M}_{\text{target}}\) is the target domain.
We assume that both these MDPs have the same state space \(\mathcal{S}\), action space \(\mathcal{A}\), reward function \(r\) and initial state distribution \(p_1(s_1)\).
The only difference between the domains is their dynamics. The dynamics are represented by \(p_{\text{target}} (s_{t+1} \vert s_t, a_t)\) and \(p_{\text{source}} (s_{t+1} \vert s_t, a_t)\).
We also make the assumption of coverage:
The core objective is that we’d like to learn Markovian policy \(\pi_{\ta} (a \vert s)\) that maximises the expected discounted sum of rewards on \(\mathcal{M}_{\text{target}}\):
\[\begin{align*} \E_{\pi_{\ta}, \mathcal{M}_{\text{target}}} \lb \sum_t \y^t r(s_t, a_t) \rb \end{align*}\]We’d like to achieve this objective using mostly inexpensive interactions in the source MDP (\(\mathcal{M}_{\text{source}}\))
and a small number of interactions in the target MDP \(\mathcal{M}_{\text{target}}\).
Our final policy should obtain near optimal policies in the target MDP, \(\mathcal{M}_{\text{target}}\). We don’t really care for optimal policies in the source MDP(s).
A recent paper (Eysenbach, et al. 2020)3 did an excellent job of attacking learning in this setting. They called it Off-Dynamcics Reinforcement Learning, and I like that name.
It’s a really cool paper, and I highly recommend reading it.
This post is going to walk through a few bits from this paper that I found interesting, and I’ll try to share some of the gotchas that leapt out at me after a lot of staring at a whiteboard and more than a little muttering.
Achieving our objective:
Consider the probabilistic inference interpretation of RL (Levine 2018)4. Here, we view the reward function as a desired distribution over trajectories. The agent samples from this distribution of trajectories by picking trajectories with probability proportional to their exponentiated reward.
For our treatment, let us define \(p(\tau)\) to be the desired distribution over trajectories in the target domain:
\[\begin{align*} p(\tau) \propto p_1(s_1) \lp \prod_t p_{\text{target}}(s_{t+1} \vert s_t, a_t) \rp \exp \lp \sum_t r(s_t, a_t) \rp \end{align*}\]We’d like our policy \(\pi_\ta\) to pick the best trajectories ( maximize the expected reward ) in this distribution.
Now consider our agent’s distribution over trajectories in the source domain. Let’s call it \(q(\tau)\):
\[\begin{align*} q(\tau) = p_1 (s_1) \prod_t p_{\text{source}} (s_{t+1} \vert s_t, a_t ) \pi_\theta (a_t \vert s_t ) \end{align*}\]\(q(\tau)\) is parameterized by \(\ta\).
Minimizing the reverse KL divergence5 between these two distributions will lead to achieving the objective we set up.
\[\begin{align*} D_{\text{KL}} (q(\tau) \vert \vert p(\tau) ) = &- \underset{\tau \sim q(\tau)}{\E} \lb \log p(\tau) - \log q(\tau) \rb \\ = &- \underset{\tau \sim q(\tau)}{\E} \bigg [ \log p_1(s_1) + \sum_{t=1}^T \log p_{\text{target}} (s_{t+1} \vert s_t, a_t) + \sum_t r(s_t, a_t) \\ &- \log p_1(s_1) - \sum_{t=1}^T \log p_{\text{source}} (s_{t+1} \vert s_t, a_t) - \log \pi_\ta (a_t \vert s_t) \bigg ] \\ = &- \underset{\tau \sim q(\tau)}{\E}\bigg [ \sum_t r(s_t, a_t) - \log \pi_\ta (a_t \vert s_t) + \sum_{t=1}^T \log p_{\text{target}} (s_{t+1} \vert s_t, a_t) \\ &- \sum_{t=1}^T \log p_{\text{source}} (s_{t+1} \vert s_t, a_t) \bigg ]\\ = &- \underset{q}{\E} \lb \sum_t r(s_t, a_t) + \mathcal{H}_{\pi}[a_t \vert s_t] + \Delta r (s_{t+1} , s_t, a_t ) \rb \end{align*}\] \[\begin{align*} \underset{\pi (a\vert s), q(s'\vert s,a)}{\min} D_{\text{KL}} (q \vert \vert p) = - \E_q \lb \sum_t r(s_t, a_t) + \mathcal{H}[a_t \vert s_t] + \Delta r (s_{t+1} , s_t, a_t ) \rb + c \end{align*}\]Where \(\mathcal{H}_\pi[a_t \vert s_t] = - \log \pi_{\ta} (a_t \vert s_t)\) and
\[\begin{align*} \Delta r (s_{t+1} , s_t, a_t ) = \log p_{\text{target}} (s_{t+1} \vert s_t, a_t) - \log p_{\text{source}} (s_{t+1} \vert s_t, a_t) \end{align*}\]If there are no differences in dynamics, \(\Delta r=0\).
Looks good eh?
No. The whole point of not being able to have perfect simulators, was that it’s really really hard to learn the transition functions.
This \(\Delta r\) term needs some working on…
Dealing with the \(\Delta r\) term
If we cannot calculate a term exactly, the second best thing we can do is look for a good estimate for this term. Let’s try doing that with our \(\Delta r\) term. On expanding it, we get a fraction of intractable terms: \(\frac{p_{\text{target}} (s_{t+1} \vert s_t, a_t) }{p_{\text{source}} (s_{t+1} \vert s_t, a_t) }\).
Hang on. We can view \(\frac{p_{\text{target}} (s_{t+1} \vert s_t, a_t) }{p_{\text{source}} (s_{t+1} \vert s_t, a_t) }\) as a density ratio6. And there are some cool ways to estimate density ratios using classifiers7. Fortunately we’ve gotten REALLY good at classifying things in the last few years.
In this case, we can use two classifiers looking at data from our replay buffers. One classifier is going to look at tuples of the form \(\langle s_t, a_t, s_{t+1} \rangle\) and the other is going to look at tuples of the form \(\langle s_t, a_t \rangle\)
Now the ratio that we’d like to estimate is given by:
\[\begin{align*} \Delta r &= \log \lb \frac{p_{\text{target}} (s_{t+1} \vert s_t, a_t)}{p_{\text{source}} (s_{t+1} \vert s_t, a_t)} \rb \end{align*}\]Let us pause to define our terms a little more clearly:
\(p_{\text{target}}(s_{t+1} \vert s_t, a_t)\): this is the environment transition probability in the ‘target’ environment. It can also be written as \(p(s_{t+1}\vert s_t, a_t, \text{target})\). This random variable tells us that if we consider \(\mathcal{M}_{\text{target}}\) and our agent takes action \(a_t\) while in state \(s_t\), with probability \(p_{\text{target}}(s_{t+1} \vert s_t, a_t)\) our agent will change its state to \(s_{t+1}\).
\(p(\text{target} \vert s_t, a_t, s_{t+1})\): this is the continuous random variable that represents the output of a binary classifier that has been fed the tuple \(\langle s_t, a_t, s_{t+1} \rangle\)8
Okay, let us focus a little more on the conditional probability arising from our binary classifier that takes the tuple \(\langle s_t, a_t, s_{t+1} \rangle\)
Bayes rule tells us that:
Now, let us take a closer look at the \(p(s_t, a_t, s_{t+1} \vert \text{target})\) term in the RL setting:
Given a state \(s_t\) and action \(a_t\) the probability of getting \(s_{t+1}\) depends on the environment transition probability. Since the label tells us that we are in the target MDP, we consider \(p_{\text{target}}(s_{t+1} \vert s_t, a_t)\)
Where \(p(s_t, a_t \vert \text{target})\) is the probability of the state-action pair \((s_t, a_t)\) occurring in the target MDP.
Once we make this substitution, we get:
This implies that we can make the substitution:
\[\begin{align*} p_{\text{target}} (s_{t+1} \vert s_t, a_t) = \frac{p(\text{target} \vert s_t , a_t, s_{t+1}) p(s_t,a_t,s_{t+1})}{ p(s_t, a_t \vert \text{target}) p(\text{target} )} \end{align*}\]Now when it comes to the \(p(s_t, a_t \vert target)\) term, remember that we have a second classifier that takes in the tuple \(\langle s_t, a_t \rangle\). The output of that classifier is given by \(p(\text{target} \vert s_t, a_t)\). Applying Bayes rule here tells us that:
\[\begin{align*} p(s_t,a_t \vert \text{target}) = \frac{p (\text{target} \vert s_t, a_t) p(s_t, a_t)}{p(\text{target})} \end{align*}\]We’re now ready to deal with the \(\Delta r\) term that we wanted to estimate:
\[\begin{align*} \Delta r &= \log \lb \frac{p_{\text{target}} (s_{t+1} \vert s_t, a_t)}{p_{\text{source}} (s_{t+1} \vert s_t, a_t)} \rb \\ &= \log \lb \frac{p(\text{target}\vert s_t, a_t, s_{t+1}) p(s_t,a_t,s_{t+1}) }{ p(s_t, a_t \vert \text{target}) p(\text{target} ) } \times \frac{ p(s_t, a_t \vert \text{source}) p(\text{source} ) }{p(\text{source}\vert s_t, a_t, s_{t+1}) p(s_t,a_t,s_{t+1}) } \rb \\ &= \log \lb \frac{p(\text{target}\vert s_t, a_t, s_{t+1}) }{ p(s_t, a_t \vert \text{target}) p(\text{target} ) } \times \frac{ p(s_t, a_t \vert \text{source}) p(\text{source} ) }{p(\text{source}\vert s_t, a_t, s_{t+1}) } \rb \\ &= \log \lb \frac{p(\text{target}\vert s_t, a_t, s_{t+1}) p(\text{target} ) }{ p(\text{target} \vert s_t, a_t ) p(s_t, a_t) p(\text{target} ) } \times \frac{ p(\text{source} \vert s_t, a_t) p(s_t, a_t) p(\text{source} ) }{p(\text{source}\vert s_t, a_t, s_{t+1}) p(\text{source} ) } \rb\\ &= \log \lb \frac{p(\text{target}\vert s_t, a_t, s_{t+1}) }{ p(\text{target} \vert s_t, a_t ) } \times \frac{ p(\text{source} \vert s_t, a_t) ) }{p(\text{source}\vert s_t, a_t, s_{t+1}) } \rb\\ &= \log p (\text{target} \vert s_t, a_t, s_{t+1} ) - \log p (\text{target} \vert s_t, a_t) - \log p (\text{source} \vert s_t, a_t, s_{t+1} ) + \log p (\text{source} \vert s_t, a_t) \end{align*}\]Hang on!
Did we just get an estimate for \(\Delta r\) that depends solely on the predictions of our two classifiers?
\(\Delta r (s_t, a_t, s_{t+1}) =\) \(\log p (\text{target} \vert s_t, a_t, s_{t+1} )\) \(-\) \(\log p (\text{target} \vert s_t, a_t)\) \(-\) \(\log p (\text{source} \vert s_t, a_t, s_{t+1} )\) \(+\) \(\log p (\text{source} \vert s_t, a_t)\)
We did!
The \(\text{red}\) terms are the difference in logits from the classifier conditioned on \(\langle s_t, a_t, s_{t+1} \rangle\) while the \(\text{blue}\) terms are the difference in logits from the classifier conditioned on just \(\langle s_t, a_t \rangle\).
This means that we can estimate the ratio of our transition functions, given enough samples from both environments, taken from their replay buffers.
And I think that is absolutely NEAT-O!
Equipped with this substitution for \(\Delta r\), Off-Dynamics Reinforcement Learning should be a cinch.
It also opens up the possibility of learning from different MDPs. Think cheap robot test beds!
Like I said before, Neat-o.
References and Footnotes
-
There are a few distinctions to be made within Reinforcement Learning methods. When addressing the problem of Exporation vs Exploitation, one neat technique is off-policy reinforcement learning, where exploration of the states and action spaces are carried out by a policy called the behaviour policy, while the policy we really care about is obtained greedily, and called our target policy. Sutton and Barto is your friend. ↩
-
Refer to Reinforcement Learning, second edition: An Introduction for a great introduction to all things RL if you aren’t familiar with any terms in this section. ↩
-
Benjamin Eysenbach et al. “Off-Dynamics Reinforcement Learning: Training for Transfer with Domain Classifiers” ICML BIG Workshop ↩
-
Sergey Levine Reinforcement Learning and Control as Probabilistic Inference: Tutorial and Review ↩
-
KL divergence can be tricky. The following post helped me a lot, and is an excellent resource: Dibya Ghosh on KL divergence in Machine Learning ↩
-
A density ratio is a cool way to compare probabilities. Probabilities when left to themselves are often bland and uninteresting. However, comparing probabilities lets us form judgements. Think about it. When you want to compare two numbers, we look either at their difference, or their ratio. The same goes for comparing probability densities. ↩
-
Machine Learning Trick of the Day(7): Density Ratio Trick Shakir Mohamed’s excellent blog ↩
-
A simple way of thinking about ‘target’ or ‘source’ when seen as an event, would be to imagine a label associated with the tuple which indicates the MDP the tuple was pulled from (\(y=\)target). The notation used here drops the ‘\(y=\)’ part. ↩