Quick notes before we begin:
- some optimizations improve runtime Runtime, some improve memory usage Memory, some trade one for the other, like better Runtime, worse Memory. This isn't a checklist, rather a reminder of some of the options for different situations. Apply what you need.
- items are in rough order of application - i.e. try the small scale things before the large scale ones!
- there are other optimizations if you're willing to modify your architecture.
- always profile before you optimize!
Any model size
Model compilation & kernel fusion
RuntimeOnce you're done debugging your model and are ready to run something for longer, compile it into a graph with torch.compile
. Graph mode execution drastically reduces data transfer overheads and can fuse multiple GPU kernels so that they are executed all together while the data is still on the chip.
Pay attention to the mode
argument for maximum performance.
Try to compile the largest module that works, like your whole model, or smaller parts if something fails. You will also have to save the weights of the base model, which shares the parameters of the compiled model, but changes the state_dict
naming schema.
model = torch.compile(model)
Nice numbers for tensor blocks
Runtime MemoryTry stick to even multiples of 8/64 so that the GPU kernels don't have to handle the edge-cases of odd extra pieces. Seriously, just walk through your model config and look through all numbers that aren't multiples of 8.
Karpathy showed how increasing the GPT-2 embedding dimension from 50257 to 50304 increased the performance about 4% despite a few extra parameters.
One sneaky number is uneven batches from the dataloader. Avoid this by skipping the last batch from the dataloader. Tiny batches can also lead to bad gradient updates.
loader = torch.utils.data.DataLoader(dataset, drop_last=True)
Fused adam optimizer
RuntimeIf you're using adam/adamW, you can enable fusing of the optimizer pass (which still isn't default).
optim = torch.optim.AdamW(model.parameters(), fused=True)
Reduced precision
Runtime MemoryIf you have an Ampere architecture or greater GPU (A100, H100, 4090, etc.), use bfloat16
datatypes. bfloat16
is a half precision data format that doesn't require special gradient scaling and is easy to add to your forward pass and loss calculation:
# Enables autocasting for the forward pass (model + loss)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
output = model(input)
loss = loss_fn(output, target)
If you have an older GPU, you can still use half-precision but it becomes more tricky with gradient scalers and special clipping. Read about it here: https://pytorch.org/docs/stable/amp.html
If you really want float32 precision and you have an Ampere or greater GPU, enable tensorfloat (TF32) matmul and convolution operations. This enables float32-like results but much faster.
# The flag below controls whether to allow TF32 on matmul.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN.
torch.backends.cudnn.allow_tf32 = True
You can also just set matmul's to take advantage of reduced precision. This is equivalent to the above torch.backends.cuda.matmul.allow_tf32 = True
.
torch.set_float32_matmul_precision('high')
Inference tricks
There's lots of tricks for fast inference that don't apply to training, but most go beyond the scope of this article. However, if you do model inference in-the-loop alongside training, these are often very helpful! You can apply them for evaluation, or other setups, like reinforcement learning that use model inferences online.
Other resources include Lilian Weng's blogpost on inference optimization which goes into some more detail and this blogpost from character.ai has some nice tricks for inference optimization of LLMs.
Inference mode
RuntimeI often forget the more recent addition of torch.inference_mode()
to torch - it's like torch.no_grad()
but improves performance with some small optimizations.
with torch.inference_mode():
output = model(inputs)
Model quantization
Runtime MemoryQuantization is quite a meaty topic on it's own, but can provide drastic memory reductions and speed boosts, at the cost of some accuracy. If you plan on quantizing the model down the line then depending on the method you may want to train with quantization-aware training so that the weights fit better into the quantized dtype.
Some methods include GPTQ, AWQ, or HQQ.
KV-caching
RuntimeWhen generating tokens from a transformer model auto-regressively we calculate the attention weights for all the previous tokens again and again. Instead cache the intermediate tokens and only append the new K and V vectors for the last token.
In llama-3, they cache the values in the attention block:
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
# Module setup
# ...
# KV Cache Setup
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
# Calculate new q, k, v vectors
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
# Update KV cache
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
# Rest of self attention calculation
# ...
Data optimizations
You've been profiling this whole time right? And you only optimized the model because it's the bottleneck right? Good!
Nowadays, modern GPUs are crazy fast, and it's quite hard to feed them with enough data to keep them computing at maximum output. So if you look at your profiling graphs you may realize that there's a lot more time being taken on the data loading, transforming and general handling than there is on the optimization.
More data loading workers
RuntimeYep, pretty simple. If your dataloader is too slow, load multiple samples in parallel.
loader = torch.utils.data.DataLoader(dataset, num_workers=os.cpu_count())
Pinned memory
RuntimeIn a similar vein, you can pin dataloader memory so that batches can be passed to the GPU much faster. You can even pass data asynchronously when moving it to GPU.
loader = torch.utils.data.DataLoader(dataset, pin_memory=True)
for batch in loader:
batch = batch.to('cuda', non_blocking=True)
The docs do warn though:
If you overuse pinned memory, it can cause serious problems when running low on RAM, and you should be aware that pinning is often an expensive operation.
Faster data transforms
RuntimeIf you're augmenting your dataset with random transforms, these can often take lots of time relative to a single gradient step. You can write better code
Store your data in the training format
RuntimeIf you run any preprocessing on your data, whether tokenization, type conversion, or normalization, save the data in the easiest to consume format for training.
Multiple train steps
RuntimeIf your algorithm allows, do multiple passes over your data, shuffling the mini-batches. You may need some more regularization, or some algorithmic adjustments but it can be a good sample efficiency improvement - especially if you have little data or data generation is slow.
Faster data generation
RuntimeIn any application where you are generating data on-the-fly, like environment rollouts in reinforcement learning, the data generation is often much slower than training because of it's slow sequential sampling. Again, profiling is your friend here, but eventually you may want to write it in a faster language, like C++, or parallelize either individual computations or the entire thing. This is quite painful to do, but may be your last option.
100M+ Models (1 GPU)
Gradient Accumulation
Runtime MemoryIf you're running low on memory, but you still want a bigger batch size, you can accumulate the gradients of multiple mini-batches into one. Runtime will take a hit, since each mini-batch backward pass is done sequentially, but your loss is identical to a larger batch size.
for micro_step in range(gradient_accumulation_steps):
loss = model(X, Y)
# scale the loss to account for gradient accumulation
loss = loss / gradient_accumulation_steps
loss.backward()
# step the optimizer after accumulating the gradients
optimizer.step()
# flush the gradients as soon as we can, no need for this memory anymore
optimizer.zero_grad(set_to_none=True)
Custom GPU kernels
Runtime MemoryCustom GPU kernels can increase performance above what torch.compile
can do, however, they take much more effort to implement. I'd start with prebuilt options before writing your own, like automatic flash attention:
torch.nn.functional.scaled_dot_product_attention(query, key, value)
Unsloth has triton kernels for RoPE, cross entropy loss, RMSNorm, LoRA and more.
1B+ Models (1 node, multiple GPUs)
Distributed Data Parallel (DDP)
MemoryDDP puts a copy of your model on every GPU available, adding up your gradients from all the model copies. Your model must be able to fit on 1 GPU at a batch size of 1 at the minimum.
You will need to also run with torchrun
and disable gradient syncing for gradient accumulation steps.
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[ddp_local_rank]
)
Low rank finetuning (LoRA, QLoRA, DoRA)
Runtime MemoryIf you're only finetuning a pretrained model, you can often get away with finetuning only a small percentage of the weights. For example, you might finetune only a new linear "head" for a downstream task.
LoRA adds a low-rank matrix to the original, frozen weights for finetuning. QLoRA quantizes those original weights for faster and lower-memory inference. DoRA improves on LoRA by decomposing the new weight matrix into separate direction and magnitude matrices for better opimization.
Answer.ai managed to finetune a 70B llama-3 model on only 2 24Gb 3090 GPUs with these methods.
10B+ Models (multiple nodes, multiple GPUs each)
When it comes to very large models, most optimizations aim to reduce the memory footprint just so that we can do any training in the first place.
Efficient optimizers
Runtime MemoryIf you're willing to stray from The Best Optimizer™, Adam, you can save on those awfully painful first and second moments storage costs. Look at ZeRO, AdaFactor or Lion - a few are also implemented in reduced precision in bitsandbytes
too.
Then again, you may just want to try reduced precision 8 bit Adam optimizers instead.
import bitsandbytes as bnb
adam = bnb.optim.Adam8bit(...)
# Use 8-bit stable embedding layer too
bnb.nn.StableEmbedding(...)
Fully Sharded Data Parallel (FSDP)
Runtime MemoryWhen a single batch doesn't fit on your GPU you'll need to split the model onto multiple GPUs. This increases communication overhead between GPUs quite a bit since one GPU's input depends on another's output, but we can optimize this by pipelining our implementation well.
FSDP can combine parallelism across multiple axes: split parts of a model up onto different devices (model parallel), split a layer's tensors onto multiple devices (tensor parallel), and pipeline it all together for efficiency (pipeline parallel). Lilian Weng's blogpost How to Train Really Large Models on Many GPUs goes into much more detail on each method.
There are lots of options and combinations here, so check out this full FSDP training example in 180 LOC: https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/fsdp_tp_example.py
CPU offloading
Runtime MemoryOffload parts of the model that aren't being used in the current computation to the CPU. This can drastically increase the model size your GPU can handle but comes with a big communication overhead.
Currently in FSDP, cpu offloading will put all parameters and gradients on the cpu when not in use.
model = FSDP(model, cpu_offload=CPUOffload(offload_params=True))
Activation/gradient checkpointing
Runtime MemoryInstead of storing all of the model activations for the backward pass, we can skip a few activations, only storing every so often (often is usually every 4 layers). Then, when we need to calculate the gradients, we can recompute all the activations from the nearest checkpoint.
You can checkpoint any module's state quite easily, but try to checkpoint at stages where it's fast to recompute the activations, and (ideally) the activations aren't too large in memory.
model = nn.Sequential(blocks)
# Checkpoint in between each block, including start and end
output = checkpoint_sequential(model, len(blocks) + 1, inputs)
Deepspeed
Runtime MemoryDeepspeed is an FSDP alternative by microsoft that supports partial offloading, sparse attention, 1-bit optimizers, and a bunch other stuff. It takes care of a lot of the details, but can be a bit opaque in how it all fits together.
model, optimizer, _, _ = deepspeed.initialize(
model=model,
optimizer=optimizer,
args=args,
)
# Loss calculation
loss = model(inputs)
# Backward pass
model.backward(loss)
model.step()
Examples
- NanoGPT: a single file training script for GPT-2 with most of these optimizations.
- Finetuning llama-3 70B on 2 3090 GPUs with FSDP and QLoRA by Answer.ai
- Full FSDP training pipeline example in 180 LOC
- Training a 1 trillion parameter model with FSDP from the pytorch blog
- Pytorch tuning guide
- Deepspeed megatron tutorial
Future Additions
Things are changing fast in AI. If there is something new, or something I missed, please get in touch! My email is on the home page.