When optimizing deep learning models in PyTorch, two powerful tools come up frequently: vmap and checkpointing.
vmap enables vectorized computations, making batch operations cleaner and more efficient without manually handling batch dimensions. Checkpointing, provided by torch.utils.checkpoint.checkpoint
, helps save memory by trading off computation—storing only the forward pass and recomputing activations during the backward pass.
But combining the two leads to a frustrating roadblock: “Saved Tensor Hook Errors.” If you’ve hit this issue while trying to train a large model efficiently, you’re not alone.
Let’s break down why this happens, how to reproduce it, and practical alternatives to work around it.
Understanding the “Saved Tensor Hook Errors”
Running vmap
over a function that uses checkpoint
triggers cryptic errors like:
RuntimeError: vmap: We do not yet support calling functionalize() on operations that involve Saved Tensor Hooks.
Why does this happen? It boils down to how Saved Tensor Hooks work in PyTorch.
What Are Saved Tensor Hooks?
Checkpointing relies on tensor hooks to save and retrieve intermediate activations during backpropagation. These hooks attach to tensors and modify their behavior when used in autograd.
Here’s an example:
import torch
def hook_fn(grad):
print("Hook activated!")
return grad * 2
x = torch.tensor(2.0, requires_grad=True)
x.register_hook(hook_fn)
y = x ** 2
y.backward()
# Output:
# Hook activated!
The hook doubles the gradient during the backward pass.
Feature | Role in Checkpointing |
---|---|
Saved Tensor Hooks | Store activations and re-materialize them when needed |
vmap | Applies functions batch-wise while preserving function behavior |
_checkpoint.checkpoint | Swaps storing activations for recomputation |
The problem? vmap transforms break when encountering these hooks.
Error Breakdown: “_NoopSaveInputs Issue”
A key error message when combining checkpointing with vmap is:
RuntimeError: attempting to vmap over _NoopSaveInputs, but it has no batching rule.
Here, _NoopSaveInputs
is an internal function in checkpointing responsible for capturing activations. Unfortunately, vmap lacks the rules to batch over it, leading to failure.
Error Breakdown: “torch.func transforms don’t yet support saved tensor hooks”
Another common failure occurs if generate_vmap_rule = True
is set:
RuntimeError: torch.func transforms don't yet support saved tensor hooks
This happens because function transforms (like vmap) work on a functionalized version of a model, while checkpoint hooks rely on stored tensor states—two mechanisms that PyTorch doesn’t currently sync.
Reproducing the Error
Let’s go hands-on with an example that triggers the issue.
Setting Up PyTorch
Ensure you have PyTorch installed:
pip install torch
Some errors are version-dependent, so using PyTorch >= 2.0 is recommended.
Minimal Reproducible Example
import torch
from torch.utils.checkpoint import checkpoint
from torch.func import vmap
class SimpleTransformer(torch.nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear = torch.nn.Linear(hidden_size, hidden_size)
def forward(self, x):
return self.linear(x).relu()
def compute_loss(model, x):
return model(x).sum()
model = SimpleTransformer(128)
x = torch.randn(10, 128)
# Apply vmap with checkpointing
batched_loss = vmap(lambda x: checkpoint(compute_loss, model, x))(x)
Expected Output: This will fail with the Saved Tensor Hook Error.
Why Simple Fixes Don’t Work
Approach 1: Using generate_vmap_rule = True
Modifying the code to:
batched_loss = vmap(lambda x: checkpoint(compute_loss, model, x,
generate_vmap_rule=True))(x)
still fails due to PyTorch’s function transformations not supporting saved tensor hooks.
Approach 2: Defining a Custom vmap Rule
Technically, a custom batching rule could be written, but this is highly complex and requires deep knowledge of PyTorch internals. There’s no official support for this yet.
Workarounds That Actually Work
Option 1: Removing Checkpointing (Not Recommended)
Disabling checkpointing eliminates the issue:
batched_loss = vmap(compute_loss)(model, x)
However, this increases memory usage significantly.
Option 2: Manual Gradient Accumulation
A better workaround is gradient accumulation, which sidesteps vmap while preserving checkpoint efficiency.
batch_size = 5
num_chunks = x.shape[0] // batch_size
loss_sum = 0.0
for i in range(num_chunks):
loss = compute_loss(model, x[i * batch_size : (i + 1) * batch_size])
(loss / batch_size).backward()
loss_sum += loss.item()
This avoids vmap breaking the saved tensor hooks.
Option 3: Custom Autograd Function (Advanced)
For experts, writing a custom autograd function can sometimes allow finer control.
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, model, x):
ctx.save_for_backward(x)
return model(x)
@staticmethod
def backward(ctx, grad_output):
x, = ctx.saved_tensors
return None, grad_output
This approach is tricky to get right and adds debugging complexity.
Option 4: Using torch.compile
If using PyTorch 2.0+, torch.compile
can optimize execution instead of vmap:
torch.compile(model)(x)
It reorders execution efficiently, avoiding the need for manual batching.
Implementation: Gradient Accumulation Example
For a complete working solution that avoids the issue, try:
import torch
def compute_loss(model, x):
return model(x).sum()
model = SimpleTransformer(128)
x = torch.randn(10, 128)
# Manual gradient accumulation
batch_size = 5
num_chunks = x.shape[0] // batch_size
optimizer = torch.optim.Adam(model.parameters())
for i in range(num_chunks):
optimizer.zero_grad()
loss = compute_loss(model, x[i * batch_size : (i + 1) * batch_size])
(loss / batch_size).backward()
optimizer.step()
This method ensures compatibility without disrupting checkpointing.
Closing Remarks
The incompatibility between vmap and checkpointing in PyTorch stems from the way saved tensor hooks interact with function transformations.
While there’s no one-size-fits-all fix, solutions like manual batching with gradient accumulation provide practical paths forward.
For those working with cutting-edge PyTorch workflows, tracking open issues and community discussions can help. Hopefully, future versions will bring native vmap-support for checkpointing.
Got a better workaround? Share in the discussion forums! 🚀
Appendix: Solution Comparison
Solution | Memory Efficient | Performance | Complexity |
---|---|---|---|
Remove Checkpoint | ❌ No | ✅ Fast | ✅ Easy |
Gradient Accumulation | ✅ Yes | ⚖️ Balanced | ✅ Easy |
Custom Autograd | ✅ Yes | ✅ Fast | ❌ Hard |
Torch Compile | ✅ Yes | ✅ Fast | ⚖️ Medium |
0 Comments