diff --git a/gemma/configs.cc b/gemma/configs.cc index 000e2786..29ba3e94 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include "compression/types.h" // Type @@ -678,7 +679,7 @@ Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) { return Model::GEMMA3_270M; case 26: - if (layer_types & (kDeducedViT|kDeducedKqNorm)) { + if (layer_types & (kDeducedViT | kDeducedKqNorm)) { return Model::GEMMA3_1B; } return Model::GEMMA2_2B; @@ -712,22 +713,28 @@ Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) { } } -// Keep in sync with enum class AttentionImpl. -const char* kAttentionImplNames[] = { - "old", "flash", - "unknown" // keep last +constexpr std::pair kAttentionImplNameToEnum[] = { + {"old", AttentionImpl::kOld}, + {"flash", AttentionImpl::kFlash}, + {"flash_transposed_qs", AttentionImpl::kFlashTransposedQs}, + {"flash_transposed_qs_bf16", AttentionImpl::kFlashTransposedQsBF16}, }; std::string GetAttentionImplName(AttentionImpl impl) { - return kAttentionImplNames[static_cast(impl)]; + for (auto const& [name, attention_impl] : kAttentionImplNameToEnum) { + if (attention_impl == impl) return std::string(name); + } + return "unknown"; } -AttentionImpl GetAttentionImpl(const std::string& impl) { - if (impl == GetAttentionImplName(AttentionImpl::kOld)) - return AttentionImpl::kOld; - if (impl == GetAttentionImplName(AttentionImpl::kFlash)) - return AttentionImpl::kFlash; - HWY_WARN("Unknown attention implementation: %s. Using kOld.\n", impl.c_str()); +AttentionImpl GetAttentionImpl(const std::string& impl_name) { + for (auto const& [name, attention_impl] : kAttentionImplNameToEnum) { + if (name == impl_name) { + return attention_impl; + } + } + HWY_WARN("Unknown attention implementation: %s. Using kOld.\n", + impl_name.c_str()); return AttentionImpl::kOld; } diff --git a/gemma/configs_test.cc b/gemma/configs_test.cc index 0ca4a848..e6f02579 100644 --- a/gemma/configs_test.cc +++ b/gemma/configs_test.cc @@ -41,4 +41,16 @@ TEST(ConfigsTest, TestAll) { }); } +TEST(ConfigsTest, TestAttentionImpl) { + for (int i = 0; i < static_cast(AttentionImpl::kSentinel); ++i) { + AttentionImpl impl = static_cast(i); + std::string name = GetAttentionImplName(impl); + ASSERT_NE(name, "unknown"); + ASSERT_EQ(GetAttentionImpl(name), impl); + } + ASSERT_EQ(GetAttentionImplName(AttentionImpl::kSentinel), "unknown"); + ASSERT_EQ(GetAttentionImpl("unknown"), AttentionImpl::kOld); + ASSERT_EQ(GetAttentionImpl("invalid"), AttentionImpl::kOld); +} + } // namespace gcpp