CUTLASS 是 CUDA C++ 模板抽象的集合,用于在 CUDA 内的所有级别和规模上实现高性能矩阵-矩阵乘法 (GEMM) 和相关计算。它采用了类似于 cuBLAS 和 cuDNN 中实现的分层分解和数据移动策略。
CUTLASS 最新版本为3.3,相比1.3.3变动较大。然而重温一下1.3.3仍然是有意义的。因为它更易于理解:
- 与 PROGRAMMING TENSOR CORES: NATIVE VOLTA TENSOR CORES WITH CUTLASS 中介绍的内容相匹配;
- 仅支持 Volta 一种 Tensor Core 架构;
- Tensor Core 仅支持 half 一种数据类型;
- 仅采用
HMMA.884.F16.F16
一种指令。
Demystifying Tensor Cores to Optimize Half-Precision Matrix Multiply 中提到 T4 GPU 在引入 Tensor Core 之后,原来重计算瓶颈的 GEMM 也变成了 IO 瓶颈。虽然 V100的带宽是 T4的三倍,然而带宽不足问题同样存在。因此,CUTLASS 对于数据路径进行了如下优化:
- 全路径128 bit 的访问粒度:
LDG.128
、STS.128
、LDS.128
、STD.128
; - 无冲突共享内存排列:转置时无需填充 Shared Memory;
- Software Pipelining:
LDG.128
、LDS.128
和HMMA.884.F16.F16
三种指令并行,隐藏数据移动。
下面以一个矩阵乘测例为例,介绍 Volta884_h884gemm 的实现。
TEST(Volta884_h884gemm_128x64x32_nt, 520x264x136)
OutputTile
即 threadblock tile,该测例下设置为32x64x128。WarpGemmShape
为32x64x64,这个是固定值。
run_gemm 初始化 Volta884GemmTraits::Params 和 GemmTestbed,调用 Gemm::launch 运行后比对结果。
TEST(Volta884_h884gemm_64x64x32_nt, 520x264x136) {typedef cutlass::gemm::Volta884GemmTraits<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,cutlass::Shape<32, 64, 128>,cutlass::Shape<32, 64, 64>,half,half,half,2> GemmTraits;run_gemm<GemmTraits>(520, 264, 136);
}
CUTLASS 中 Volta884实现的层次结构如下图所示
gemm_kernel_nolb
Kernel 函数申请动态 Shared Memory,并传递给 GemmMainloop,然后调用 GemmMainloop::multiply_add 进行计算。
/// GEMM kernel without launch bounds specified
template <typename Gemm_>
__global__ /* __launch_bounds__(Gemm_::kThreads) */
void gemm_kernel_nolb(typename Gemm_::Params params) {// Dynamic shared memory base pointerextern __shared__ int GemmSharedStorageBase[];// Declare pointer to dynamic shared memory.typename Gemm_::SharedStorage *shared_storage = reinterpret_cast<typename Gemm_::SharedStorage *>(GemmSharedStorageBase);// Construct the GEMM object.Gemm_ gemm(params, *shared_storage);// Run GEMM.gemm.multiply_add();
}
GemmMainloop
GemmMainloop 实现了软流水,如下图所示:
Shared Memory 和寄存器需要两个缓冲区,通过 SM 上的调度实现三条流水线并行。Global Memory 到 Shared Memory 的加载有同步,而从 Shared Memory 移动到寄存器时不需要同步。由于 Ampere 之前的架构不支持 Global Memory 到 Shared Memory 的直接拷贝,因此整个搬运过程比较复杂。如下图所示,程序中多处调用 Copy::transform 函数生成transformed_fragment
。原因应该是为了实现类型转换,但 Volta 只支持 half,也就没有实际作用。
template <typename Traits_>
struct GemmMainloop {//// Type definitions///// The traits.typedef Traits_ Traits;/// The GEMM mainlooptypedef typename Traits::KernelClass KernelClass;/// The shared storage.typedef typename Traits::SharedStorage SharedStorage;/// The scalar for A.typedef typename Traits::ScalarA ScalarA;/// The scalar for B.typedef typename Traits::ScalarB ScalarB;/// The scalar in the epilogue.typedef typename Traits::Epilogue::Scalar ScalarEpilogue;/// The scalar for C.typedef typename Traits::Epilogue::ScalarC ScalarC;/// The scalar for D.typedef typename Traits::Epilogue::ScalarD ScalarD;/// The index.typedef typename Traits::Index Index;/// Define the mainloop iteration sizetypedef typename Traits::MultiplyAdd MultiplyAdd;/// The number of threads.static int const kThreads = Traits::GemmConfig::kThreads;
AccumulatorsPerWarp
为 GemmConfig::AccumulatorsPerWarp 即 Volta884MultiplyAdd::WarpGemmShape,为32x64x64。
Volta884MultiplyAdd::InstructionShape 为4x32x32。因此,kWarpGemmSteps
为8。
// Number of warp-level multiply-accumulate steps executed by each warp.static Index const kWarpGemmSteps =Traits::GemmConfig::AccumulatorsPerWarp::kD / MultiplyAdd::InstructionShape::kD;/*// Make sure we have at least 2 unrolling steps or our pipeling is not going to work.static_assert(kWarpGemmSteps >= 2, "The pipelining assumes at least two steps");*//// Use the params object defined in traitstypedef typename Traits::Params Params;//// Data members///// The params.Params const& params;/// SharedStorage objectSharedStorage& shared_storage;
//// Methods///// Ctor.CUTLASS_DEVICE GemmMainloop(Params const& params_, SharedStorage& shared_storage_): params(params_), shared_storage(shared_storage_) {}
GemmMainloop::fetch_global
Volta884GemmTraits::GlobalLoadStream 即 GlobalLoadStreamPair 类型。
GlobalLoadStreamPair::residue 函数调用两次 MMAGlobalLoadStream::residue,计算在线程块 tile 最后一次加载所需的预测掩码。
GlobalLoadStreamPair::copy 函数调用两次 MMAGlobalLoadStream::copy 从 Global Memory 拷贝矩阵元素到寄存器。后者调用 TileLoadIterator::load_post_increment 函数。
/// Fetches global stream pairtemplate <bool Residue>CUTLASS_DEVICE void fetch_global(typename Traits::GlobalLoadStream& global_to_shared_stream,Index outer_k) {// If residue portion and not calculating residue in prolog, update residue predicates now.if (Residue) {global_to_shared_stream.residue(outer_k);}global_to_shared_stream.copy();}
GemmMainloop::consume_tile
如果kWarpGemmSteps
小于等于4,则为kGlobalStreamFirst
,先从 Global Memory 加载下一次迭代的数据。
/// Computes a warp-level GEMM on data held in shared memorytemplate <bool Residue, bool LastIteration>CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream& global_to_shared_stream,typename Traits::SharedStream& shared_load_stream,typename MultiplyAdd::Accumulators& accumulators,Index outer_k) {// Whether to load global stream before loading shared streamconst bool kGlobalStreamFirst = (kWarpGemmSteps <= 4);// Load data for the next iteration of the main loop (unless it's the last iteration).if (kGlobalStreamFirst && !LastIteration) {fetch_global<Residue>(global_to_shared_stream, outer_k);}
首先从 Shared Memory 加载下一次迭代的输入。拥有双缓冲区。
MMASharedLoadStream::copy 调用 Volta884WarpMultiplicandLoadIterator::load 函数加载数据到寄存器中。
问题是前一步如果没有调用 GemmMainloop::fetch_global,从 Shared Memory 拷贝不会有问题吗?
CUTLASS_PRAGMA_UNROLLfor (int step = 0; step < kWarpGemmSteps; ++step) {// Trigger the copy from shared memory for the next A/B values.shared_load_stream.copy((step + 1) % kWarpGemmSteps);
如果不是kGlobalStreamFirst
, 在循环的第一步时调用GemmMainloop::fetch_global 函数加载输入。
// Load data for the next iteration of the main loop (unless it's the last iteration).if (!kGlobalStreamFirst && (step == 0) && !LastIteration) {fetch_global<Residue>(global_to_shared_stream, outer_k);}
如果是倒数第2步,需要确保数据已经加载到了 Shared Memory。
Volta884GemmTraits::shared_load_fence 根据外部传入的StageCount
来确定是否同步线程。
GlobalLoadStreamPair::commit 函数会分别调用两个矩阵的 GlobalLoadStream::commit 拷贝到 Shared Memory。
Volta884GemmTraits::shared_store_fence 同步线程。
MMASharedLoadStream::inc_stage 递增stage_index
。
if (step == kWarpGemmSteps - 2) {// Make sure the data from shared memory has been entirely consumed.Traits::shared_load_fence(true);global_to_shared_stream.commit();// Make sure the data is in shared memory.Traits::shared_store_fence(true);// Move to the next stage for the load (if it makes sense).shared_load_stream.inc_stage();}
MMASharedLoadStream::commit 调用 Copy 进行拷贝。Volta884WarpMultiplicandLoadIterator::Fragment 即 Fragment 。
Volta884MultiplyAdd::multiply_add 完成 Warp Tile 的计算。
// Make sure the values are available for the current iteration to do the multiply-add.shared_load_stream.commit(step);// Do the math on the fragments of the current iteration.MultiplyAdd multiply_add;multiply_add.multiply_add(shared_load_stream.fragment_a(step),shared_load_stream.fragment_b(step),accumulators,accumulators);}}
GemmMainloop::multiply_add
make_Coord_from_shape 根据形状创建一个 Coord 对象。
IdentityBlockSwizzle::get_threadblock_offset 获得当前线程块在输出二维图上的偏移。
Volta884GemmTraits::ClearAccumulators 即 ClearAccumulators。
IdentityBlockSwizzle::get_threadblock_bounds 返回 threadblock 的三维边界。
/// Do the GEMM.CUTLASS_DEVICE void multiply_add() {// Swizzle the IDs of the block (to enable better cache behavior).typename Traits::BlockSwizzle block_swizzle;Coord<3> threadblock_offset =block_swizzle.get_threadblock_offset(make_Coord_from_shape<typename Traits::OutputTile>());// We may want to use shared memory to clear the registers.typedef typename Traits::ClearAccumulators ClearAccumulators;// Get the bounds for each thread, it maybe different than problem_sizeCoord<3> bounds = block_swizzle.get_threadblock_bounds(params.problem_size,params.partitionK_range);
params.global_to_shared_stream
即 GlobalLoadStreamPair::Params。
shared_storage.main_loop.global_to_shared_stream
为 GlobalLoadStreamPair::SharedStorage。
shared_storage.main_loop.threadblock_tile
为 GlobalLoadStreamPair::ThreadblockTileStorage,即 ZipTileAllocation。ZipTileAllocation::reference 返回指向数据的 ZipTensorRef 对象。
global_to_shared_stream
为 Volta884GemmTraits::GlobalLoadStream 即 GlobalLoadStreamPair。
GlobalLoadStreamPair::add_batch_offset 调用 GlobalLoadStreamPair::add_batch_offset GlobalLoadStream::add_batch_offset 函数设置迭代器的 batch 偏移。
// The streams to read A/B from global memory to shared memory.typename Traits::GlobalLoadStream global_to_shared_stream(params.global_to_shared_stream,shared_storage.main_loop.global_to_shared_stream,shared_storage.main_loop.threadblock_tile.reference(),bounds,threadblock_offset);// update A and B pointer offset based on batch_id and batch_stride_offsetglobal_to_shared_stream.add_batch_offset(block_swizzle.get_batch_id());// Create the accumulator clear.ClearAccumulators clear;
GlobalLoadStreamPair::move_to_residue 如果是在序幕中执行余数则调用 MMAGlobalLoadStream::move_to_residue 移动指针,否则直接调用 GlobalLoadStreamPair::residue 函数。
GlobalLoadStreamPair::copy 调用 MMAGlobalLoadStream::copy 函数,后者调用 TileLoadIterator::load_post_increment 加载 A 和 B 矩阵的片段到 Fragment 寄存器。
GlobalLoadStreamPair::commit 调用 MMAGlobalLoadStream::commit 函数,后者调用 Copy.transform 进行拷贝,然后调用
Volta884ThreadblockMultiplicandStoreIterator::store_post_increment 保存到 Shared Memory。
Volta884GemmTraits::shared_store_fence 同步 threadblock 内的线程。
GlobalLoadStreamPair::rollback 调用 MMAGlobalLoadStream::rollback 函数,后者调用 TileLoadIterator::initialize_predicates 初始化预测向量,然后移动偏移。
// Deal with residue in prolog.// global_to_shared_stream.move_to_residue(params.problem_size[0], Traits::OutputTile::kD);global_to_shared_stream.move_to_residue(bounds[0], Traits::OutputTile::kD);// Fetch the fragments for A and B from global memory.global_to_shared_stream.copy();// Copy the elements to shared memory (after transformation if needed).global_to_shared_stream.commit();// Make sure the data is in shared memory.Traits::shared_store_fence(false);// Rollback to the beginning of the first tile (if residue exists).// global_to_shared_stream.rollback(params.problem_size[0] % Traits::OutputTile::kD);global_to_shared_stream.rollback(bounds[0] % Traits::OutputTile::kD);
shared_load_stream
为 Volta884GemmTraits::SharedStream 类型,即 SharedStreamPair。
SharedStreamPair::copy 调用 MMASharedLoadStream::copy,后者调用 Volta884WarpMultiplicandLoadIterator::load 从 Shared Memory 加载。
accumulators
为 Volta884MultiplyAdd::Accumulators 类型,即 Fragment。
ClearAccumulators::clear 调用 Fragment::clear 将存储清零。
outer_k
是什么?
// The stream of data from shared memory to fragments.typename Traits::SharedStream shared_load_stream(params.shared_stream,shared_storage.main_loop.threadblock_tile.reference());// Trigger the copy from shared memory for the 1st stream.shared_load_stream.copy(0);// Allocate the accumulators.typename MultiplyAdd::Accumulators accumulators;// Clear the accumulators.clear.clear(accumulators);// Initial index// Index outer_k = params.problem_size[0] - Traits::OutputTile::kD;// problem_size[0] might be bigger than bounds[0]Index outer_k = bounds[0] - Traits::OutputTile::kD;
如果在序幕中计算了剩余,则仅最后一次处理余数。
GemmMainloop::consume_tile 计算k = Traits::OutputTile::kD
的分块。
// Check if we are computing residue in prolog or not.if (Traits::GemmConfig::kResidueInProlog) {// Execute all mainloop iterations but the last one.CUTLASS_GEMM_LOOPfor (; outer_k > 0; outer_k -= Traits::OutputTile::kD) {CUTLASS_GEMM_LOOP_HEADERconsume_tile<false, false>(global_to_shared_stream, shared_load_stream, accumulators, outer_k);}consume_tile<false, true>(global_to_shared_stream, shared_load_stream, accumulators, outer_k);
否则,每次迭代都考虑余数。
} else {// When kResidueSeparate = true, execute all mainloop iterations but the last two without any// consideration for K-residue or predicate updates. This improves the steady state of some// kernels.if (Traits::GemmConfig::kResidueSeparate) {CUTLASS_GEMM_LOOPfor (; outer_k > Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {CUTLASS_GEMM_LOOP_HEADERconsume_tile<false, false>(global_to_shared_stream, shared_load_stream, accumulators, outer_k);}}// Execute remaining tiles with K-residue predicate updates enabled.CUTLASS_GEMM_LOOPfor (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {CUTLASS_GEMM_LOOP_HEADERconsume_tile<true, false>(global_to_shared_stream, shared_load_stream, accumulators, outer_k);}}
创建 MMAEpilogue 对象,然后调用 MMAEpilogue::epilogue 函数。
typedef typename Traits::Epilogue Epilogue;Epilogue epilogue(params.epilogue, shared_storage.epilogue, params.problem_size.knm());epilogue.epilogue(accumulators, threadblock_offset, block_swizzle.get_batch_id());}
};
参考资料:
- # [DOC] Where does cutlass’ detailed GEMM kernel? #526
- Dissecting the NVIDIA Volta GPU Architecture via Microbenchmarking
- Modeling Deep Learning Accelerator Enabled GPUs
- gpgpu-sim_distribution
- 理解Tensor Core
- Flexible Performant GEMM Kernels on GPUs
- CUDA Tensor Core编程
- PROGRAMMING TENSOR CORES: NATIVE VOLTA TENSOR CORES WITH CUTLASS
- The NVIDIA Titan V Deep Learning Deep Dive: It’s All About The Tensor Cores
- 9.7.13.4.1. Matrix Fragments for mma.m8n8k4 with .f16 floating point type
- Numerical Behavior of NVIDIA Tensor Cores
- CUDA Ampere Tensor Core HGEMM 矩阵乘法优化笔记 —— Up To 131 TFLOPS!
- If we have two or four memory requests by a warp, do they need coalesced access/contiguity? #328
- Do bank conflicts increase when using more shared memory?
- How does parameter computeType affect the computation?
- 2.1.10. GEMM Algorithms Numerical Behavior
- cuBLAS的使用
- RAFT在Knowhere上的一些评估测试[1]
- How does parameter computeType affect the computation?
- cudnn-frontend/tree/main/samples/samples/conv_sample.cpp
- Is a union in C++ actually a class?
- A Generalized Micro-kernel Abstraction for GPU Linear Algebra
- Implementing Strassen’s Algorithm with CUTLASS on NVIDIA Volta GPUs
- Double-buffering in shared memory, details? #227
- Efficient GEMM in CUDA
- Thread synchronization with syncwarp
- Using CUDA Warp-Level Primitives
- CUDA微架构与指令集(3)-SASS指令集分类
- VOLTA Architecture and performance optimization
- How to Optimize a CUDA Matmul Kernel for cuBLAS-like Performance: a Worklog
- Determining registers holding the data after executing LDG.E.128
- 刘冰、郑鹏|GPU编程和优化-最佳实践分享