From ca3574dc62aa468bd22c3aca2c85aee1451bf489 Mon Sep 17 00:00:00 2001 From: Viktor Shipitsin Date: Wed, 25 Feb 2026 09:35:01 -0800 Subject: [PATCH] Use a struct to manage the mapping between `AttentionImpl` enum values and their string names, simplifying `GetAttentionImplName` function. Add a test to ensure all valid `AttentionImpl` enums have a corresponding name and can be looked up. PiperOrigin-RevId: 875204584 --- gemma/configs.cc | 31 +++++++++++++++++++------------ gemma/configs_test.cc | 12 ++++++++++++ 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/gemma/configs.cc b/gemma/configs.cc index 000e2786..29ba3e94 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include "compression/types.h" // Type @@ -678,7 +679,7 @@ Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) { return Model::GEMMA3_270M; case 26: - if (layer_types & (kDeducedViT|kDeducedKqNorm)) { + if (layer_types & (kDeducedViT | kDeducedKqNorm)) { return Model::GEMMA3_1B; } return Model::GEMMA2_2B; @@ -712,22 +713,28 @@ Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) { } } -// Keep in sync with enum class AttentionImpl. -const char* kAttentionImplNames[] = { - "old", "flash", - "unknown" // keep last +constexpr std::pair kAttentionImplNameToEnum[] = { + {"old", AttentionImpl::kOld}, + {"flash", AttentionImpl::kFlash}, + {"flash_transposed_qs", AttentionImpl::kFlashTransposedQs}, + {"flash_transposed_qs_bf16", AttentionImpl::kFlashTransposedQsBF16}, }; std::string GetAttentionImplName(AttentionImpl impl) { - return kAttentionImplNames[static_cast(impl)]; + for (auto const& [name, attention_impl] : kAttentionImplNameToEnum) { + if (attention_impl == impl) return std::string(name); + } + return "unknown"; } -AttentionImpl GetAttentionImpl(const std::string& impl) { - if (impl == GetAttentionImplName(AttentionImpl::kOld)) - return AttentionImpl::kOld; - if (impl == GetAttentionImplName(AttentionImpl::kFlash)) - return AttentionImpl::kFlash; - HWY_WARN("Unknown attention implementation: %s. Using kOld.\n", impl.c_str()); +AttentionImpl GetAttentionImpl(const std::string& impl_name) { + for (auto const& [name, attention_impl] : kAttentionImplNameToEnum) { + if (name == impl_name) { + return attention_impl; + } + } + HWY_WARN("Unknown attention implementation: %s. Using kOld.\n", + impl_name.c_str()); return AttentionImpl::kOld; } diff --git a/gemma/configs_test.cc b/gemma/configs_test.cc index 0ca4a848..e6f02579 100644 --- a/gemma/configs_test.cc +++ b/gemma/configs_test.cc @@ -41,4 +41,16 @@ TEST(ConfigsTest, TestAll) { }); } +TEST(ConfigsTest, TestAttentionImpl) { + for (int i = 0; i < static_cast(AttentionImpl::kSentinel); ++i) { + AttentionImpl impl = static_cast(i); + std::string name = GetAttentionImplName(impl); + ASSERT_NE(name, "unknown"); + ASSERT_EQ(GetAttentionImpl(name), impl); + } + ASSERT_EQ(GetAttentionImplName(AttentionImpl::kSentinel), "unknown"); + ASSERT_EQ(GetAttentionImpl("unknown"), AttentionImpl::kOld); + ASSERT_EQ(GetAttentionImpl("invalid"), AttentionImpl::kOld); +} + } // namespace gcpp