关于 UNFUSED_PADDED_MHA VS FUSED_MHA
- FUSED_MHA用了另一种kernel的执行方法(和添加链接描述相同,将在下一个section说明)
- UNFUSED_PADDED 的 KERNELS执行代码在 src/fastertransformer/kernels/unfused_attention_kernels.cu
enum class AttentionType {UNFUSED_MHA,UNFUSED_PADDED_MHA,FUSED_MHA,FUSED_PADDED_MHA
};/* NOTE:
1. only swin-style relative position bias is supported currently
2. gpt-style (causal-mask) models support any-sequence-length fmha, so we don't need to call isValidSeqLen at run-time
3. bert/vit can also support any-seq-length fmha
*/
template<typename T>
AttentionType getAttentionType(size_t size_per_head,const int sm,const bool remove_padding,const int max_seq_len,const bool is_fuse = true,const bool with_swin_relative_position_bias = false,const bool causal_mask = false)
{if (std::is_same<T, half>::value && is_fuse) {// Bert/Vitif (!causal_mask) {if (!with_swin_relative_position_bias&& (((sm == kSM_70 || sm == kSM_72) && size_per_head == 64)|| ((sm == kSM_75 || sm == kSM_80 || sm == kSM_86)&& (size_per_head == 64 || size_per_head == 32)))) {return remove_padding ? AttentionType::FUSED_MHA : AttentionType::FUSED_PADDED_MHA;}else if (with_swin_relative_position_bias && (sm == kSM_75 || sm == kSM_80 || sm == kSM_86)&& max_seq_len <= 256 && size_per_head == 32) {return remove_padding ? AttentionType::FUSED_MHA : AttentionType::FUSED_PADDED_MHA;}}// GPT and its variantselse {// FMHA_ENABLE only affects gpt-style models (causal-mask)char * fused_qkv = std::getenv("FMHA_ENABLE");if (fused_qkv != nullptr && std::string(fused_qkv) == "ON") {if ((sm == kSM_70 || sm == kSM_72 || sm == kSM_75 || sm == kSM_80 || sm == kSM_86 || sm == kSM_89)&& (size_per_head == 32 || size_per_head == 40 || size_per_head == 64 || size_per_head == 80|| size_per_head == 128 || size_per_head == 144 || size_per_head == 160 || size_per_head == 256)) {return remove_padding ? AttentionType::FUSED_MHA : AttentionType::UNFUSED_PADDED_MHA;}}}}
- 如果想执行FUSED_MHA,需要将参数设置如下:
FUSED_MHA
- https://github.com/NVIDIA/FasterTransformer/blob/main/docs/bert_guide.md
所以有关核函数的定义调用等还在forward部分:
https://github1s.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/bert/Bert.cc#L494
调用了FusedAttentionLayer的传播函数
传播函数的融合部分
Dispatcher_fp16为指向MHARunner类型的指针
实际上通过 .reset()实现了多态:
最终调用pimpl->run
指针pimpl对应的内部类的定义在
https://github1s.com/NVIDIA/FasterTransformer/blob/main/3rdparty/trt_fused_multihead_attention/qkvToContext.cu#L62