Hit the road to super-fast AI/ML development
One of the most critical decisions you will need to make in the development of AI models is the choice of a machine learning development framework. Over the years, many libraries have vied for the lucrative title of “AI developer’s framework of choice”. (Remember Caffe and Theano?) For several years TensorFlow — with its emphasis on high-performing, graph-based computation — appeared to be the runaway leader (as estimated by the author based on mentions in academic papers and the strength of community support). Roughly around the turn of the decade, PyTorch — with its user-friendly Pythonic interface — seemed to have become the unquestionable queen. However, in recent years a new entrant has quickly grown in popularity and can no longer be ignored. With its sights on the coveted crown, JAX aims to maximize the performance of AI model training and inference without compromising the user experience.
In this post we will assess this new framework, demonstrate its use, and share some of our own perspectives on its advantages and drawbacks. Importantly, this post is not intended to be a JAX tutorial. To learn about JAX you are kindly referred to the official documentation and the many online tutorials on ML development with JAX (e.g., here). Although our focus will be on AI model training, it should be noted that JAX has many additional applications in the AI/ML landscape and beyond. There are several high-level ML libraries built on top of JAX. In this post we will use Flax which, as of the time of this writing appears to be the most popular.
Thanks to Ohad Klein and Yitzhak Levi for their contributions to this post.
JAX Under the Hood — XLA Compilation
Let’s get this out in the open straight away: No disrespect to JAX, the real power of JAX comes from its use of XLA compilation. The phenomenal runtime performance demonstrated with JAX, comes from the HW specific optimizations enabled by XLA. And many of the features and functionalities often associated with JAX, such as just-in-time (JIT) compilation and the “functional programming” paradigm, are actually derived from XLA. In fact, XLA compilation is hardly unique to JAX, with both TensorFlow and PyTorch supporting options for using XLA. However, contrary to other popular frameworks, JAX was designed from the bottom up to use XLA. This allows for tight coupling of the design and implementation of their JIT, automatic differentiation (grad), vectorization (vmap), parallelization (pmap), sharding (shard_map), and other features (all of which deserve very much respect), with the underlying XLA library. (For contrast, see this interesting post for a history on the “functionalization” of PyTorch.)
As discussed in a previous post on the topic, the XLA JIT compiler performs a full analysis of the computation graph associated with the model, fuses together the successive tensor operations into single kernels, removes redundant graph components, and outputs machine code that is most optimal for the underlying accelerator. This results in a reduced number of overall machine level operations (FLOPS) per training step, reduced host to accelerator communication overhead (e.g., fewer kernels that need to be loaded into the accelerator), reduced memory footprint, increased utilization of the dedicated accelerator engines, and more.
In addition to the runtime performance optimization, another important feature of XLA is its pluggable infrastructure which enables extending its support to additional AI accelerators. XLA is a part of the OpenXLA project and is being built in collaboration by multiple actors in the field of ML.
At the same time, as detailed in our previous post, the reliance on XLA also implies some limitations and potential pitfalls. In particular, many AI models, including ones with dynamic tensor shapes, may not run optimally in XLA. Special care needs to be taken to avoid graph breaks and graph recompilations. You should also consider the implications on the debuggability of your code.
JAX In Action — Toy Example
In this section we will demonstrate how to train a toy AI model in JAX on a (single) GPU and compare it with PyTorch. Nowadays there are a number of high-level ML development platforms that include backends for multiple ML frameworks. This allows for comparing the performance of JAX with other frameworks. In this section we will use HuggingFace’s Transformers library, which includes PyTorch and JAX implementations of many common Transformer-backed models. More specifically, we will define a Vision Transformer (ViT) backed classification model using the ViTForImageClassification and FlaxViTForImageClassification modules for the PyTorch and JAX implementations, respectively. The code block below contains the model definition:
import torch
import jax, flax, optax
import jax.numpy as jnp
def get_model(use_jax=False):
from transformers import ViTConfig
if use_jax:
from transformers import FlaxViTForImageClassification as ViTModel
else:
from transformers import ViTForImageClassification as ViTModel
vit_config = ViTConfig(
num_labels = 1000,
_attn_implementation = 'eager' # this disables flash attention
)
return ViTModel(vit_config)
Note, that we have chosen to disable the use of flash attention due to the fact that this optimization is implemented for the PyTorch model only (as of the time of this writing).
Since our interest in this post is in runtime performance, we will train our model on a randomly generated dataset. We take advantage of the fact that JAX supports the use of PyTorch dataloaders:
def get_data_loader(batch_size, use_jax=False):
from torch.utils.data import Dataset, DataLoader, default_collate
# create dataset of random image and label data
class FakeDataset(Dataset):
def __len__(self):
return 1000000
def __getitem__(self, index):
if use_jax: # use nhwc
rand_image = torch.randn([224, 224, 3], dtype=torch.float32)
else: # use nchw
rand_image = torch.randn([3, 224, 224], dtype=torch.float32)
label = torch.tensor(data=[index % 1000], dtype=torch.int64)
return rand_image, label
ds = FakeDataset()
if use_jax: # convert torch tensors to numpy arrays
def numpy_collate(batch):
from jax.tree_util import tree_map
import jax.numpy as jnp
return tree_map(jnp.asarray, default_collate(batch))
collate_fn = numpy_collate
else:
collate_fn = default_collate
ds = FakeDataset()
dl = DataLoader(ds, batch_size=batch_size,
collate_fn=collate_fn)
return dl
Next, we define our PyTorch and JAX training loops. The JAX training loop relies on a Flax TrainState object and its definition follows the basic tutorial for training ML models in Flax:
@jax.jit
def train_step_jax(train_state, batch):
with jax.default_matmul_precision('tensorfloat32'):
def forward(params):
logits = train_state.apply_fn({'params': params}, batch[0])
loss = optax.softmax_cross_entropy(
logits=logits.logits, labels=batch[1]).mean()
return loss
grad_fn = jax.grad(forward)
grads = grad_fn(train_state.params)
train_state = train_state.apply_gradients(grads=grads)
return train_state
def train_step_torch(batch, model, optimizer, loss_fn, device):
inputs = batch[0].to(device=device, non_blocking=True)
label = batch[1].squeeze(-1).to(device=device, non_blocking=True)
outputs = model(inputs)
loss = loss_fn(outputs.logits, label)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
Let’s now put everything together. In the script below we have included controls for using the graph-based JIT compilation options of PyTorch, using torch.compile and torch_xla:
def train(batch_size, mode, compile_model):
print(f"Mode: {mode} n"
f"Batch size: {batch_size} n"
f"Compile model: {compile_model}")
# init model and data loader
use_jax = mode == 'jax'
use_torch_xla = mode == 'torch_xla'
model = get_model(use_jax)
train_loader = get_data_loader(batch_size, use_jax)
if use_jax:
# init jax settings
from flax.training import train_state
params = model.module.init(jax.random.key(0),
jnp.ones([1, 224, 224, 3]))['params']
optimizer = optax.sgd(learning_rate=1e-3)
state = train_state.TrainState.create(apply_fn=model.module.apply,
params=params, tx=optimizer)
else:
if use_torch_xla:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
use_full_mat_mul_precision=False)
device = xm.xla_device()
backend = 'openxla'
# wrap data loader
train_loader = pl.MpDeviceLoader(train_loader, device)
else:
device = torch.device('cuda')
backend = 'inductor'
model = model.to(device)
if compile_model:
model = torch.compile(model, backend=backend)
model.train()
optimizer = torch.optim.SGD(model.parameters())
loss_fn = torch.nn.CrossEntropyLoss()
import time
t0 = time.perf_counter()
summ = 0
count = 0
for step, data in enumerate(train_loader):
if use_jax:
state = train_step_jax(state, data)
else:
train_step_torch(data, model, optimizer, loss_fn, device)
# capture step time
batch_time = time.perf_counter() - t0
if step > 10: # skip first steps
summ += batch_time
count += 1
t0 = time.perf_counter()
if step > 50:
break
print(f'average step time: {summ / count}')
if __name__ == '__main__':
import argparse
torch.set_float32_matmul_precision('high')
parser = argparse.ArgumentParser(description='Toy Training Script.')
parser.add_argument('--batch-size', type=int, default=32,
help='input batch size for training (default: 2)')
parser.add_argument('--mode', choices=['pytorch', 'jax', 'torch_xla'],
default='jax',
help='choose training mode')
parser.add_argument('--compile-model', action='store_true', default=False,
help='whether to apply torch.compile to the model')
args = parser.parse_args()
train(**vars(args))
An Important Note on Benchmark Comparisons
When analyzing benchmark comparisons, it is of the utmost importance that we be extremely meticulous and critical about how they were conducted. This is especially true in the case of AI model development where a decision made based on inaccurate data could have extremely expensive repercussions. When comparing the runtime performance of training models there are a number of factors that can have a dominating effect on our measurements including floating type precision, matrix multiplication (matmul) precision, data loading methods, the use of flash/fused attention, etc. For example, if the default matmul precision is float32 in PyTorch and tensorfloat32 in JAX, we cannot learn much from their performance comparison. These settings can be controlled via APIs such as jax.default_matmul_precision and torch.set_float32_matmul_precision. In our script we have attempted to isolate these kinds of potential issues, but do not offer any guarantee that we have, in fact, succeeded.
Results
We ran our training script on two Google Cloud VMs, a g2-standard-16 VM (with a single NVIDIA L4 GPU) and an a2-highgpu-1g (with a single NVIDIA A100 GPU) , In each case we used a dedicated deep learning VM image (common-cu121-v20240514-ubuntu-2204-py310) with installations of PyTorch (2.3.0), PyTorch/XLA (2.3.0), JAX (0.4.28), Flax (0.8.4), Optax (0.2.2), and HuggingFace’s Transformers library (4.41.1). Please see the official documentation for appropriate installation of JAX and PyTorch/XLA for GPU.
The tables below capture the runtime results of a number of experiments. Please keep in mind that the comparative performance is likely to change drastically based on the model architecture and runtime environment. In addition, it is quite possible that a few small tweaks to the code could also have had a measurable impact on the results.
Although JAX appears to have demonstrated far superior performance than its alternatives on an L4 GPU, it came out neck-in-neck with PyTorch/XLA on A100. This is not surprising given the common XLA backend. Any XLA (HLO) graph generated by JAX should (at least in theory) be achievable by PyTorch/XLA as well. The torch.compile option underwhelmed on both platforms. This is somewhat expected given our choice of full precision floats. As noted in a previous post, the true value of torch.compile is seen when using Automatic Mixed Precision (AMP).
For additional information on the performance comparison between JAX and PyTorch, be sure to check out the more comprehensive benchmark reports compiled by HuggingFace, Google, or MLCommons.
So Why Use JAX?
A commonly stated motivation for training in JAX is the potential runtime performance optimization enabled by JIT compilation. But, given the new (PyTorch/XLA) and even newer (torch.compile) JIT compilation options in PyTorch, this claim could easily be challenged. In fact, considering the huge community of PyTorch developers and the numerous features that are natively supported in PyTorch and not in JAX/FLAX (e.g., automatic mixed precision, advanced attention layers, as of the time of this writing), one could make a strong argument not to take the time to learn JAX. However, it is our opinion that modern-day AI development teams must acquaint themselves with JAX and the opportunities that it offers. This is especially true for teams that are (like us) obsessive about utilizing the very latest and greatest available model training methodologies. On top of the potential performance benefits, here are some additional motivating factors:
Designed for XLA
Contrary to PyTorch which underwent after-the-fact “functionalization” in the form of PyTorch/XLA, JAX was designed for XLA from the ground up. This implies that certain sequences that may appear difficult or messy in PyTorch/XLA can be done elegantly in JAX. A good example of this is mixing between JIT and non-JIT functions in your training sequence — totally straightforward in JAX but may require some creativity in PyTorch/XLA.
As noted above, PyTorch/XLA and TensorFlow could — in theory — generate an XLA (HLO) graph that is identical to the one created by JAX (and therefore be equally performant). However, in practice the quality of the resulting graph will come down to the manner in which the framework-level implementation is translated into XLA. A more optimal translation will ultimately result in better runtime performance. Given its nativity to XLA, JAX could have the advantage over other frameworks.
Support for XLA Devices
The XLA-friendliness of JAX makes it especially compelling to developers of dedicated-AI accelerators, such as the Google Cloud TPU, Intel Gaudi, and AWS Trainium chips, which are often exposed as “XLA devices”. Teams that train on TPU, in particular, are likely to find the support ecosystem for JAX to be more advanced than for PyTorch/XLA.
Advanced Features
In recent years, there have been a number of advanced training features that have been released in JAX well before its counterparts. SPMD, for example, an advanced technique for device parallelism offering state-of-the-art model sharding opportunities, was introduced in JAX a couple of years ago and is only recently being carried over to PyTorch. Another example is Pallas which (at long last) enables building custom kernels for XLA devices.
Open Source Models
As a consequence of the increasing popularity of the JAX framework, more and more open-source AI models are being released in JAX. Some classic examples of this are Google’s open-sourced MaxText (LLM) and AlphaFold v2 (protein-structure prediction) models. To take full advantage of such models, you will need to either learn JAX, or undertake the non-trivial task of porting it to another language.
It is our strong belief that these considerations warrant the inclusion of JAX in any ML development toolkit.
Summary
In this post we have explored the up-and-coming JAX ML development framework. We described its reliance on the XLA compiler and demonstrated its use in a toy example. Although often noted for its speedy runtime execution, the PyTorch JIT compilation APIs (torch.compile and PyTorch/XLA) support similar potential for performance optimization. The relative performance of each option will depend greatly on the details of the model and the runtime environment.
Importantly, each ML development framework option might have unique features, (such as SPMD auto-sharding in JAX and SDPA attention in PyTorch — as of the time of this writing) that can have a decisive impact on the comparative runtime performance. Thus, the best choice of framework may depend on the degree to which your model can benefit from these features.
In conclusion, as we have emphasized in many of our previous posts, staying relevant in the constantly evolving landscape of ML development requires us to stay abreast of the most up-to-date tools and techniques, including the JAX ML development framework.
AI Model Training with JAX was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.
Originally appeared here:
AI Model Training with JAX