PyTorch Layers to MAX Mapping Guide
- Authors: Brad Larson and Claude
- Date: June 23, 2025
Introduction
This guide provides mappings between common PyTorch layers used in Hugging Face
transformers and their equivalent MAX graph operations and layer
abstractions.
Table of Contents
- Overview
- Core Layer Mappings
- Graph Operations Mapping
- Implementation Examples
- Performance Optimization Tips
Overview
MAX provides two levels of abstraction for building neural networks:
- High-level layers (
max.nn): PyTorch-compatible layer abstractions - Low-level graph operations (
max.graph.ops): Fine-grained tensor operations
Key Differences from PyTorch
- MAX uses explicit device placement
- Supports advanced quantization (Float8, GPTQ)
- Provides distributed/sharded variants of common layers
- Offers hardware-optimized kernels for specific operations
- MAX relies on the construction, compilation, and execution of graphs, unlike PyTorch's eager execution
Core Layer Mappings
1. Linear Layers
| HuggingFace/PyTorch | MAX Layer | MAX Graph Op | Notes |
|---|---|---|---|
nn.Linear |
max.nn.Linear |
ops.matmul + ops.add |
MAX supports quantization options |
nn.Linear (no bias) |
max.nn.Linear(has_bias=False) |
ops.matmul |
Use has_bias=False parameter |
| Column Parallel Linear | max.nn.ColumnParallelLinear |
- | For tensor parallelism |
| GPTQ Quantized Linear | max.nn.GPTQLinear |
- | GPTQ quantization support |
Example:
2. Embedding Layers
| HuggingFace/PyTorch | MAX Layer | MAX Graph Op | Notes |
|---|---|---|---|
nn.Embedding |
max.nn.Embedding |
ops.gather |
Token embedding lookup |
| Vocab Parallel Embedding | max.nn.VocabParallelEmbedding |
- | For distributed vocabularies |
Example:
3. Normalization Layers
| HuggingFace/PyTorch | MAX Layer | MAX Graph Op | Notes |
|---|---|---|---|
nn.LayerNorm |
max.nn.LayerNorm |
ops.layer_norm |
Epsilon parameter available |
| RMSNorm (custom) | max.nn.RMSNorm |
Custom implementation | Used in Llama, Gemma |
nn.GroupNorm |
max.nn.GroupNorm |
Custom implementation | Group-wise normalization |
| Distributed RMSNorm | max.nn.DistributedRMSNorm |
- | For tensor parallelism |
Example:
4. Attention Mechanisms
| HuggingFace/PyTorch | MAX Layer | MAX Graph Op | Notes |
|---|---|---|---|
nn.MultiheadAttention |
max.nn.MultiheadAttention |
Multiple ops | Full attention implementation |
| Attention with RoPE | max.nn.AttentionWithRope |
- | Rotary position embeddings |
| Distributed Attention | max.nn.TensorParallelAttentionWithRope |
- | Multi-GPU attention |
| Quantized Attention | max.nn.GPTQAttentionWithRope |
- | GPTQ quantized attention |
Attention Implementation with Graph Ops:
5. Activation Functions
| HuggingFace/PyTorch | MAX Layer | MAX Graph Op | Notes |
|---|---|---|---|
F.gelu |
- | ops.gelu |
Supports approximation modes |
F.silu / SwiGLU |
- | ops.silu |
Sigmoid Linear Unit |
F.sigmoid |
- | ops.sigmoid |
Sigmoid activation |
F.tanh |
- | ops.tanh |
Hyperbolic tangent |
F.relu |
- | ops.maximum(x, 0) |
ReLU via maximum |
Example:
6. Positional Embeddings
| HuggingFace/PyTorch | MAX Layer | MAX Graph Op | Notes |
|---|---|---|---|
| Rotary Embeddings | max.nn.RotaryEmbedding |
Custom ops | RoPE implementation |
| Sinusoidal PE | - | ops.sin, ops.cos |
Build with trig ops |
| Learnable PE | max.nn.Embedding |
- | Use embedding layer |
7. Pooling and Reduction
| HuggingFace/PyTorch | MAX Layer | MAX Graph Op | Notes |
|---|---|---|---|
F.adaptive_avg_pool1d |
- | ops.mean |
Use with appropriate axis |
torch.mean |
- | ops.mean |
Reduction operation |
torch.max |
- | ops.max |
Maximum reduction |
torch.sum |
- | ops.sum |
Sum reduction |
Graph Operations Mapping
Tensor Manipulation
| PyTorch Operation | MAX Graph Operation | Notes |
|---|---|---|
torch.reshape |
ops.reshape |
Shape inference with -1 |
torch.transpose |
ops.transpose |
Swap two dimensions |
torch.permute |
ops.permute |
Reorder all dimensions |
torch.squeeze |
ops.squeeze |
Remove dimensions of size 1 |
torch.unsqueeze |
ops.unsqueeze |
Add dimension of size 1 |
torch.cat |
ops.concat |
Concatenate along axis |
torch.stack |
ops.stack |
Stack along new axis |
torch.split |
ops.split |
Split into chunks |
Mathematical Operations
| PyTorch Operation | MAX Graph Operation | Notes |
|---|---|---|
@ / torch.matmul |
ops.matmul |
Matrix multiplication |
+ |
ops.add |
Element-wise addition |
- |
ops.sub |
Element-wise subtraction |
* |
ops.mul |
Element-wise multiplication |
/ |
ops.div |
Element-wise division |
torch.exp |
ops.exp |
Exponential |
torch.log |
ops.log |
Natural logarithm |
torch.sqrt |
ops.sqrt |
Square root |
torch.pow |
ops.pow |
Power operation |
Indexing and Selection
| PyTorch Operation | MAX Graph Operation | Notes |
|---|---|---|
tensor[...] |
ops.slice_tensor |
Advanced slicing |
torch.gather |
ops.gather |
Gather along dimension |
torch.scatter |
ops.scatter |
Scatter values |
torch.where |
ops.where |
Conditional selection |
torch.topk |
ops.top_k |
Top-k values and indices |
Implementation Examples
1. Transformer Block in MAX
2. Multi-Head Attention with Graph Ops
3. Feed-Forward Network with Quantization
Performance Optimization Tips
1. Use Hardware-Specific Optimizations
2. Leverage Quantization
3. Use Fused Operations
4. Optimize Memory Layout
5. Batch Operations
6. Use Distributed Variants for Large Models
Common Patterns and Best Practices
1. Residual Connections
2. Attention Masking
3. Positional Encoding Integration
References
For the latest updates and additional operations, refer to:
- MAX Python API docs: https://docs.modular.com/max/api/python/
- MAX Graph Operations: https://docs.modular.com/max/graph/ops/
- MAX Neural Network Layers: https://docs.modular.com/max/api/python/nn/