单核M1 CPU上实现FP32 1.5 TFlops算力?这是一份代码指南
选自jott.live
机器之心编译
编辑:悉闲、蛋酱
需要注意的是:如果你打算训练大型神经网络,那么就可以忽略这篇文章的内容了,因为它比 A100(156TFlops)慢 100 倍。
首先,这是在电池供电的单核 MacBook Air 2020 上运行; 其次,这会以每条指令约 0.5 纳秒的延迟运行。
void reference_16x16xK (float *A, float *B, float *C, uint64_t K) {
for (uint32_t m = 0; m < 16; ++m) {
for (uint32_t n = 0; n < 16; ++n) {
C [n * 16 + m] = 0;
for (uint32_t k = 0; k < K; ++k) {
C [n * 16 + m] += A [k * 16 + m] * B [k * 16 + n];
}
}
}
}
/only set for k == 0uint64_t reset_z = 1ull << 27;
for (uint32_t k = 0; k < K; ++k) {
uint64_t idx = k % 4;
// 64 bytes = 16 floats
AMX_LDX ((uint64_t) A + k * 64);
AMX_LDY ((uint64_t) B + k * 64);
//now we do 4 indepedent outer products (avoiding pipeline hazards)
AMX_FMA32 (reset_z);
reset_z = 0;
}
for (uint64_t i = 0; i < 16; ++i) {
const uint64_t z_register = (i * 4ull) << 56;
AMX_STZ (z_register | (uint64_t) C + i * 64);
}
void mm32x32xK (float* A, float* B, float* C, uint64_t K) {
//flag to load/store 128 bytes
const uint64_t load_store_2 = 1ull << 62;
const uint64_t load_store_width = 128; //in bytes
//only set for k == 0
uint64_t reset_z = 1ull << 27;
for (uint32_t k = 0; k < K; ++k) {
uint64_t idx = k % 4;
//load to X, Y (skipping every other index because we're loading 128 bytes)
AMX_LDX (load_store_2 | (idx * 2) << 56 | (uint64_t) A + k * load_store_width);
AMX_LDY (load_store_2 | (idx * 2) << 56 | (uint64_t) B + k * load_store_width);
//offset into X and Y registers is byte-wise
const uint64_t offset = idx * load_store_width;
//now we do 4 indepedent outer products (avoiding pipeline hazards)
AMX_FMA32 (reset_z | (0ull << 20) | ((offset + 0ull) << 10) | ((offset + 0ull) << 0));
AMX_FMA32 (reset_z | (1ull << 20) | ((offset + 64ull) << 10) | ((offset + 0ull) << 0));
AMX_FMA32 (reset_z | (2ull << 20) | ((offset + 0ull) << 10) | ((offset + 64ull) << 0));
AMX_FMA32 (reset_z | (3ull << 20) | ((offset + 64ull) << 10) | ((offset + 64ull) << 0));
reset_z = 0;
}
for (uint64_t i = 0; i < 16; ++i) {
//store interleaved
AMX_STZ (load_store_2 | ((i * 4ull + 0) << 56) | (uint64_t) C + i * load_store_width);
AMX_STZ (load_store_2 | ((i * 4ull + 2) << 56) | (uint64_t) C + (16 + i) * load_store_width);
}
}
加载和存储标志:https://github.com/corsix/amx/blob/main/ldst.md FMA 标志:https://github.com/corsix/amx/blob/main/fma.md
© THE END
转载请联系本公众号获得授权
投稿或寻求报道:[email protected]
微信扫码关注该文公众号作者
戳这里提交新闻线索和高质量文章给我们。
来源: qq
点击查看作者最近其他文章