WGMMA Programming
Hopper (H100) introduced a new tensor core instruction called Warp Group MMA (WGMMA). The regular Matrix Multiply-Accumulate (MMA) tensor core instructions are warp scoped and map onto a single subcore of a Streaming Multiprocessor (SM). An SM is divided into 4 subcores. Instead of using a single warp, a WGMMA instruction uses 4 warps (ie 128 threads) and maps onto an entire SM. It has much bigger tile sizes than the regular MMA instruction.
- WGMMA executes the operation
D = A*B + C
whereC
can be disabled to only do matrix multiplication ofA
&B
. - For WGMMA instructions the matrix tile
A
can be registers or shared memory butB
has to be in shared memory. It's important to know how values are arranged in shared memory for wgmma to correctly compute the operation.
Lets look at example m64n8k16
(bf16) with both A
& B
in shared memory.
Shared memory has the advantage that Tensor Memory Accelerator (TMA) can be
used to move data from global to shared memory.
So A
matrix is of the shape 64x16
while B
is of the shape 16x8
:
Here 8x8
is referred to as the core matrix. Each of the core matrix can only
have 8 rows or columns and row size being 16 bytes. So our matrix A
has 8x2
⇒ 16
core matrices.
Lets consider the example where matrix A
is row major and elements start from
0 to 1023 (increasing by 1).
Logically this would be represented like this:
But shared memory for A
is like a 1-D contiguous space in memory.
From 0 to 511 elements it would look like this:
Now from 512 to 1023 elements it would be:
Now lets look how B
matrix is arranged in shared memory. Lets assume the
default case where B
matrix is in column major order and it also contains
values starting from 0 to 127. (NT is the BLAS notation). Here are the various
representations:
So B
matrix is of the shape 16x8
. So it has 2 core matrices:
So its gonna be like this:
Where 0 is the first element in shared memory, 8 is the next, then 16, then 24, after 56 it goes to 1. After 63 it goes to the next core matrix and 64 (repeating the same way).
The result will be stored in the registers. It will be of the form:
Here T0
is thread 0 storing the first 2 final values in its 2 registers.
Every thread will have 4 registers.
Here is the CUDA implementation of the this example:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
|
Now lets consider the case when B
is also row major—ie transposed (denoted by
B'
). Two things would need to changed for it to work correctly. First is
there are 2 flags the WGMMA instruction takes A'
and B'
, by default they
are set to 0 which indicates A
is row major and B
is column major. Now if
we want to make B
as row major B'
has to be set to 1. Next thing to change
is how B
matrix values are mapped onto shared memory. For the example
discussed above it becomes B = [1 2 3 4 5 …. ]
, it just maps onto the thread
ids.