Skip to content

Implementing online-softmax #1

@FSSRepo

Description

@FSSRepo

I'm opening this issue in this project to facilitate communication; I need clarification on certain doubts.

@ggerganov Could you take a quick look at my code in flash-matrix.cu where I've implemented a kernel that parallelizes along the sequence (increasing occupancy and reducing the tail effect) for a query batch size of 1 (inference), with a head_dim of 128, which is particularly problematic for improving performance with very small batch sizes. I've already implemented (QK^T*scale+mask)V and it works very well for lengths up to 16K.

However, implementing softmax is proving to be very difficult for me (I can't seem to find a way to implement it). Perhaps you, who have tinkered with creating the kernel and have a better understanding of the FA 2.0 paper, can guide me a bit.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions