I'm opening this issue in this project to facilitate communication; I need clarification on certain doubts.
@ggerganov Could you take a quick look at my code in flash-matrix.cu where I've implemented a kernel that parallelizes along the sequence (increasing occupancy and reducing the tail effect) for a query batch size of 1 (inference), with a head_dim of 128, which is particularly problematic for improving performance with very small batch sizes. I've already implemented (QK^T*scale+mask)V and it works very well for lengths up to 16K.
However, implementing softmax is proving to be very difficult for me (I can't seem to find a way to implement it). Perhaps you, who have tinkered with creating the kernel and have a better understanding of the FA 2.0 paper, can guide me a bit.
I'm opening this issue in this project to facilitate communication; I need clarification on certain doubts.
@ggerganov Could you take a quick look at my code in flash-matrix.cu where I've implemented a kernel that parallelizes along the sequence (increasing occupancy and reducing the tail effect) for a query batch size of 1 (inference), with a head_dim of 128, which is particularly problematic for improving performance with very small batch sizes. I've already implemented (QK^T*scale+mask)V and it works very well for lengths up to 16K.
However, implementing softmax is proving to be very difficult for me (I can't seem to find a way to implement it). Perhaps you, who have tinkered with creating the kernel and have a better understanding of the FA 2.0 paper, can guide me a bit.