Understanding AI applications in bio for machine learning engineers
AlphaFold 2 and BERT were both developed in the cradle of Google’s deeply lined pockets in 2018 (albeit by different departments: DeepMind and Google AI). They represented huge leaps forward in state-of-the-art models for natural language processing (NLP) and biology respectively. For BERT, this meant topping the leaderboard on benchmarks like GLUE (General Language Understanding Evaluation) and SQuAD (Stanford Question Answering Dataset). For AlphaFold 2 (hereafter just referred to as AlphaFold), it meant achieving near-experimental accuracy in predicting 3D protein structures. In both cases, these advancements were largely attributed to the use of transformer architecture and the self-attention mechanism.
I expect most machine learning engineers have a cursory understanding of how BERT or Bidirectional encoder representations from transformers work with language but only a vague metaphorical understanding of how the same architecture is applied to the field of biology. The purpose of this article is to explain the concepts behind the development and success of AlphaFold through the lens of how they compare and contrast to BERT.
Forewarning: I am a machine learning engineer and not a biologist, just a curious person.
BERT Primer
Before diving into protein folding, let’s refresh our understanding of BERT. At a high level, BERT is trained by masked token prediction and next-sentence prediction.
BERT falls into the sequence model family. Sequence models are a class of machine learning models designed to handle and make sense of sequential data where the order of the elements matters. Members of the family include Recurrent Neural Nets (RNNs), LSTMs (Long Short Term Memory), and Transformers. As a Transformer model (like its more famous relative, GPT), a key unlock for BERT was how training could be parallelized. RNNs and LSTMs process sequences sequentially, which slows down training and limits the applicable hardware. Transformer models utilize the self-attention mechanism which processes the entire sequence in parallel and allows training to leverage modern GPUs and TPUs, which are optimized for parallel computing.
Processing the entire sequence at once not only decreased training time but also improved embeddings by modeling the contextual relationships between words. This allows the model to better understand dependencies, regardless of their position in the sequence. A classic example illustrates this concept: “I went fishing by the river bank” and “I need to deposit money in the bank.” To readers, bank clearly represents two distinct concepts, but previous models struggled to differentiate them. The self-attention mechanism in transformers enables the model to capture these nuanced differences. For a deeper dive into this topic, I recommend watching this Illustrated Guide to Transformers Neural Network: A step by step explanation.
One reason RNNs and LSTMs struggle is because they are unidirectional i.e. they process a sentence from left to right. So if the sentence was rewritten “At the bank, I need to deposit money”, money would no longer clarify the meaning of bank. The self-attention mechanism eliminates this fragility by allowing each word in the sentence to “attend” to every other word, both before and after it making it “bidirectional”.
AlphaFold and BERT Comparison
Now that we’ve reviewed the basics of BERT, let’s compare it to AlphaFold. Like BERT, AlphaFold is a sequence model. However, instead of processing words in sentences, AlphaFold’s inputs are amino acid sequences and multiple sequence alignments (MSAs), and its output/prediction is the 3D structure of the protein.
Let’s review what these inputs and outputs are before learning more about how they are modeled.
First input: Amino Acid Sequences
Amino acid sequences are embedded into high-dimensional vectors, similar to how text is embedded in language models like BERT.
Reminder from your high school biology class: the specific sequence of amino acids that make up a protein is determined by mRNA. mRNA is transcribed from the instructions in DNA. As the amino acids are linked together, they interact with one another through various chemical bonds and forces, causing the protein to fold into a unique three-dimensional structure. This folded structure is crucial for the protein’s function, as its shape determines how it interacts with other molecules and performs its biological roles. Because the 3D structure is so important for determining the protein’s function, the “protein folding” problem has been an important research problem for the last half-century.
Before AlphaFold, the only reliable way to determine how an amino acid sequence would fold was through experimental validation through techniques like X-ray crystallography, NMR spectroscopy (nuclear magnetic resonance), and Cryo-electron microscopy (cryo-EM). Though accurate, these methods are time-consuming, labor-intensive, and expensive.
So what is an MSA (multiple sequence alignment) and why is it another input into the model?
Second input: Multiple sequence alignments, represented as matrices in the model.
Amino acid sequences contain the necessary instructions to build a protein but also include some less important or more variable regions. Comparing this to language, I think of these less important regions as the “stop words” of protein folding instructions. To determine which regions of the sequence are the analogous stop words, MSAs are constructed using homologous (evolutionarily related) sequences of proteins with similar functions in the form of a matrix where the target sequence is the first row.
Similar regions of the sequences are thought to be “evolutionarily conserved” (parts of the sequence that stay the same). Highly conserved regions across species are structurally or functionally important (like active sites in enzymes). My imperfect metaphor here is to think about lining up sentences from Romance languages to identify shared important words. However, this metaphor doesn’t fully explain why MSAs are so important for predicting the 3D structure. Conserved regions are so critical because they allow us to detect co-evolution between amino acids. If two residues tend to mutate in a coordinated way across different sequences, it often means they are physically close in the 3D structure and interact with each other to maintain protein stability. This kind of evolutionary relationship is difficult to infer from a single amino acid sequence but becomes clear when analyzing an MSA.
Here is another place where the comparison of natural language processing and protein folding diverges; MSAs must be constructed and researchers often manually curate them for optimal results. Biologists use tools like BLAST (Basic Local Alignment Search Tool) to search their target sequences to find “homologs” or similar sequences. If you’re studying humans, this could mean finding sequences from other mammals, vertebrates, or more distant organisms. Then the sequences are manually selected considering things like comparable lengths and similar functions. Including too many sequences with divergent functions degrades the quality of the MSA. This is a HUGE difference from how training data is collected for natural language models. Natural language models are trained on huge swaths of data that are hovered up from anywhere and everywhere. Biology models, by contrast, need highly skilled and contentious dataset composers.
What is being predicted/output?
In BERT, the prediction or target is the masked token or next sentence. For AlphaFold, the target is the 3D structure of the protein, represented as the 3D coordinates of protein atoms, which defines the spatial arrangement of amino acids in a folded protein. Each set of 3D coordinates is collected experimentally, reviewed, and stored in the Protein Data Bank. Recently solved structures serve as a validation set for evaluation.
How are the inputs and outputs tied together?
Both the target sequence and MSA are processed independently through a series of transformer blocks, utilizing the self-attention mechanism to generate embeddings. The MSA embedding captures evolutionary relationships, while the target sequence embedding documents local context. These contextual embeddings are then fed into downstream layers to predict pairwise interactions between amino acids, ultimately inferring the protein’s 3D structure.
Within each sequence, the pairwise residue (the relationship or interaction between two amino acids within a protein sequence) representation predicts spatial distances and orientations between acids, which are critical for modeling how distant parts of the protein come into proximity when folded. The self-attention mechanism allows the model to account for both local and long-range dependencies within the sequence and MSA. This is important because when a sequence is folded, residues that are far from each other in a sequence may end up close to each other spatially.
The loss function for AlphaFold is considerably more complex than the BERT loss function. BERT faces no spatial or geometric constraints and its loss function is much simpler because it only needs to predict missing words or sentence relationships. In contrast, AlphaFold’s loss function involves multiple aspects of protein structure (distance distributions, torsion angles, 3D coordinates, etc.), and the model optimizes for both ****geometric and spatial predictions. This component heavy loss function ensures that AlphaFold accurately captures the physical properties and interactions that define the protein’s final structure.
While there is essentially no meaningful post-processing required for BERT predictions, predicted 3D coordinates are reviewed for energy minimization and geometric refinement based on the physical principles of proteins. These steps ensure that predicted structures are physically viable and biologically functional.
Conclusion
- AlphaFold and BERT both benefit from the transformer architecture and the self-attention mechanism. These improvements improve contextual embeddings and faster training time with GPUs and TPUs.
- AlphaFold has a much more complex data preparation process than BERT. Curating MSAs from experimentally derived data is harder than vacuuming up a large corpus of text!
- AlphaFold’s loss function must account for spatial or geometric constraints and it’s much more complex than BERT’s.
- AlphaFold predictions require post-processing to confirm that the prediction is physically viable whereas BERT predictions do not require post-processing.
Thank you for reading this far! I’m a big believer in cross-functional learning and I believe as machine learning engineers we can learn more by challenging ourselves to learn outside our immediate domains. I hope to continue this series on Understanding AI Applications in Bio for Machine Learning Engineers throughout my maternity leave. ❤
AlphaFold 2 Through the Context of BERT was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.
Originally appeared here:
AlphaFold 2 Through the Context of BERT
Go Here to Read this Fast! AlphaFold 2 Through the Context of BERT