Accelerating AI/ML Model Training with Custom Operators — Part 3
This is the third part of a series of posts on the topic of building custom operators for optimizing AI/ML workloads. In our previous post we demonstrated the simplicity and accessibility of Triton. Named for the Greek god of the sea, Triton empowers Python developers to increase their control over the GPU and optimize its use for the specific workload at hand. In this post we move one step down the lineage of Greek mythology to Triton’s daughter, Pallas and discuss her namesake, the JAX extension for writing custom kernels for GPU and TPU.
One of the most important features of NVIDIA GPUs — and a significant factor in their rise to prominence — is their programmability. A key ingredient of the GPU offering are frameworks for creating General-Purpose GPU (GPGPU) operators, such as CUDA and Triton.
In previous posts (e.g., here) we discussed the opportunity for running ML workloads on Google TPUs and the potential for a meaningful increase in price performance and a reduction in training costs. One of the disadvantages that we noted at the time was the absence of tools for creating custom operators. As a result, models requiring unique operators that were either unsupported by the underlying ML framework (e.g., TensorFlow/XLA) or implemented in a suboptimal manner, would underperform on TPU compared to GPU. This development gap was particularly noticeable over the past few years with the frequent introduction of newer and faster solutions for computing attention on GPU. Enabled by GPU kernel development frameworks, these led to a significant improvement in the efficiency of transformer models.
On TPUs, on the other hand, the lack of appropriate tooling prevented this innovation and transformer models were stuck with the attention mechanisms that were supported by the official SW stack. Fortunately, with the advent of Pallas this gap has been addressed. Built as an extension to JAX and with dedicated support for PyTorch/XLA, Pallas enables the creation of custom kernels for GPU and TPU. For its GPU support Pallas utilizes Triton, and for its TPU support it uses a library called Mosaic. Although we will focus on custom kernels for TPU, it is worth noting that when developing in JAX, GPU kernel customization with Pallas offers some advantages over Triton (e.g., see here).
Our intention in this post is to draw attention to Pallas and demonstrate its potential. Please do not view this post as a replacement for the official Pallas documentation. The examples we will share were chosen for demonstrative purposes, only. We have made no effort to optimize these or verify their robustness, durability, or accuracy.
Importantly, at the time of this writing Pallas is an experimental feature and still under active development. The samples we share (which are based on JAX version 0.4.32 and PyTorch version 2.4.1) may become outdated by the time you read this. Be sure to use the most up-to-date APIs and resources available for your Pallas development.
Many thanks to Yitzhak Levi for his contributions to this post.
Environment Setup
For the experiments described below we use the following environment setup commands:
# create TPU node
gcloud alpha compute tpus queued-resources create v5litepod-1-resource
--node-id v5litepod
--project <project-id>
--zone us-central1-a
--accelerator-type v5litepod-1
--runtime-version v2-alpha-tpuv5-lite
--valid-until-duration 1d
--service-account <service-account>
# check TPU node status (wait for state to be ACTIVE)
gcloud alpha compute tpus queued-resources describe v5litepod-1-resource
--project <project-id>
--zone us-central1-a
# SSH to TPU node
gcloud alpha compute tpus tpu-vm ssh v5litepod
--project <project-id>
--zone us-central1-a
# install dependencies
pip install torch_xla[tpu]
-f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas]
pip install timm
# run tests
python train.py
#exit ssh
exit
# delete TPU node
gcloud alpha compute tpus queued-resources delete v5litepod-1-resource
--project <project-id>
--zone us-central1-a --force --quiet
Pallas Kernels for TPU
In the toy example of our first post in this series, we distinguished between two different ways in which custom kernel development can potentially boost performance. The first is by combining (fusing) together multiple operations in a manner that reduces the overhead of: 1) loading multiple individual kernels, and 2) reading and writing intermediate values (e.g., see PyTorch’s tutorial on multiply-add fusion). The second is by meticulously applying the resources of the underlying accelerator in manner that optimizes the function at hand. We briefly discuss these two opportunities as they pertain to developing custom TPU kernels and make note of the limitations of the Pallas support.
Operator Fusion on TPU
The TPU is an XLA (Accelerated Linear Algebra) device, i.e., it runs code that has been generated by the XLA compiler. When training an AI model in a frameworks such as JAX or PyTorch/XLA, the training step is first transformed into an intermediate graph representation (IR). This computation graph is then fed to the XLA compiler which converts it into machine code that can run on the TPU. Contrary to eager execution mode, in which operations are executed individually, this mode of running models enables XLA to identify and implement opportunities for operator fusion during compilation. And, in fact, operator fusion is the XLA compiler’s most important optimization. Naturally, no compiler is perfect and we are certain to come across additional opportunities for fusion through custom kernels. But, generally speaking, we might expect the opportunity for boosting runtime performance in this manner to be lower than in the case of eager execution.
Optimizing TPU Utilization
Creating optimal kernels for TPU requires a comprehensive and intimate understanding of the TPU system architecture. Importantly, TPUs are very different from GPUs: expertise in GPUs and CUDA does not immediately carry over to TPU development. For example, while GPUs contain a large number of processors and draw their strength from their ability to perform massive parallelization, TPUs are primarily sequential with dedicated engines for running highly vectorized operations and support for asynchronous scheduling and memory loading.
The differences between the underlying architectures of the GPU and TPU can have significant implications on how custom kernels should be designed. Mastering TPU kernel development requires 1) appropriate overlapping of memory and compute operations via pipelining, 2) knowing how to mix between the use of the scalar, vector (VPU) and matrix (MXU) compute units and their associated scalar and vector registers (SREG and VREG) and memory caches (SMEM and VMEM), 3) a comprehension of the costs of different low-level operations, 4) appropriate megacore configuration (on supporting TPU generations), 5) a grasp of the different types of TPU topologies and their implications on how to support distributed computing, and more.
Framework Limitations
While the ability to create custom operators in Python using JAX functions and APIs greatly increases the simplicity and accessibility of Pallas kernel development, it also limits its expressivity. Additionally, (as of the time of this writing) there are some JAX APIs that are not supported by Pallas on TPU (e.g., see here). As a result, you may approach Pallas with the intention of implementing a particular operation only to discover that the framework does not support the APIs that you need. This is in contrast to frameworks such as CUDA which enable a great deal of flexibility when developing custom kernels (for GPU).
The matrix multiplication tutorial in the Pallas documentation provides an excellent introduction to Pallas kernel development, highlighting the potential for operator fusion and customization alongside the challenges involved in optimizing performance (e.g., appropriate tuning of the input block size). The tutorial clearly illustrates that maximizing the full potential of the TPU requires a certain degree of specialization. However, as we intend to demonstrate, even the novice ML developer can benefit from Pallas kernels.
Integrating the Use of Existing Pallas Kernels
To benefit from custom Pallas kernels you do not necessarily need to know how to build them. In our first example we demonstrate how you can leverage existing Pallas kernels from dedicated public repositories.
Example — Flash Attention in Torch/XLA
The JAX github repository includes implementations of a number of Pallas kernels, including flash attention. Here we will demonstrate its use in a Torch/XLA Vision Transformer (ViT) model. Although Pallas kernels are developed in JAX, they can be adopted into Torch/XLA, e.g., via the make_kernel_from_pallas utility (see the documentation for details). In the case of flash attention the adoption is implemented by Torch/XLA.
In the following code block we define a stripped down version of the classic timm attention block with an option to define the underlying attention operator in the constructor. We will use this option to compare the performance of the flash attention Pallas kernel to its alternatives.
# general imports
import os, time, functools
# torch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch_xla.core.xla_model as xm
# custom kernel import
from torch_xla.experimental.custom_kernel import flash_attention
# timm imports
from timm.layers import Mlp
from timm.models.vision_transformer import VisionTransformer
class TPUAttentionBlock(nn.Module):
def __init__(
self,
dim: int = 768,
num_heads: int = 12,
attn_fn = None,
**kwargs
) -> None:
super().__init__()
self.attn_fn = attn_fn
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.proj = nn.Linear(dim, dim)
self.mlp = Mlp(
in_features=dim,
hidden_features=dim * 4,
)
def forward(self, x_in: torch.Tensor) -> torch.Tensor:
x = self.norm1(x_in)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
if self.attn_fn is None:
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
x = attn @ v
else:
x = self.attn_fn(q, k, v)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = x + x_in
x = x + self.mlp(self.norm2(x))
return x
In the following block we train a simple ViT-backed classification model using the input dataset and attention function (attn_fn) of choice.
def train(dataset, attn_fn=None):
device = xm.xla_device()
train_loader = DataLoader(
dataset,
batch_size=128,
num_workers=os.cpu_count(),
pin_memory=True
)
# configure the VisionTranformer in a manner that complies with the
# Pallas flash_attention kernel constraints
model = VisionTransformer(
block_fn=functools.partial(TPUAttentionBlock, attn_fn=attn_fn),
img_size=256,
class_token=False,
global_pool="avg"
)
optimizer = torch.optim.SGD(model.parameters())
loss_fn = torch.nn.CrossEntropyLoss()
# copy the model to the TPU
model = model.to(device)
model.train()
t0 = time.perf_counter()
summ = 0
count = 0
for step, data in enumerate(train_loader):
# copy data to TPU
inputs = data[0].to(device=device, non_blocking=True)
label = data[1].to(device=device, non_blocking=True)
optimizer.zero_grad(set_to_none=True)
with torch.autocast('xla', dtype=torch.bfloat16):
output = model(inputs)
loss = loss_fn(output, label)
loss.backward()
optimizer.step()
xm.mark_step()
# capture step time
batch_time = time.perf_counter() - t0
if step > 20: # skip first steps
summ += batch_time
count += 1
t0 = time.perf_counter()
if step > 100:
break
print(f'average step time: {summ / count}')
Note the specific configuration we chose for the VisionTransformer. This is to comply with certain restrictions (as of the time of this writing) of the custom flash attention kernel (e.g., on tensor shapes).
Finally, we define a dataset and compare the runtimes of training with three different attention routines, 1. using native PyTorch functions, 2. using PyTorch’s built in SDPA function, and 3. using the custom Pallas operator:
# use random data
class FakeDataset(Dataset):
def __len__(self):
return 1000000
def __getitem__(self, index):
rand_image = torch.randn([3, 256, 256], dtype=torch.float32)
label = torch.tensor(data=index % 1024, dtype=torch.int64)
return rand_image, label
ds = FakeDataset()
print('PyTorch native')
train(ds, attn_fn=None)
print('PyTorch SDPA')
train(ds, attn_fn=functools.partial(F.scaled_dot_product_attention, scale=1.0))
print('Pallas flash_attention')
train(ds, attn_fn=flash_attention)
The comparative results are captured in the table below:
Although our Pallas kernel clearly underperforms when compared to its alternatives, we should not be discouraged:
- It is likely that these results could be improved with appropriate tuning.
- These results are specific to the model and runtime environment that we chose. The Pallas kernel may exhibit wholly different comparative results in other use cases.
- The real power of Pallas is in the ability to create and adjust low level operators to our specific needs. Although runtime performance is important, a 23% performance penalty (as in our example) may be a small price to pay for this flexibility. Moreover, the opportunity for customization may open up possibilities for optimizations that are not supported by the native framework operations.
Enhancing Existing Kernels
Oftentimes it may be easier to tweak an existing Pallas kernel to your specific needs, rather than creating one from scratch. This is especially recommended if the kernel has already been optimized as performance tuning can be tedious and time-consuming. The official matrix multiplication tutorial includes a few examples of how to extend and enhance an existing kernel. Here we undertake one of the suggested exercises: we implement int8 matrix multiplication and assess its performance advantage over its bfloat16 alternative.
Example — Int8 Matrix Multiplication
In the code block below we implement an int8 version of the matrix multiplication example.
import functools, timeit
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
# set to True to develop/debug on CPU
interpret = False
def matmul_kernel_int8(x_ref, y_ref, z_ref, acc_ref, *, nsteps):
@pl.when(pl.program_id(2) == 0)
def _():
acc_ref[...] = jnp.zeros_like(acc_ref)
acc_ref[...] += jnp.dot(
x_ref[...], y_ref[...], preferred_element_type=jnp.int32
)
@pl.when(pl.program_id(2) == nsteps - 1)
def _():
z_ref[...] = acc_ref[...]
@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn'])
def matmul_int8(
x: jax.Array,
y: jax.Array,
*,
bm: int = 128,
bk: int = 128,
bn: int = 128,
):
m, k = x.shape
_, n = y.shape
return pl.pallas_call(
functools.partial(matmul_kernel_int8, nsteps=k // bk),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec(block_shape=(bm, bk),
index_map=lambda i, j, k: (i, k)),
pl.BlockSpec(block_shape=(bk, bn),
index_map=lambda i, j, k: (k, j)),
],
out_specs=pl.BlockSpec(block_shape=(bm, bn),
index_map=lambda i, j, k: (i, j)),
scratch_shapes=[pltpu.VMEM((bm, bn), jnp.int32)],
grid=(m // bm, n // bn, k // bk),
),
out_shape=jax.ShapeDtypeStruct((m, n), jnp.int32),
compiler_params=dict(mosaic=dict(
dimension_semantics=("parallel", "parallel", "arbitrary"))),
interpret=interpret
)(x, y)
Note our use of an int32 accumulation matrix for addressing the possibility of overflow. Also note our use of the interpret flag for debugging of Pallas kernels on CPU (as recommended here).
To assess our kernel, we introduce a slight modification to the benchmarking utilities defined in the tutorial and compare the runtime results to both the jnp.float16 Pallas matmul kernel and the built-in JAX matmul API:
def benchmark(f, ntrials: int = 100):
def run(*args, **kwargs):
# Compile function first
jax.block_until_ready(f(*args, **kwargs))
# Time function
res=timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)),
number=ntrials
)
time = res/ntrials
# print(f"Time: {time}")
return time
return run
def analyze_matmul(m: int, k: int, n: int, dtype: jnp.dtype,
mm_func):
x = jnp.ones((m, k), dtype=dtype)
y = jnp.ones((k, n), dtype=dtype)
time = benchmark(mm_func)(x, y)
print("Matmul time: ", time)
mm_ops = 2*m*k*n/time
v5e_ops = 394e12 if dtype == jnp.int8 else 197e12
print(f"OP/s utilization: {mm_ops / v5e_ops * 100:.4f}%")
print()
print("bfloat16 Pallas matmul")
mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)
print("int8 Pallas matmul")
mm = functools.partial(matmul_int8, bm=512, bk=1024, bn=1024)
analyze_matmul(8192, 8192, 8192, jnp.int8, mm)
print("XLA int8 matmul")
mm = functools.partial(jnp.matmul, preferred_element_type=jnp.int32)
analyze_matmul(8192, 8192, 8192, jnp.int8, mm)
The results of our experiment are captured in the table below:
By using int8 matrices (rather than bfloat16matrices) on tpuv5e we can boost the runtime performance of our custom matrix multiplication kernel by 71%. However, as in the case of the bfloat16 example, additional tuning is required to match the performance of the built-in matmul operator. The potential for improvement is highlighted by the drop in system utilization when compared to bfloat16.
Creating a Kernel from Scratch
While leveraging existing kernels can be greatly beneficial, it is unlikely to solve all of your problems. Inevitably, you may need to implement an operation that is either unsupported on TPU or exhibits suboptimal performance. Here we demonstrate the creation of a relatively simple pixel-wise kernel. For the sake of continuity, we choose the same Generalized Intersection Over Union (GIOU) operation as in our previous posts.
Example — A GIOU Pallas Kernel
In the code block below we define a Pallas kernel that implements GIOU on pairs of batches of bounding boxes, each of dimension BxNx4 (where we denote the batch size by B and the number of boxes per sample by N) . The function returns a tensor of scores of dimension BxN. We choose a block size of 128 on both the batch axis and the boxes axis, i.e., we divide each of the tensors into blocks of 128x128x4 and pass them to our kernel function. The grid and BlockSpec index_map are defined accordingly.
import timeit
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
# set to True to develop/debug on CPU
interpret = False
# perform giou on a single block
def giou_kernel(preds_left_ref,
preds_top_ref,
preds_right_ref,
preds_bottom_ref,
targets_left_ref,
targets_top_ref,
targets_right_ref,
targets_bottom_ref,
output_ref):
epsilon = 1e-5
# copy tensors into local memory
preds_left = preds_left_ref[...]
preds_top = preds_top_ref[...]
preds_right = preds_right_ref[...]
preds_bottom = preds_bottom_ref[...]
gt_left = targets_left_ref[...]
gt_top = targets_top_ref[...]
gt_right = targets_right_ref[...]
gt_bottom = targets_bottom_ref[...]
# Compute the area of each box
area1 = (preds_right - preds_left) * (preds_bottom - preds_top)
area2 = (gt_right - gt_left) * (gt_bottom - gt_top)
# Compute the intersection
left = jnp.maximum(preds_left, gt_left)
top = jnp.maximum(preds_top, gt_top)
right = jnp.minimum(preds_right, gt_right)
bottom = jnp.minimum(preds_bottom, gt_bottom)
# intersection width and height
inter_w = jnp.maximum(right - left, 0)
inter_h = jnp.maximum(bottom - top, 0)
# intersection area
inter_area = inter_w * inter_h
# union of two boxes
union_area = area1 + area2 - inter_area
iou_val = inter_area / jnp.maximum(union_area, epsilon)
# Compute the smallest enclosing box
enclose_left = jnp.minimum(preds_left, gt_left)
enclose_top = jnp.minimum(preds_top, gt_top)
enclose_right = jnp.maximum(preds_right, gt_right)
enclose_bottom = jnp.maximum(preds_bottom, gt_bottom)
# enclosing box width and height
enclose_w = jnp.maximum(enclose_right - enclose_left, 0)
enclose_h = jnp.maximum(enclose_bottom - enclose_top, 0)
# enclosing box area
enclose_area = enclose_w * enclose_h
# Compute GIOU
delta_area = (enclose_area - union_area)
enclose_area = jnp.maximum(enclose_area, epsilon)
output_ref[...] = iou_val - delta_area / enclose_area
@jax.jit
def batch_giou(preds, targets):
m, n, _ = preds.shape
output = pl.pallas_call(
giou_kernel,
out_shape=jax.ShapeDtypeStruct((m, n), preds.dtype),
in_specs=[pl.BlockSpec(block_shape=(128, 128),
index_map=lambda i, j: (i, j))]*8,
out_specs=pl.BlockSpec(block_shape=(128, 128),
index_map=lambda i, j: (i, j)),
grid=(m // 128, n // 128),
compiler_params=dict(mosaic=dict(
dimension_semantics=("parallel", "parallel"))),
interpret=interpret
)(*jnp.unstack(preds, axis=-1), *jnp.unstack(targets, axis=-1))
return output
Although the creation of a new TPU kernel is certainly cause for celebration (especially if it enables a previously blocked ML workload) our work is not done. A critical part of Pallas kernel development is tuning the operator, (e.g. the block size) for optimal runtime performance. We omit this stage in the interest of brevity.
To asses the performance of our kernel, we compare it to the following native JAX GIOU implementation:
def batched_box_iou(boxes1, boxes2):
epsilon = 1e-5
# Compute areas of both sets of boxes
area1 = (boxes1[..., 2]-boxes1[..., 0])*(boxes1[..., 3]-boxes1[..., 1])
area2 = (boxes2[..., 2]-boxes2[..., 0])*(boxes2[..., 3]-boxes2[..., 1])
# corners of intersection
lt = jnp.maximum(boxes1[..., :2], boxes2[..., :2])
rb = jnp.minimum(boxes1[..., 2:], boxes2[..., 2:])
# width and height of intersection
wh = jnp.clip(rb - lt, a_min=0)
# area of the intersection
inter = wh[..., 0] * wh[..., 1]
# union of the two boxes
union = area1 + area2 - inter
iou = inter / jnp.clip(union, a_min=epsilon)
# corners of enclosing box
lti = jnp.minimum(boxes1[..., :2], boxes2[..., :2])
rbi = jnp.maximum(boxes1[..., 2:], boxes2[..., 2:])
# Width and height of the enclosing box
whi = jnp.clip(rbi - lti, a_min=0)
# Area of the enclosing box
areai = jnp.clip(whi[..., 0] * whi[..., 1], a_min=epsilon)
# Generalized IoU
return iou - (areai - union) / areai
We generate two batches of randomly generated bounding boxes and measure the performance of our functions using the benchmark function defined above.
from jax import random
batch_size = 1024
n_boxes = 256
img_size = 256
boxes = []
for i in range(2):
k1, k2 = random.split(random.key(i), 2)
# Randomly generate box sizes and positions
box_sizes = random.randint(k1, shape=(batch_size, n_boxes, 2), minval=1, maxval=img_size)
top_left = random.randint(k2, shape=(batch_size, n_boxes, 2), minval=0, maxval=img_size - 1)
bottom_right = jnp.clip(top_left + box_sizes, 0, img_size - 1)
# Concatenate top-left and bottom-right coordinates
rand_boxes = jnp.concatenate((top_left, bottom_right), axis=2)
boxes.append(rand_boxes.astype(jnp.float32))
time = benchmark(batch_giou)(boxes[0], boxes[1])
print(f'Pallas kernel: {time}')
time = benchmark(batched_box_iou)(boxes[0], boxes[1])
print(f'JAX function: {time}')
time = benchmark(jax.jit(batched_box_iou))(boxes[0], boxes[1])
print(f'Jitted function: {time}')
The comparative results appear in the table below:
We can see that JIT-compiling our naive JAX implementation results in slightly better performance than our Pallas kernel. Once again, we can see that matching or surpassing the performance results of JIT compilation (and its inherent kernel fusion) would require fine-tuning of our custom kernel.
Utilizing the Sequential Nature of TPUs
While the ability to develop custom kernels for TPU offers great potential, our examples thus far have demonstrated that reaching optimal runtime performance could be challenging. One way to overcome this is to seek opportunities to utilize the unique properties of the TPU architecture. One example of this is the sequential nature of the TPU processor. Although deep learning workloads tend to rely on operations that are easily parallelizable (e.g., matrix multiplication), on occasion they require algorithms that are inherently sequential. These can pose a serious challenge for the SIMT (single instruction multi thread) model of GPUs and can sometimes have a disproportionate impact on runtime performance. In a sequel to this post, we demonstrate how we can implement sequential algorithms in a way that takes advantage of the TPUs sequential processor and in a manner that minimizes their performance penalty.
Summary
The introduction of Pallas marks an important milestone in the evolution of TPUs. By enabling customization of TPU operations it can potentially unlock new opportunities for TPU programmability, particularly in the world of ML. Our intention in this post was to demonstrate the accessibility of this powerful new feature. While our examples have indeed shown this, they have also highlighted the effort required to reach optimal runtime performance.
This post has merely scratched the surface of Pallas kernel development. Be sure to see the official documentation to learn more about automatic differentiation in Pallas, developing sparse kernels, and more.
The Rise of Pallas: Unlocking TPU Potential with Custom Kernels was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.
Originally appeared here:
The Rise of Pallas: Unlocking TPU Potential with Custom Kernels
Go Here to Read this Fast! The Rise of Pallas: Unlocking TPU Potential with Custom Kernels