Skip to content

Commit

Permalink
Updates for M3 (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
corsix committed Jan 15, 2024
1 parent 4acaa99 commit e159758
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 70 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
Contemporary M1 / M2 machines from Apple have (at least) four different ways for low-level programmers to perform heavy computations:
Contemporary M1 / M2 / M3 machines from Apple have (at least) four different ways for low-level programmers to perform heavy computations:
1. Standard ARMv8 SIMD/NEON vector instructions on CPU cores (128 bits wide, issue [up to four per cycle on Firestorm](https://dougallj.github.io/applecpu/firestorm-simd.html))
2. Apple's undocumented AMX instructions, issued from CPU, executed on a special accelerator execution unit
3. The Neural Engine (called ANE or NPU)
4. The GPU (e.g. [Metal Compute Shaders](https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu))

This repository is all about the 2<sup>nd</sup> of those: Apple's AMX instructions. Note that these instructions are neither documented nor supported by Apple. As a source of potential great confusion, Apple's AMX instructions are completely distinct from [Intel's AMX instructions](https://en.wikipedia.org/wiki/Advanced_Matrix_Extensions), though both are intended for issuing matrix multiply operations from a CPU.

The research was done on an Apple M1 Max (2021), with follow-up work on an M2 (2023). Older or newer chips might have different AMX instructions. [Some sources](https://nod.ai/comparing-apple-m1-with-amx2-m1-with-neon/) report that the M1 contains version 2 of the AMX instructions, which seems plausible (possibly everything using 7-bit writemasks comes from version 1, and everything using 9-bit writemasks is new in version 2). The M1 to M2 transition [adds bf16 support, along with a few other tweaks](https://github.com/corsix/amx/issues/5#issuecomment-1464639729).
The research was done on an Apple M1 Max (2021), with follow-up work on an M2 (2023), and additional follow-up work on an M3 (2023). Older or newer chips might have different AMX instructions. [Some sources](https://nod.ai/comparing-apple-m1-with-amx2-m1-with-neon/) report that the M1 contains version 2 of the AMX instructions, which seems plausible (possibly everything using 7-bit writemasks comes from version 1, and everything using 9-bit writemasks is new in version 2). The M1 to M2 transition [adds bf16 support, along with a few other tweaks](https://github.com/corsix/amx/issues/5#issuecomment-1464639729). The M2 to M3 transition [adds one extra mode to each of `ldx` and `ldy` and `matint`](https://github.com/corsix/amx/issues/10).

A good one-image summary of AMX is the following figure from [abandoned patent US20180074824A1](https://patents.google.com/patent/US20180074824A1/en). Consider a 32x32 grid of compute units, where each unit can perform 16-bit multiply-accumulate, or a 2x2 subgrid of units can perform 32-bit multiply-accumulate, or a 4x4 subgrid can perform 64-bit multiply-accumulate. To feed this grid, there is a pool of X registers each containing 32 16-bit elements (or 16 32-bit elements, or 8 64-bit elements) and a pool of Y registers similarly containing 32 16-bit elements (or 16 32-bit elements, or 8 64-bit elements). A single instruction can perform a full outer product: multiply every element of an X register with every element of a Y register, and accumulate with the Z element in the corresponding position.

Expand Down
1 change: 1 addition & 0 deletions emulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ typedef __attribute__((aligned(128))) struct amx_state {
extern uint32_t AMX_VER;
#define AMX_VER_M1 1
#define AMX_VER_M2 2
#define AMX_VER_M3 3

// Common helpers:

Expand Down
21 changes: 13 additions & 8 deletions ldst.c
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
#include "emulate.h"
#include <stdio.h>

#define LDST_PAIR (1ull << 62)
#define LDST_PAIR_MEANS_FOUR (1ull << 60)
#define LDST_MULTIPLE (1ull << 62)
#define LDST_NON_CONSECUTIVE (1ull << 61)
#define LDST_MULTIPLE_MEANS_FOUR (1ull << 60)

static void ld_common(amx_reg* regs, uint64_t operand, uint32_t regmask) {
uint32_t rn = (operand >> 56) & regmask;
const uint8_t* src = (uint8_t*)((operand << 8) >> 8);
memcpy(regs + rn, src, 64);
if (operand & LDST_PAIR) {
memcpy(regs + ((rn + 1) & regmask), src + 64, 64);
if ((AMX_VER >= AMX_VER_M2) && (operand & LDST_PAIR_MEANS_FOUR) && (regmask <= 15)) {
memcpy(regs + ((rn + 2) & regmask), src + 128, 64);
memcpy(regs + ((rn + 3) & regmask), src + 192, 64);
if (operand & LDST_MULTIPLE) {
uint32_t rs = 1;
if ((AMX_VER >= AMX_VER_M3) && (operand & LDST_NON_CONSECUTIVE) && (regmask <= 15)) {
rs = (operand & LDST_MULTIPLE_MEANS_FOUR) ? 2 : 4;
}
memcpy(regs + ((rn + rs) & regmask), src + 64, 64);
if ((AMX_VER >= AMX_VER_M2) && (operand & LDST_MULTIPLE_MEANS_FOUR) && (regmask <= 15)) {
memcpy(regs + ((rn + rs*2) & regmask), src + 128, 64);
memcpy(regs + ((rn + rs*3) & regmask), src + 192, 64);
}
}
}
Expand All @@ -21,7 +26,7 @@ static void st_common(const amx_reg* regs, uint64_t operand, uint32_t regmask) {
uint32_t rn = (operand >> 56) & regmask;
uint8_t* dst = (uint8_t*)((operand << 8) >> 8);
memcpy(dst, regs + rn, 64);
if (operand & LDST_PAIR) {
if (operand & LDST_MULTIPLE) {
memcpy(dst + 64, regs + ((rn + 1) & regmask), 64);
}
}
Expand Down
20 changes: 12 additions & 8 deletions ldst.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ For `ldx` / `ldy`:
|---:|---:|---|
|63|1|Ignored|
|62|1|Load multiple registers (`1`) or single register (`0`)|
|61|1|Ignored|
|60|1|On M1: Ignored ("multiple" always means two registers)<br/>On M2: "Multiple" means four registers (`1`) or two registers (`0`)|
|61|1|On M1/M2: Ignored (loads are always to consecutive registers)<br/>On M3: Load to non-consecutive registers (`1`) or to consecutive registers (`0`)|
|60|1|On M1: Ignored ("multiple" always means two registers)<br/>On M2/M3: "Multiple" means four registers (`1`) or two registers (`0`)|
|59|1|Ignored|
|56|3|X / Y register index|
|0|56|Pointer|
Expand Down Expand Up @@ -61,7 +61,7 @@ For `ldzi` / `stzi`:

## Description

Move 64 bytes of data between memory (does not have to be aligned) and an AMX register, or move 128 bytes of data between memory (must be aligned to 128 bytes) and an adjacent pair of AMX registers. On M2, can also move 256 bytes of data from memory to four consecutive X or Y registers.
Move 64 bytes of data between memory (does not have to be aligned) and an AMX register, or move 128 bytes of data between memory (must be aligned to 128 bytes) and an adjacent pair of AMX registers. On M2/M3, can also move 256 bytes of data from memory to four consecutive X or Y registers. On M3, can move 128 or 256 bytes of data from memory to non-consecutive X or Y registers: if bit 61 is set, 128 bytes are moved to registers `n` and `(n+4)%8`, or 256 bytes are moved to registers `n`, `(n+2)%8`, `(n+4)%8`, `(n+6)%8`.

The `ldzi` and `stzi` instructions manipulate _half_ of a _pair_ of Z registers. Viewing the 64 bytes of memory and the 64 bytes of every Z register as vectors of i32 / u32 / f32, the mapping between memory and Z is:

Expand All @@ -87,11 +87,15 @@ void ld_common(amx_reg* regs, uint64_t operand, uint32_t regmask) {
uint32_t rn = (operand >> 56) & regmask;
const uint8_t* src = (uint8_t*)((operand << 8) >> 8);
memcpy(regs + rn, src, 64);
if (operand & LDST_PAIR) {
memcpy(regs + ((rn + 1) & regmask), src + 64, 64);
if ((AMX_VER >= AMX_VER_M2) && (operand & LDST_PAIR_MEANS_FOUR) && (regmask <= 15)) {
memcpy(regs + ((rn + 2) & regmask), src + 128, 64);
memcpy(regs + ((rn + 3) & regmask), src + 192, 64);
if (operand & LDST_MULTIPLE) {
uint32_t rs = 1;
if ((AMX_VER >= AMX_VER_M3) && (operand & LDST_NON_CONSECUTIVE) && (regmask <= 15)) {
rs = (operand & LDST_MULTIPLE_MEANS_FOUR) ? 2 : 4;
}
memcpy(regs + ((rn + rs) & regmask), src + 64, 64);
if ((AMX_VER >= AMX_VER_M2) && (operand & LDST_MULTIPLE_MEANS_FOUR) && (regmask <= 15)) {
memcpy(regs + ((rn + rs*2) & regmask), src + 128, 64);
memcpy(regs + ((rn + rs*3) & regmask), src + 192, 64);
}
}
}
Expand Down
52 changes: 27 additions & 25 deletions matint.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ void emulate_AMX_MATINT(amx_state* state, uint64_t operand) {
int alumode = (operand & MATINT_INDEXED_LOAD) ? (operand & (1ull << 54)) >> 51 : (operand >> 47) & 0x3f;
uint32_t shift = (operand >> 58) & 0x1f;

uint32_t xybits = 0, zbits, satbits;
uint32_t xbits = 0, ybits = 0, zbits, satbits;
if (alumode == 4) {
switch ((operand >> 42) & 0xf) {
case 3: zbits = 32; satbits = 16; break;
Expand All @@ -37,27 +37,29 @@ void emulate_AMX_MATINT(amx_state* state, uint64_t operand) {
default: zbits = 16; satbits = 16; break;
}
} else if (alumode == 5 || alumode == 6) {
xybits = 16; zbits = 16;
xbits = ybits = 16; zbits = 16;
shift = 15;
} else if (alumode == 8) {
switch ((operand >> 42) & 0xf) {
case 10: xybits = 8; zbits = 32; break;
default: xybits = 8; zbits = 16; break;
case 10: xbits = ybits = 8; zbits = 32; break;
case 12: xbits = 8; if (AMX_VER >= AMX_VER_M3) { ybits = 16; zbits = 32; } else { ybits = 8; zbits = 16; } break;
default: xbits = ybits = 8; zbits = 16; break;
}
} else if (alumode == 9) {
switch ((operand >> 42) & 0xf) {
case 3: xybits = 16; zbits = 32; break;
case 4: xybits = 32; zbits = 32; break;
default: xybits = 16; zbits = 16; break;
case 3: xbits = ybits = 16; zbits = 32; break;
case 4: xbits = ybits = 32; zbits = 32; break;
default: xbits = ybits = 16; zbits = 16; break;
}
shift = 64 - xybits; // Not actually used as a shift
shift = 64 - xbits; // Not actually used as a shift
} else {
switch ((operand >> 42) & 0xf) {
case 3: xybits = 16; zbits = 32; break;
default: xybits = 16; zbits = 16; break;
case 3: xbits = ybits = 16; zbits = 32; break;
default: xbits = ybits = 16; zbits = 16; break;
}
}
uint32_t xybytes = xybits / 8;
uint32_t xbytes = xbits / 8;
uint32_t ybytes = ybits / 8;
uint32_t zbytes = zbits / 8;

if (alumode == 4) {
Expand Down Expand Up @@ -87,13 +89,13 @@ void emulate_AMX_MATINT(amx_state* state, uint64_t operand) {
uint32_t src_reg = (operand >> 49) & 7;
uint32_t ibits = (operand & MATINT_INDEXED_LOAD_4BIT) ? 4 : 2;
if (operand & MATINT_INDEXED_LOAD_Y) {
load_xy_reg_indexed(y, state->y[src_reg].u8, ibits, xybits);
load_xy_reg_indexed(y, state->y[src_reg].u8, ibits, ybits);
} else {
load_xy_reg_indexed(x, state->x[src_reg].u8, ibits, xybits);
load_xy_reg_indexed(x, state->x[src_reg].u8, ibits, xbits);
}
}
xy_shuffle(x, (operand >> 29) & 3, xybytes);
xy_shuffle(y, (operand >> 27) & 3, xybytes);
xy_shuffle(x, (operand >> 29) & 3, xbytes);
xy_shuffle(y, (operand >> 27) & 3, ybytes);

// z = z +/- (f(x, y) >> s) for f being * or + or weird xor/popcnt thing
// z = sat_i16(z +/- (f(x, y) >> 16)) for f being SQRDMLAH / SQRDMLSH
Expand All @@ -104,9 +106,9 @@ void emulate_AMX_MATINT(amx_state* state, uint64_t operand) {
uint64_t x_enable, y_enable;
if (operand & MATINT_ENABLE_MASK_IS_Y) {
x_enable = ~(uint64_t)0;
y_enable = parse_writemask(operand >> 32, xybytes, 9);
y_enable = parse_writemask(operand >> 32, ybytes, 9);
} else {
x_enable = parse_writemask(operand >> 32, xybytes, 9);
x_enable = parse_writemask(operand >> 32, xbytes, 9);
y_enable = ~(uint64_t)0;
}
if (((operand >> (32+6)) & 7) == 0) {
Expand All @@ -116,18 +118,18 @@ void emulate_AMX_MATINT(amx_state* state, uint64_t operand) {
}
}

uint32_t xsignext = (operand & MATINT_SIGNED_X) ? (64 - xybits) : 0;
uint32_t ysignext = (operand & MATINT_SIGNED_Y) ? (64 - xybits) : 0;
uint32_t xsignext = (operand & MATINT_SIGNED_X) ? (64 - xbits) : 0;
uint32_t ysignext = (operand & MATINT_SIGNED_Y) ? (64 - ybits) : 0;
uint32_t zsignext = 64 - zbits;
uint32_t zmask = (zbytes / xybytes) - 1;
uint32_t step = xybytes == 1 ? zbytes : xybytes;
uint32_t zmask = (zbytes / xbytes) - 1;
uint32_t step = xbytes == 1 ? zbytes : xbytes;
for (uint32_t j = 0; j < 64; j += step) {
if (!((y_enable >> j) & 1)) continue;
for (uint32_t i = 0; i < 64; i += xybytes) {
for (uint32_t i = 0; i < 64; i += xbytes) {
if (!((x_enable >> i) & 1)) continue;
int64_t xv = load_int(x + i, xybytes, xsignext);
int64_t yv = load_int(y + j, xybytes, ysignext);
void* z = &state->z[bit_select(bit_select(j, z_row, xybytes - 1), i / xybytes, zmask)].u8[i & -zbytes];
int64_t xv = load_int(x + i, xbytes, xsignext);
int64_t yv = load_int(y + j, ybytes, ysignext);
void* z = &state->z[bit_select(bit_select(j, z_row, xbytes - 1), i / xbytes, zmask)].u8[i & -zbytes];
int64_t zv = load_int(z, zbytes, zsignext);
int64_t result = vecint_alu(xv, yv, zv, alumode, shift) & omask;
store_int(z, zbytes, result);
Expand Down
62 changes: 37 additions & 25 deletions matint.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ When ALU mode = 8, lane width modes:
|X|Y|Z|42|
|---|---|---|---|
|i8 or u8|i8 or u8 (only every fourth lane is used, said lanes are used four times each)|i32 or u32 (all rows)|`10`|
|i8 or u8|i16 or u16 (only even lanes used, said lanes are used four times each)|i32 or u32 (all rows)|`12` (M3 only)|
|i8 or u8|i8 or u8 (only even lanes used, said lanes are used twice each)|i16 or u16 (all rows)|anything else|

When ALU mode = 9, lane width modes:
Expand Down Expand Up @@ -134,6 +135,15 @@ For 16-bit Z, each 2 by 2 block of bytes ends up looking like:
<tr><td>Y<sub>1</sub></td><td colspan="2">Z<sub>1,0:1</sub> += X<sub>1</sub> × Y<sub>0</sub></td>
</table>

When 47=8, M3 also supports 8-bit X, 16-bit Y, 32-bit Z, with each 4 by 4 block of bytes looking like:

<table><tr><td/><td>X<sub>0</sub></td><td>X<sub>1</sub></td><td>X<sub>2</sub></td><td>X<sub>3</sub></td></tr>
<tr><td>Y<sub>0</sub></td><td colspan="4">Z<sub>0,0:3</sub> += X<sub>0</sub> × Y<sub>0:1</sub></tr>
<tr><td>Y<sub>1</sub></td><td colspan="4">Z<sub>1,0:3</sub> += X<sub>1</sub> × Y<sub>0:1</sub></td>
<tr><td>Y<sub>2</sub></td><td colspan="4">Z<sub>2,0:3</sub> += X<sub>2</sub> × Y<sub>0:1</sub></tr>
<tr><td>Y<sub>3</sub></td><td colspan="4">Z<sub>3,0:3</sub> += X<sub>3</sub> × Y<sub>0:1</sub></td>
</table>

## Emulation code

See [matint.c](matint.c), and [vecint.c](vecint.c) for the shared ALU.
Expand All @@ -153,7 +163,7 @@ void emulate_AMX_MATINT(amx_state* state, uint64_t operand) {
int alumode = (operand & MATINT_INDEXED_LOAD) ? (operand & (1ull << 54)) >> 51 : (operand >> 47) & 0x3f;
uint32_t shift = (operand >> 58) & 0x1f;

uint32_t xybits = 0, zbits, satbits;
uint32_t xbits = 0, ybits = 0, zbits, satbits;
if (alumode == 4) {
switch ((operand >> 42) & 0xf) {
case 3: zbits = 32; satbits = 16; break;
Expand All @@ -163,27 +173,29 @@ void emulate_AMX_MATINT(amx_state* state, uint64_t operand) {
default: zbits = 16; satbits = 16; break;
}
} else if (alumode == 5 || alumode == 6) {
xybits = 16; zbits = 16;
xbits = ybits = 16; zbits = 16;
shift = 15;
} else if (alumode == 8) {
switch ((operand >> 42) & 0xf) {
case 10: xybits = 8; zbits = 32; break;
default: xybits = 8; zbits = 16; break;
case 10: xbits = ybits = 8; zbits = 32; break;
case 12: xbits = 8; if (AMX_VER >= AMX_VER_M3) { ybits = 16; zbits = 32; } else { ybits = 8; zbits = 16; } break;
default: xbits = ybits = 8; zbits = 16; break;
}
} else if (alumode == 9) {
switch ((operand >> 42) & 0xf) {
case 3: xybits = 16; zbits = 32; break;
case 4: xybits = 32; zbits = 32; break;
default: xybits = 16; zbits = 16; break;
case 3: xbits = ybits = 16; zbits = 32; break;
case 4: xbits = ybits = 32; zbits = 32; break;
default: xbits = ybits = 16; zbits = 16; break;
}
shift = 64 - xybits; // Not actually used as a shift
shift = 64 - xbits; // Not actually used as a shift
} else {
switch ((operand >> 42) & 0xf) {
case 3: xybits = 16; zbits = 32; break;
default: xybits = 16; zbits = 16; break;
case 3: xbits = ybits = 16; zbits = 32; break;
default: xbits = ybits = 16; zbits = 16; break;
}
}
uint32_t xybytes = xybits / 8;
uint32_t xbytes = xbits / 8;
uint32_t ybytes = ybits / 8;
uint32_t zbytes = zbits / 8;

if (alumode == 4) {
Expand All @@ -201,20 +213,20 @@ void emulate_AMX_MATINT(amx_state* state, uint64_t operand) {
uint32_t src_reg = (operand >> 49) & 7;
uint32_t ibits = (operand & MATINT_INDEXED_LOAD_4BIT) ? 4 : 2;
if (operand & MATINT_INDEXED_LOAD_Y) {
load_xy_reg_indexed(y, state->y[src_reg].u8, ibits, xybits);
load_xy_reg_indexed(y, state->y[src_reg].u8, ibits, ybits);
} else {
load_xy_reg_indexed(x, state->x[src_reg].u8, ibits, xybits);
load_xy_reg_indexed(x, state->x[src_reg].u8, ibits, xbits);
}
}
xy_shuffle(x, (operand >> 29) & 3, xybytes);
xy_shuffle(y, (operand >> 27) & 3, xybytes);
xy_shuffle(x, (operand >> 29) & 3, xbytes);
xy_shuffle(y, (operand >> 27) & 3, ybytes);

uint64_t x_enable, y_enable;
if (operand & MATINT_ENABLE_MASK_IS_Y) {
x_enable = ~(uint64_t)0;
y_enable = parse_writemask(operand >> 32, xybytes, 9);
y_enable = parse_writemask(operand >> 32, ybytes, 9);
} else {
x_enable = parse_writemask(operand >> 32, xybytes, 9);
x_enable = parse_writemask(operand >> 32, xbytes, 9);
y_enable = ~(uint64_t)0;
}
if (((operand >> (32+6)) & 7) == 0) {
Expand All @@ -224,18 +236,18 @@ void emulate_AMX_MATINT(amx_state* state, uint64_t operand) {
}
}

uint32_t xsignext = (operand & MATINT_SIGNED_X) ? (64 - xybits) : 0;
uint32_t ysignext = (operand & MATINT_SIGNED_Y) ? (64 - xybits) : 0;
uint32_t xsignext = (operand & MATINT_SIGNED_X) ? (64 - xbits) : 0;
uint32_t ysignext = (operand & MATINT_SIGNED_Y) ? (64 - ybits) : 0;
uint32_t zsignext = 64 - zbits;
uint32_t zmask = (zbytes / xybytes) - 1;
uint32_t step = xybytes == 1 ? zbytes : xybytes;
uint32_t zmask = (zbytes / xbytes) - 1;
uint32_t step = xbytes == 1 ? zbytes : xbytes;
for (uint32_t j = 0; j < 64; j += step) {
if (!((y_enable >> j) & 1)) continue;
for (uint32_t i = 0; i < 64; i += xybytes) {
for (uint32_t i = 0; i < 64; i += xbytes) {
if (!((x_enable >> i) & 1)) continue;
int64_t xv = load_int(x + i, xybytes, xsignext);
int64_t yv = load_int(y + j, xybytes, ysignext);
void* z = &state->z[bit_select(bit_select(j, z_row, xybytes - 1), i / xybytes, zmask)].u8[i & -zbytes];
int64_t xv = load_int(x + i, xbytes, xsignext);
int64_t yv = load_int(y + j, ybytes, ysignext);
void* z = &state->z[bit_select(bit_select(j, z_row, xbytes - 1), i / xbytes, zmask)].u8[i & -zbytes];
int64_t zv = load_int(z, zbytes, zsignext);
int64_t result = vecint_alu(xv, yv, zv, alumode, shift) & omask;
store_int(z, zbytes, result);
Expand Down
5 changes: 3 additions & 2 deletions test.c
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,13 @@ uint32_t AMX_VER;

static uint32_t detect_amx_hardware_version() {
__attribute__((aligned(256))) uint8_t buf[256];
buf[64] = 2;
buf[128] = 1;
AMX_SET(); // Set x[0:8] to zero
AMX_LDX(PTR_ROW_FLAGS(buf, 16, 1)); // On M1: copy buf[0:128] to x[0:2], on M2: copy buf[0:256] to x[0:4]
AMX_LDX(PTR_ROW_FLAGS(buf, 48, 1)); // On M1: copy buf[0:128] to x[0,1], on M2: copy buf[0:256] to x[0,1,2,3], on M3: copy buf[0:256] to x[0,2,4,6]
AMX_STX(PTR_ROW_FLAGS(buf, 2, 0)); // Copy x[2] to buf[0:64]
AMX_CLR();
return buf[0] == 1 ? AMX_VER_M2 : AMX_VER_M1;
return 1 + buf[0];
}

int main() {
Expand Down

0 comments on commit e159758

Please sign in to comment.