Anatomy of Triton
Dissecting Triton and writing small LLM purely in it
ยท 12 min read
Disclaimer #
In this post, I want to focus on Triton itself and won’t go into the details of threads, blocks, grids, or warps. If you’re not familiar with these concepts, I recommend reading this post first - and maybe checking out this visualisation too.
Fun part #
Why Triton? #
If you’ve ever written GPU kernels in CUDA, you know the drill: think in threads, compute thread IDs, divide work manually, fight with memory coalescing, and hope the compiler doesn’t betray you. This SIMT (Single Instruction, Multiple Threads) model is powerful, but also incredibly low-level.
Triton was created to change that. Instead of thinking in threads, Triton lets you think in blocks of data. It introduces a block-based programming paradigm where you write code that conceptually operates on a whole vector/tile at once, and the compiler takes care of generating efficient, coalesced, parallel GPU instructions.
In CUDA you think like this:
“I am thread 5. I will load index 5.”
In Triton, it becomes:
“I am a block of pointers. I will load a block of data.”
When you write something like:
x = tl.load(ptr + offsets)
you’re not loading a single element - you’re loading an entire tensor tile directly into GPU registers. Triton’s compiler then auto-vectorizes, auto-tiles, and arranges memory accesses to match modern GPU hardware. It’s like magic and very powerful tool unless you need manual memory management or scheduling.
What is kernel? #
So what we are actually running on this threads/blocks? A kernel is simply a function that executes on the GPU - but unlike a normal function, it runs many times in parallel, once per thread or block (depending on the programming model). You write the logic for “one unit of work”, and the GPU launches thousands of these units at once.
Every Triton operation needs TWO parts:
THE KERNEL (@triton.jit decorated function):
- Runs ON the GPU
- Defines the computation for ONE program instance
- Has special restrictions (can’t use regular Python features)
- Uses tl.* functions (triton.language)
THE WRAPPER (regular Python function):
- Runs ON the CPU
- Sets up memory and launches the kernel
- Defines the grid (how many parallel instances to run)
- Passes parameters to the kernel
So very basic kernel could look like this:
@triton.jit # 1. Decorator tells Triton to compile this
def my_kernel(
input_ptr, # 2. Pointers to GPU memory
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr # 3. Compile-time constants (capitalized)
):
pid = tl.program_id(0) # 4. Which parallel instance am I?
# 5. Calculate which elements THIS instance processes
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# 6. Load data (with bounds checking - block size can be larger than the array)
mask = offsets < n_elements
x = tl.load(input_ptr + offsets, mask=mask)
# 7. Do computation
y = x * 2
# 8. Store results
tl.store(output_ptr + offsets, y, mask=mask)
But let’s ditch toy examples and write something more interesting. Like a toy LLM purely in Triton.
LLM in Triton #
Omiting embeddings, every Transformer is just bunch of Transformer Blocks. Each block has a attention layer (matmul + softmax), a feed-forward layer (matmul + gelu) and layer normalization.
GELU #
Let’s start with easiest part - GELU.
GELU (Gaussian Error Linear Unit) is a non-linear activation function used in neural networks. It is defined as:
GELU(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
@triton.jit
def gelu_kernel(
x_ptr, # Input pointer
y_ptr, # Output pointer
n_elements, # Total number of elements
BLOCK_SIZE: tl.constexpr # Elements per thread block
):
# 1. Figure out which block/thread we are
pid = tl.program_id(0) # ID of this parallel instance (0, 1, 2, ...)
# 2. Calculate which elements THIS instance handles
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# 3.Load data (with bounds checking - block size can be larger than the array)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
# 4. Do computation
y = 0.5 * x * (1.0 + tl.tanh(0.797885 * (x + 0.044715 * x * x * x)))
# 5. Store results
tl.store(y_ptr + offsets, y, mask=mask)
def triton_gelu(x):
y = torch.empty_like(x)
n_elements = x.numel()
# Choose block size (power of 2, typically 256-1024)
BLOCK_SIZE = 1024
# Define grid: how many blocks do we need?
# If n_elements=5000 and BLOCK_SIZE=1024, we need 5 blocks (5000/1024 = 4.88)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
# Launch kernel
gelu_kernel[grid](x, y, n_elements, BLOCK_SIZE)
return y
Softmax #
@triton.jit
def softmax_kernel(
output_ptr,
input_ptr,
n_rows,
n_cols,
BLOCK_SIZE: tl.constexpr
):
# Each program processes ONE row
row_idx = tl.program_id(0)
# Calculate starting position for this row
row_start_ptr = input_ptr + row_idx * n_cols
# Generate offsets for columns in this row
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row (with masking for columns beyond n_cols)
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Step 1: Find max (for numerical stability)
# Without this, exp(large_number) can overflow
row_max = tl.max(row, axis=0)
# Step 2: Subtract max and exponentiate
row_shifted = row - row_max
row_exp = tl.exp(row_shifted)
# Step 3: Sum across the row
row_sum = tl.sum(row_exp, axis=0)
# Step 4: Normalize (divide by sum)
row_softmax = row_exp / row_sum
# Step 5: Store results
output_row_start_ptr = output_ptr + row_idx * n_cols
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, row_softmax, mask=mask)
Softmax wrapper
Matmul #
This is hardest part so far because:
- we need to process 2D blocks
- we need to accumulate partial results
- we need tiling for efficiency
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K, # Matrix dimensions: A is MxK, B is KxN, C is MxN
stride_am, stride_ak, # Strides for A
stride_bk, stride_bn, # Strides for B
stride_cm, stride_cn, # Strides for C
BLOCK_M: tl.constexpr, # Block size for M dimension
BLOCK_N: tl.constexpr, # Block size for N dimension
BLOCK_K: tl.constexpr # Block size for K dimension (reduction)
):
# 2D grid: each program computes a BLOCK_M x BLOCK_N tile of C
pid_m = tl.program_id(0) # Which row block
pid_n = tl.program_id(1) # Which column block
# Generate offsets for the output block we're computing
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
# Initialize pointers for A and B tiles
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
# Accumulator for the result
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# Loop over K dimension in tiles
for k in range(0, K, BLOCK_K):
# Load tiles from A and B
a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (offs_k[None, :] < K - k), other=0.0)
b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k) & (offs_n[None, :] < N), other=0.0)
# Accumulate: C += A @ B for this tile
acc += tl.dot(a, b)
# Move to next tile in K dimension
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
# Convert accumulator to output type and store
c = acc.to(tl.float16)
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, c, mask=mask)
Matmul wrapper
Layernorm #
Hidden because it's very similar to softmax by nature
Finishing touches #
Now we can build whole LLM with those pieces.
Hard part #
I hope you’ve enjoyed this and are curious to dive deeper into Triton - not just to write kernels, but to understand how it really works under the hood (especially since you could always just use pre-written high-performance kernels).
Compilation pipeline (what happens when you @triton.jit a function) #
Triton parses the decorated Python function into a frontend AST and builds Triton-IR (a machine-independent IR tailored to Triton’s block-level model).
Triton lowers Triton-IR into GPU-specific IR (Triton-GPU / MLIR dialects). Optimizations and hardware-aware transforms happen here (tiling, shared-memory buffering, software pipelining, etc.).
Lowering continues to LLVM IR, which is then used to generate PTX (NVIDIA) or other device code for the target backend (HIP/ROCm, MPS/Metal…).
PTX (or the device binary) is JIT/linked into a device binary (CUBIN) that the runtime can launch. Triton caches compiled variants (specialization).
Because of constexprs and meta parameters, one Python kernel can produce many specialized binaries.
Execution model - kernels, grids, and how they map to GPU hardware #
Triton exposes the kernel at the block/program level: every program instance (the work unit you index with tl.program_id) is analogous to a CUDA thread block / CTA but with Triton’s own abstraction. Internally a program instance is executed cooperatively by a configurable number of threads (warps).
num_warpscontrols how many warps (32-thread groups) are assigned per program instance. e.g.num_warps=8 => each program runs on 8 * 32 = 256 threads.num_stagestells the compiler how many pipeline stages to insert for software-pipelining (helpful for tiled GEMMs).The grid you pass to kernel[grid] determines how many program instances to launch along each program axis; tl.program_id / tl.num_programs give the per-instance index/number.
Why this is magic I’ve talked before: Triton’s block-level model frees you from thinking about individual threads - you reason in blocks/tiles and let the compiler parallelize across threads/warps inside each block. This is what makes Triton kernels compact to write while still being efficient.
Memory model & synchronization #
Triton kernels use the usual GPU memory hierarchy: global (HBM), shared memory (SMEM) (software-managed), and registers. tl.load/tl.store generate memory accesses; the compiler decides placement/tiling. Use masks to make loads/stores predicated (bounds-safe) โ these become predicate instructions or masked loads/stores in generated PTX.
Shared memory / barriers: Triton exposes barrier primitives (e.g.
triton.language.debug_barrier) and the compiler lowers higher-level sync to device barrier ops; some backends also provide lower-level mbarrier ops for fine control. Full global synchronization/block-to-block synchronization is limited - use atomics or communication primitives designed for inter-CTA communication if you need cross-block coordination (and beware of portability/ordering caveats).Predicated/masked IO: masks avoid OOB accesses and are emitted as predicated accesses - cheaper than explicit branches in many cases.
Key compiler optimizations knobs you should know #
tl.constexpr(compile-time constants): every constexpr in the signature produces a specialization. Use for block/tile sizes and loop bounds to let the compiler generate tighter code.num_warps(threads-per-program) andnum_stages(software pipelining) - important for occupancy/latency hiding on newer SM architectures; increasing num_warps increases intra-program parallelism whilenum_stagespipelines loads with compute for tiled reductions/GEMMs.Autotuning:
@triton.autotune+triton.Configlets you supply multiple tile/warps/stage configs and benchmarks them to pick the best for the current shape/device. This is usually easier than hand-tuning every kernel.Tensor ops &
tl.dot/tl.matmul: these intrinsics allow the compiler to target Tensor Cores / MMA instructions when available โ-use them for inner products and small GEMM tiles to get huge speedups.Software pipelining / multi-buffering: the compiler can overlap loads/stores with compute across stages;
num_stagesselects the pipeline depth. Useful in matmul/attention kernels.
tl.debug_barrier #
Let’s stop here and synchronize. There’s plenty more to explore in Triton’s optimization toolbox (just tell me if you want Part 2). I’m also cooking up a matching breakdown for Cute/TileLang, so keep an eye out!