A look at the “Direct Preference Optimization:
Your Language Model is Secretly a Reward Model” paper and its findings
This blog post was inspired by a discussion I recently had with some friends about the Direct Preference Optimization (DPO) paper. The discussion was lively and went over many important topics in LLMs and Machine Learning in general. Below is an expansion on some of those ideas and the concepts discussed in the paper.
Direct Preference Optimization (DPO) has become the way that new foundation models are fine-tuned. Famously Mixtral 8x7B, the Sparse Mixture of Experts model created by Mistral, was able to reach LLaMa 70B levels of performance with significantly fewer parameters by using DPO. Naturally, this success has led many in the community to begin fine-tuning their own models with DPO.
Let’s dive into what exactly DPO is and how we got here.
High Level Discussion
Let’s begin with setting out what fine-tuning should do from a high level. Once you have a pre-trained a model to have strong generative capacities, you typically want to control its output somehow. Whether that be optimizing it to respond in dialogue as a chat-bot or to respond in code rather than English, the goal here is to take an LLM that is already functional and find a way to be more selective with its output. As this is machine learning, the way we show it the right behavior is with data.
There are some key terms here I’ll define before we start diving into the technicals:
Loss Function — a function we use as a guide to optimize performance of our model. This is chosen based on what has been found to be effective
KL Divergence— stands for Kullback–Leibler divergence, which is a way to measure the difference between two continuous probability distributions. To learn more about this, there is a wonderful post by Aparna Dhinakaran on the topic.
Policy — an abstraction that describes how a neural network will make decisions. Put a different way, if a neural network is trained 3 times, each time it will have a different policy, whose performances you can compare.
The Status Quo before DPO (PPO)
Before DPO, we used to have to train an entirely separate model to help us fine-tune, typically called the reward model or RLHF model. We would sample completions from our LLM and then have the reward model give us a score for each completion. The idea here was simple. Humans are expensive to have evaluate your LLMs outputs but the quality of your LLM will ultimately be determined by humans. To keep costs down and quality high, you would train the reward model to approximate the human’s feedback. This is why the method was called Proximal Policy Optimization (or PPO), and it lives or dies based on the strength of your reward model.
The Math behind PPO
To find the ideal reward model, we assume human preferences are more probabilistic than deterministic, so we can represent this symbolically in the Bradley-Terry model like below.
Going variable by variable, p* means that this is the optimal probability distribution, or the one the model should treat as the source of truth. y₁ and y₂ are 2 completions from the model that we are going to compare, and x is the prompt given to LLM. r* means that the reward function is optimal, or put another way, to train the model to approximate the optimal probability distribution, you give it the rewards from the optimal reward function.
Nevertheless, the perfect probability distribution of human preference is difficult, if not impossible, to know. For this reason, we focus on the reward model , so we need to find a way to figure out r*. In machine learning, we often use loss minimization to estimate complex issues. If we have access to training data that shows us what human preferences truly are, and thus would give scores that are part of the p* distribution, then we can use those samples to train the reward model like below:
Here rϕ is the rewards model we are training, D is a set of the samples we are training on, yw is the preferred completion and yl is the dispreferred completion. The authors have chosen to frame the problem as a binary-classification problem, which we will see why later on, but for now just remember this is why we have yw and yl.
Once we have optimized our reward model, we use it to fine-tune the LLM using a difference between the old policy (π ref) and the new policy (π θ). Importantly, we are doing a KL divergence to prevent the model from shifting too much.
Why don’t we want it shifting too much? Remember the model is already mostly functional, and it has taken quite a lot of compute resources to reach this level. Consequently, we want to make sure the model retains many of the good traits it currently has while we focus on having it follow instructions better.
While the above methodology is effective — LLaMa2 for instance was fine-tuned this way — it has a one major weakness: it requires training an entirely separate model, which is costly and requires huge amounts of additional data.
How does DPO improve on this?
DPO removes the need for the rewards model all together! This allows us to avoid training a costly separate reward model and incidentally, we have found that DPO requires a lot less data to work as well as PPO.
The Math behind DPO
The major leap stems from the KL constraint we placed on ourselves in equation 3. By adding this constraint, we can actually derive the ideal policy that will maximize a KL-constrained rewards model. The algebra is shown below:
For our purposes, the most important point to take away is that we now have the below equation for a policy π r, such that the reward function r is easily solved for.
Naturally, we immediately solve for r
Returning to our ideal probability distribution equation (equation 1), we can rewrite that so that each instance of r is replaced by equation 5.
What this has shown is that you don’t need the reward model to optimize the policy to follow the ideal probability distribution of human preferences. Instead, you can directly work on the policy to improve it (hence where Direct Preference optimization gets its name from). We are using the probabilities that your LLM generates for each token to help it fine-tune itself.
To finish the derivation, we do the same math as we did in equation 3 to come up with our loss optimizing function to optimize for the policy.
That was a lot of algebra, but equation 7 is the most important one to understand, so I’ll break down the most important pieces. We now have an equation which will compare the policy probabilities of the old policy (π ref) and the new policy (π θ) for a winning completion (yw) and a losing completion (yl). When we compare these, we are optimizing so that that yw is bigger, as this would mean that the policies are getting better at giving winning responses than losing responses.
Consequences
First, DPO does not require a reward model! You simply need high quality data so that the model has a clear direction of what is good and bad, and it will improve.
Second, DPO is dynamic. Every time you use new data, it is going to adapt immediately thanks to the way it figures out the right direction to go. Compared to PPO, where you have to retrain your reward model each time you have new data, this is a big win.
Third, DPO allows you to train a model to avoid certain topics just as much as it will learn to give good answers for others. One way to conceptualize the new loss equation is as a signal that points our training in the right direction. By using both a good and bad example, we are teaching the model to avoid certain responses as much as we tell them to go towards others. As a large part of fine-tuning involves the model ignoring certain subjects, this feature is very valuable.
Closing Thoughts
Understanding the consequences of DPO’s math make me more optimistic about the future of LLMs.
DPO requires less data and compute than PPO, both of which are major contributors to the cost of making your own model. With this cost reduction, more people will be able to fine-tune their own models, potentially giving society access to more specialized LLMs.
Moreover, as DPO explicitly requires good and bad examples, while PPO only asks for good ones, it is much better at restricting behavior. This means that LLMs can be made far safer, another piece that will allow them to help out society.
With forces like DPO giving us access to better quality LLMs that can be more easily trained, it is an incredibly exciting time for this field.
[1] R. Rafailov, et al., Direct Preference Optimization: Your Language Model is Secretly a Reward Mode (2023), arXiv
[2] A. Jiang, et al., Mixtral of Experts (2024), ArXiv
Understanding Direct Preference Optimization was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.
Originally appeared here:
Understanding Direct Preference Optimization
Go Here to Read this Fast! Understanding Direct Preference Optimization