The current implementation of DISPATCH_HEAD_DIM macro in libflashinfer/include/flashinfer/attention/generic/dispatch.cuh is limited to the following cases:
// convert head_dim to compile-time constant
#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \
switch (head_dim) { \
case 64: { \
constexpr size_t HEAD_DIM = 64; \
__VA_ARGS__ \
break; \
} \
case 128: { \
constexpr size_t HEAD_DIM = 128; \
__VA_ARGS__ \
break; \
} \
case 256: { \
constexpr size_t HEAD_DIM = 256; \
__VA_ARGS__ \
break; \
} \
case 512: { \
constexpr size_t HEAD_DIM = 512; \
__VA_ARGS__ \
break; \
} \
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported head_dim: " << head_dim; \
FLASHINFER_ERROR(err_msg.str()); \
} \
}
The limitation makes it currently impossible to support other HEAD_DIM values such as 192. We need to look at how the upstream flashinfer library handles it and remediate the functionality.
The current implementation of DISPATCH_HEAD_DIM macro in libflashinfer/include/flashinfer/attention/generic/dispatch.cuh is limited to the following cases:
The limitation makes it currently impossible to support other HEAD_DIM values such as 192. We need to look at how the upstream flashinfer library handles it and remediate the functionality.